Lutz Roeder 6 лет назад
Родитель
Сommit
5bf5a65733
3 измененных файлов с 1454 добавлено и 446 удалено
  1. 1 1
      setup.py
  2. 1415 0
      src/python.js
  3. 38 445
      src/torchscript.js

+ 1 - 1
setup.py

@@ -103,7 +103,7 @@ setuptools.setup(
             'tf.js', 'tf-metadata.json', 'tf-proto.js', 
             'tflite.js', 'tflite-metadata.json', 'tflite-schema.js', 
             'torch.js', 'torch-metadata.json',
-            'torchscript.js', 'torchscript-metadata.json',
+            'torchscript.js', 'torchscript-metadata.json', 'python.js',
             'index.html', 'index.js',
             'view-grapher.css', 'view-grapher.js',
             'view-sidebar.css', 'view-sidebar.js',

Разница между файлами не показана из-за своего большого размера
+ 1415 - 0
src/python.js


+ 38 - 445
src/torchscript.js

@@ -24,27 +24,29 @@ torchscript.ModelFactory = class {
     }
 
     open(context, host) {
-        var identifier = context.identifier;
-        try {
-            var container = torchscript.ModelFactory._openContainer(context.buffer);
-            return torchscript.Metadata.open(host).then((metadata) => {
-                try {
-                    return new torchscript.Model(metadata, container);
-                }
-                catch (error) {
-                    host.exception(error, false);
-                    var message = error && error.message ? error.message : error.toString();
-                    message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
-                    throw new torchscript.Error(message + " in '" + identifier + "'.");
-                }
-            });
-        }
-        catch (error) {
-            host.exception(error, false);
-            var message = error && error.message ? error.message : error.toString();
-            message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
-            return Promise.reject(new torchscript.Error(message + " in '" + identifier + "'."));
-        }
+        return host.require('./python').then((python) => {
+            var identifier = context.identifier;
+            try {
+                var container = torchscript.ModelFactory._openContainer(context.buffer);
+                return torchscript.Metadata.open(host).then((metadata) => {
+                    try {
+                        return new torchscript.Model(metadata, python, container);
+                    }
+                    catch (error) {
+                        host.exception(error, false);
+                        var message = error && error.message ? error.message : error.toString();
+                        message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
+                        throw new torchscript.Error(message + " in '" + identifier + "'.");
+                    }
+                });
+            }
+            catch (error) {
+                host.exception(error, false);
+                var message = error && error.message ? error.message : error.toString();
+                message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
+                return Promise.reject(new torchscript.Error(message + " in '" + identifier + "'."));
+            }
+        });
     }
 
     static _openContainer(buffer) {
@@ -67,7 +69,7 @@ torchscript.ModelFactory = class {
 
 torchscript.Model = class { 
 
-    constructor(metadata, container) {
+    constructor(metadata, python, container) {
         var textDecoder = new TextDecoder('utf-8');
         var model = JSON.parse(textDecoder.decode(container.model.data));
         var version = JSON.parse(textDecoder.decode(container.version.data));
@@ -79,7 +81,7 @@ torchscript.Model = class {
             }
         }
         this._graphs = [];
-        this._graphs.push(new torchscript.Graph(metadata, container, model.mainModule, model.tensors));
+        this._graphs.push(new torchscript.Graph(metadata, python, container, model.mainModule, model.tensors));
     }
 
     get format() {
@@ -97,7 +99,7 @@ torchscript.Model = class {
 
 torchscript.Graph = class {
 
-    constructor(metadata, container, mainModule, tensors) {
+    constructor(metadata, python, container, mainModule, tensors) {
         this._name = mainModule.name;
         this._inputs = [];
         this._outputs = [];
@@ -105,7 +107,7 @@ torchscript.Graph = class {
 
         container.tensors = tensors.map((tensor) => new torchscript.Tensor(tensor, container));
 
-        var context = new torchscript.GraphContext(container, mainModule);
+        var context = new torchscript.GraphContext(container, python, mainModule);
 
         container.parameters = {};
         var queue = [ mainModule ];
@@ -787,7 +789,7 @@ torchscript.Metadata = class {
 
 torchscript.GraphContext = class {
 
-    constructor(container, mainModule) {
+    constructor(container, python, mainModule) {
 
         this._container = container;
         this._mainModule = mainModule;
@@ -807,11 +809,11 @@ torchscript.GraphContext = class {
                 var codeEntry = codeEntries[0];
                 var textDecoder = new TextDecoder('utf-8');
                 var code = textDecoder.decode(codeEntry.data);
-                var reader = new torchscript.PythonReader(code);
-                var statements = reader.statements();
-                var method = statements.find((statement) => statement.type == 'def' && statement.name == 'forward');
+                var reader = new python.Parser(code);
+                var program = reader.parse();
+                var method = program.body.find((statement) => statement.type == 'def' && statement.name == 'forward');
                 if (method) {
-                    this._body = method.body;
+                    this._body = method.body.statements;
                     var methodParameters = method.parameters;
                     if (methodParameters.length > 0 && methodParameters[0].name == 'self') {
                         methodParameters.shift();
@@ -876,7 +878,7 @@ torchscript.GraphContext = class {
             if (this._body.length > 0) {
                 var statement = this._body[0];
                 if (statement.expression.type == 'identifier' && statement.expression.value == parameter.name) {
-                    if (statement.type === '=' && statement.target.type === 'identifier_list') {
+                    if (statement.type === '=' && statement.target.type === 'tuple') {
                         for (var input of statement.target.value) {
                             if (input) {
                                 this._inputs.push(input.value);
@@ -925,7 +927,7 @@ torchscript.GraphContext = class {
     }
 
     _nodeExpression(expression, target) {
-        if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'identifier_list')) {
+        if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'tuple')) {
             var name = this._name(expression.target);
             var namespace = 'torch.';
             if (name.startsWith(namespace)) {
@@ -943,6 +945,9 @@ torchscript.GraphContext = class {
                         delete this._argumentMap[argument.value];
                     }
                     if (argument.type == 'identifier') {
+                        if (argument.value === 'False' || argument.value === 'True') {
+                            break;
+                        }
                         node.inputs.push([ { id: argument.value } ]);
                         args.shift();
                         continue;
@@ -1021,7 +1026,7 @@ torchscript.GraphContext = class {
                 if (target.type == 'identifier') {
                     node.outputs.push(target.value);
                 }
-                if (target.type == 'identifier_list') {
+                if (target.type == 'tuple') {
                     for (var identifier of target.value) {
                         node.outputs.push(identifier.value);
                     }
@@ -1234,418 +1239,6 @@ torchscript.GraphContext = class {
     }
 }
 
-torchscript.PythonReader = class {
-
-    constructor(text) {
-        this._text = text;
-        this._position = 0;
-        this._lineEnd = -1;
-        this._lineStart = 0;
-        this._line = -1;
-        this._indentation = [];
-    }
-
-    whitespace() {
-        for (;;) {
-            while (this._position > this._lineEnd) {
-                this._lineStart = this._lineEnd + 1;
-                this._position = this._lineStart;
-                if (this._position >= this._text.length) {
-                    return false;
-                }
-                this._lineEnd = this._text.indexOf("\n", this._position);
-                if (this._lineEnd === -1) {
-                    this._lineEnd = this._text.length;
-                }
-                this._line++;
-            }
-            var c = this._text[this._position];
-            switch (c) {
-                case " ":
-                case "\r":
-                case "\t":
-                    this._position++;
-                    break;
-                case "#":
-                    this._position = this._lineEnd;
-                    break;
-                default:
-                    return true;
-            }
-        }
-    }
-    
-    tokenize() {
-        if (!this.whitespace()) {
-            this._token = { type: 'eof', value: "" };
-            return this._token;
-        }
-        var c = this._text[this._position];
-        if (c == '\n') {
-            this._token = { type: 'newline', value: c };
-            return this._token;
-        }
-        if (c === '=' || c === '(' || c === ')' || c === ":" || c === "," || c === '[' || c === ']') {
-            this._token = { type: 'separator', value: c };
-            return this._token;
-        }
-        var position = this._position + 1;
-        if (c >= "a" && c <= "z" || c >= "A" && c <= "Z" || c === "_") {
-            while (position < this._lineEnd) {
-                c = this._text[position];
-                if (c >= "a" && c <= "z" || c >= "A" && c <= "Z" || c >= "0" && c <= "9" || c === "_" || c === "+" || c === "-") {
-                    position++;
-                    continue;
-                }
-                break;
-            }
-            var identifier = this._text.substring(this._position, position);
-            if (identifier == 'True' || identifier == 'False') {
-                this._token = { type: 'boolean', value: identifier };
-            }
-            else {
-                this._token = { type: 'identifier', value: identifier };
-            }
-            return this._token;
-        }
-        if (c === "-") {
-            if (position < this._lineEnd) {
-                if (this._text[position] === '>') {
-                    position++;
-                    this._token = { type: 'arrow', value: '->' };
-                    return this._token;
-                }
-            }
-        }
-        if (c >= "0" && c <= "9" || c === "-" || c === "+") {
-            while (position < this._lineEnd) {
-                c = this._text[position];
-                if (c >= "a" && c <= "z" || c >= "A" && c <= "Z" || c >= "0" && c <= "9" || c === "_" || c === "+" || c === "-" || c === ".") {
-                    position++;
-                    continue;
-                }
-                break;
-            }
-            this._token = { type: 'number', value: this._text.substring(this._position, position) };
-            return this._token;
-        }
-        if (c === "\"" || c === "'") {
-            var quote = c;
-            while (position < this._lineEnd) {
-                c = this._text[position];
-                if (c === "\\" && position < this._lineEnd) {
-                    position += 2;
-                    continue;
-                }
-                position++;
-                if (c === quote) {
-                    break;
-                }
-            }
-            this._token = { type: 'string', value: this._text.substring(this._position, position) };
-            return this._token;
-        }
-        if (c === '.') {
-            this._token = { type: 'dot', value: c };
-            return this._token;
-        }
-        throw new torchscript.Error("Unexpected token '" + c + "'" + this.location());
-    }
-
-    peek() {
-        if (!this._cache) {
-            this._token = this.tokenize();
-            this._cache = true;
-        }
-        return this._token;
-    }
-    
-    read() {
-        if (!this._cache) {
-            this._token = this.tokenize();
-        }
-        this._position += this._token.value.length;
-        this._cache = false;
-        return this._token;
-    }
-
-    match(value) {
-        if (this.peek().value === value) {
-            this.read();
-            return true;
-        }
-        return false;
-    }
-
-    expect(value) {
-        var token = this.read();
-        if (token.value !== value) {
-            throw new torchscript.Error("Unexpected '" + token + "' instead of '" + value + "'" + this.location());
-        }
-    }
-
-    location() {
-        return " at " + (this._line + 1).toString() + ":" + (this._position - this._lineStart + 1).toString();
-    }
-
-    letter(c) {
-        return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
-    }
-
-    number(c) {
-        return c >= '0' && c <= '9';
-    }
-
-    identifier() {
-        var token = this.peek();
-        if (token.type == 'identifier') {
-            this.read();
-            return token;
-        }
-        return null;
-    }
-
-    literal() {
-        var token = this.peek();
-        if (token.type == 'string' || token.type == 'number' || token.type == 'boolean') {
-            this.read();
-            return token;
-        }
-        return null;
-    }
-
-    typeArguments() {
-        var list = [];
-        this.expect('[');
-        while (!this.match(']')) {
-            var type = this.type();
-            if (type == null) {
-                throw new torchscript.Error('Expected type ' + this.location());
-            }
-            list.push(type);
-            if (!this.match(',')) {
-                this.expect(']');
-                break;
-            }
-        }
-        return list;
-    }
-
-    type() {
-        var identifier = this.identifier();
-        if (identifier) {
-            var type = { type: 'type', value: identifier.value };
-            if (this.peek().value === '[') {
-                type.arguments = this.typeArguments();
-            }
-            return type;
-        }
-        return null;
-    }
-
-    parameter() {
-        var identifier = this.identifier();
-        if (identifier != null) {
-            var parameterType = null
-            if (this.match(':')) {
-                parameterType = this.type();
-            }
-            return { type: 'parameter', name: identifier.value, parameterType: parameterType };
-        }
-        return null;
-    }
-
-    parameters() {
-        var list = [];
-        this.expect('(');
-        while (!this.match(')')) {
-            this.match('\n');
-            list.push(this.parameter());
-            this.match('\n');
-            if (!this.match(',')) {
-                this.expect(')');
-                break;
-            }
-        }
-        return list;
-    }
-
-    arguments() {
-        var list = [];
-        this.expect('(');
-        while (!this.match(')')) {
-            var expression = this.expression();
-            if (expression == null) {
-                throw new torchscript.Error('Expected expression ' + this.location());
-            }
-            list.push(expression);
-            if (!this.match(',')) {
-                this.expect(')');
-                break;
-            }
-        }
-        return list;
-    }
-
-    expression() {
-        var stack = [];
-        for (;;) {
-            var identifier = this.identifier();
-            if (identifier) {
-                stack.push(identifier);
-                continue;
-            }
-            var literal = this.literal();
-            if (literal) {
-                stack.push(literal);
-                continue;
-            }
-            if (this.match('.')) {
-                stack.push({
-                    type: '.',
-                    target: stack.pop(),
-                    member: this.identifier(),
-                });
-                continue;
-            }
-            if (this.peek().value === '(') {
-                if (stack.length == 0) {
-                    stack.push({ type: 'tuple', value: this.arguments() });
-                }
-                else {
-                    stack.push({ type: 'call', target: stack.pop(), arguments: this.arguments() });
-                }
-                continue;
-            }
-            if (this.peek().value === '[') {
-                stack.push({ type: 'list', value: this.expressions() });
-                continue;
-            }
-            if (this.match('=')) {
-                stack.push({ type: '=', target: stack.pop(), expression: this.expression() });
-                continue;
-            }
-            break;
-        }
-
-        if (stack.length == 1) {
-            return stack.pop();
-        }
-        if (stack.length != 0) {
-            throw new torchscript.Error('Unexpected expression ' + this.location());
-        }
-        return null;
-    }
-
-    expressions() {
-        var list = [];
-        this.expect('[');
-        while (!this.match(']')) {
-            var expression = this.expression();
-            if (expression == null) {
-                throw new torchscript.Error('Expected expression ' + this.location());
-            }
-            list.push(expression);
-            if (!this.match(',')) {
-                this.expect(']');
-                break;
-            }
-        }
-        return list;
-    }
-
-    statement() {
-        var stack = [];
-        while (this.peek().type !== 'eof') {
-
-            if (this.match('def')) {
-                var node = { type: 'def' };
-                node.name = this.identifier().value;
-                node.parameters = this.parameters();
-                if (this.match('->')) {
-                    node.returnType = this.type();
-                }
-                this.expect(':');
-                this.expect('\n');
-                var position = this._position;
-                while (this.match('\n')) {
-                    position = this._position;
-                }
-                this.peek();
-                this._indentation.push(this._text.substring(position, this._position));
-                this._position = position;
-                node.body = this.statements();
-                this._indentation.pop();
-                stack.push(node);
-                break;
-            }
-
-            if (this.match('return')) {
-                stack.push({ type: 'return', expression: this.expression() });
-                break; 
-            }
-
-            var expression = this.expression();
-            if (expression) {
-                if (expression.type == 'identifier') {
-                    if (this.peek().value === ',') {
-                        var list = [ expression ];
-                        while (this.match(',')) {
-                            var identifier = this.identifier();
-                            if (!identifier) {
-                                if (this.peek().value != '=') {
-                                    throw new torchscript.Error('Expected identifier' + this.location());
-                                }
-                            }
-                            list.push(identifier);
-                        }
-                        expression = { type: 'identifier_list', value: list };
-                        if (this.match('=')) {
-                            expression = { type: '=', target: expression, expression: this.expression() };
-                        }
-                    }
-                }
-                if (expression.type == '=') {
-                    stack.push(expression);
-                    this.match('\n');
-                    break;
-                }
-                throw new torchscript.Error('Unhandled expression ' + this.location);
-            }
-
-            if (this.match('\n')) {
-                break;
-            }
-        }
-
-        if (stack.length == 1) {
-            return stack.pop();
-        }
-        if (stack.length != 0) {
-            throw new torchscript.Error('Unexpected statement ' + this.location());
-        }
-        return null;
-    }
-
-    statements() {
-        var indentation = this._indentation.join('');
-        var stack = [];
-        while (this._position < this._text.length) {
-            if (this._text.substring(this._position, this._position + indentation.length) !== indentation) {
-                return stack;
-            }
-            this._position = this._position + indentation.length;
-
-            var statement = this.statement();
-            if (statement) { 
-                stack.push(statement);
-                continue;
-            }
-        }
-        return stack;
-    }
-}
-
 torchscript.Error = class extends Error {
     constructor(message) {
         super(message);

Некоторые файлы не были показаны из-за большого количества измененных файлов