|
|
@@ -24,27 +24,29 @@ torchscript.ModelFactory = class {
|
|
|
}
|
|
|
|
|
|
open(context, host) {
|
|
|
- var identifier = context.identifier;
|
|
|
- try {
|
|
|
- var container = torchscript.ModelFactory._openContainer(context.buffer);
|
|
|
- return torchscript.Metadata.open(host).then((metadata) => {
|
|
|
- try {
|
|
|
- return new torchscript.Model(metadata, container);
|
|
|
- }
|
|
|
- catch (error) {
|
|
|
- host.exception(error, false);
|
|
|
- var message = error && error.message ? error.message : error.toString();
|
|
|
- message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
|
|
|
- throw new torchscript.Error(message + " in '" + identifier + "'.");
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
- catch (error) {
|
|
|
- host.exception(error, false);
|
|
|
- var message = error && error.message ? error.message : error.toString();
|
|
|
- message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
|
|
|
- return Promise.reject(new torchscript.Error(message + " in '" + identifier + "'."));
|
|
|
- }
|
|
|
+ return host.require('./python').then((python) => {
|
|
|
+ var identifier = context.identifier;
|
|
|
+ try {
|
|
|
+ var container = torchscript.ModelFactory._openContainer(context.buffer);
|
|
|
+ return torchscript.Metadata.open(host).then((metadata) => {
|
|
|
+ try {
|
|
|
+ return new torchscript.Model(metadata, python, container);
|
|
|
+ }
|
|
|
+ catch (error) {
|
|
|
+ host.exception(error, false);
|
|
|
+ var message = error && error.message ? error.message : error.toString();
|
|
|
+ message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
|
|
|
+ throw new torchscript.Error(message + " in '" + identifier + "'.");
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ catch (error) {
|
|
|
+ host.exception(error, false);
|
|
|
+ var message = error && error.message ? error.message : error.toString();
|
|
|
+ message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
|
|
|
+ return Promise.reject(new torchscript.Error(message + " in '" + identifier + "'."));
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
static _openContainer(buffer) {
|
|
|
@@ -67,7 +69,7 @@ torchscript.ModelFactory = class {
|
|
|
|
|
|
torchscript.Model = class {
|
|
|
|
|
|
- constructor(metadata, container) {
|
|
|
+ constructor(metadata, python, container) {
|
|
|
var textDecoder = new TextDecoder('utf-8');
|
|
|
var model = JSON.parse(textDecoder.decode(container.model.data));
|
|
|
var version = JSON.parse(textDecoder.decode(container.version.data));
|
|
|
@@ -79,7 +81,7 @@ torchscript.Model = class {
|
|
|
}
|
|
|
}
|
|
|
this._graphs = [];
|
|
|
- this._graphs.push(new torchscript.Graph(metadata, container, model.mainModule, model.tensors));
|
|
|
+ this._graphs.push(new torchscript.Graph(metadata, python, container, model.mainModule, model.tensors));
|
|
|
}
|
|
|
|
|
|
get format() {
|
|
|
@@ -97,7 +99,7 @@ torchscript.Model = class {
|
|
|
|
|
|
torchscript.Graph = class {
|
|
|
|
|
|
- constructor(metadata, container, mainModule, tensors) {
|
|
|
+ constructor(metadata, python, container, mainModule, tensors) {
|
|
|
this._name = mainModule.name;
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
@@ -105,7 +107,7 @@ torchscript.Graph = class {
|
|
|
|
|
|
container.tensors = tensors.map((tensor) => new torchscript.Tensor(tensor, container));
|
|
|
|
|
|
- var context = new torchscript.GraphContext(container, mainModule);
|
|
|
+ var context = new torchscript.GraphContext(container, python, mainModule);
|
|
|
|
|
|
container.parameters = {};
|
|
|
var queue = [ mainModule ];
|
|
|
@@ -787,7 +789,7 @@ torchscript.Metadata = class {
|
|
|
|
|
|
torchscript.GraphContext = class {
|
|
|
|
|
|
- constructor(container, mainModule) {
|
|
|
+ constructor(container, python, mainModule) {
|
|
|
|
|
|
this._container = container;
|
|
|
this._mainModule = mainModule;
|
|
|
@@ -807,11 +809,11 @@ torchscript.GraphContext = class {
|
|
|
var codeEntry = codeEntries[0];
|
|
|
var textDecoder = new TextDecoder('utf-8');
|
|
|
var code = textDecoder.decode(codeEntry.data);
|
|
|
- var reader = new torchscript.PythonReader(code);
|
|
|
- var statements = reader.statements();
|
|
|
- var method = statements.find((statement) => statement.type == 'def' && statement.name == 'forward');
|
|
|
+ var reader = new python.Parser(code);
|
|
|
+ var program = reader.parse();
|
|
|
+ var method = program.body.find((statement) => statement.type == 'def' && statement.name == 'forward');
|
|
|
if (method) {
|
|
|
- this._body = method.body;
|
|
|
+ this._body = method.body.statements;
|
|
|
var methodParameters = method.parameters;
|
|
|
if (methodParameters.length > 0 && methodParameters[0].name == 'self') {
|
|
|
methodParameters.shift();
|
|
|
@@ -876,7 +878,7 @@ torchscript.GraphContext = class {
|
|
|
if (this._body.length > 0) {
|
|
|
var statement = this._body[0];
|
|
|
if (statement.expression.type == 'identifier' && statement.expression.value == parameter.name) {
|
|
|
- if (statement.type === '=' && statement.target.type === 'identifier_list') {
|
|
|
+ if (statement.type === '=' && statement.target.type === 'tuple') {
|
|
|
for (var input of statement.target.value) {
|
|
|
if (input) {
|
|
|
this._inputs.push(input.value);
|
|
|
@@ -925,7 +927,7 @@ torchscript.GraphContext = class {
|
|
|
}
|
|
|
|
|
|
_nodeExpression(expression, target) {
|
|
|
- if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'identifier_list')) {
|
|
|
+ if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'tuple')) {
|
|
|
var name = this._name(expression.target);
|
|
|
var namespace = 'torch.';
|
|
|
if (name.startsWith(namespace)) {
|
|
|
@@ -943,6 +945,9 @@ torchscript.GraphContext = class {
|
|
|
delete this._argumentMap[argument.value];
|
|
|
}
|
|
|
if (argument.type == 'identifier') {
|
|
|
+ if (argument.value === 'False' || argument.value === 'True') {
|
|
|
+ break;
|
|
|
+ }
|
|
|
node.inputs.push([ { id: argument.value } ]);
|
|
|
args.shift();
|
|
|
continue;
|
|
|
@@ -1021,7 +1026,7 @@ torchscript.GraphContext = class {
|
|
|
if (target.type == 'identifier') {
|
|
|
node.outputs.push(target.value);
|
|
|
}
|
|
|
- if (target.type == 'identifier_list') {
|
|
|
+ if (target.type == 'tuple') {
|
|
|
for (var identifier of target.value) {
|
|
|
node.outputs.push(identifier.value);
|
|
|
}
|
|
|
@@ -1234,418 +1239,6 @@ torchscript.GraphContext = class {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-torchscript.PythonReader = class {
|
|
|
-
|
|
|
- constructor(text) {
|
|
|
- this._text = text;
|
|
|
- this._position = 0;
|
|
|
- this._lineEnd = -1;
|
|
|
- this._lineStart = 0;
|
|
|
- this._line = -1;
|
|
|
- this._indentation = [];
|
|
|
- }
|
|
|
-
|
|
|
- whitespace() {
|
|
|
- for (;;) {
|
|
|
- while (this._position > this._lineEnd) {
|
|
|
- this._lineStart = this._lineEnd + 1;
|
|
|
- this._position = this._lineStart;
|
|
|
- if (this._position >= this._text.length) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- this._lineEnd = this._text.indexOf("\n", this._position);
|
|
|
- if (this._lineEnd === -1) {
|
|
|
- this._lineEnd = this._text.length;
|
|
|
- }
|
|
|
- this._line++;
|
|
|
- }
|
|
|
- var c = this._text[this._position];
|
|
|
- switch (c) {
|
|
|
- case " ":
|
|
|
- case "\r":
|
|
|
- case "\t":
|
|
|
- this._position++;
|
|
|
- break;
|
|
|
- case "#":
|
|
|
- this._position = this._lineEnd;
|
|
|
- break;
|
|
|
- default:
|
|
|
- return true;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- tokenize() {
|
|
|
- if (!this.whitespace()) {
|
|
|
- this._token = { type: 'eof', value: "" };
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- var c = this._text[this._position];
|
|
|
- if (c == '\n') {
|
|
|
- this._token = { type: 'newline', value: c };
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- if (c === '=' || c === '(' || c === ')' || c === ":" || c === "," || c === '[' || c === ']') {
|
|
|
- this._token = { type: 'separator', value: c };
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- var position = this._position + 1;
|
|
|
- if (c >= "a" && c <= "z" || c >= "A" && c <= "Z" || c === "_") {
|
|
|
- while (position < this._lineEnd) {
|
|
|
- c = this._text[position];
|
|
|
- if (c >= "a" && c <= "z" || c >= "A" && c <= "Z" || c >= "0" && c <= "9" || c === "_" || c === "+" || c === "-") {
|
|
|
- position++;
|
|
|
- continue;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- var identifier = this._text.substring(this._position, position);
|
|
|
- if (identifier == 'True' || identifier == 'False') {
|
|
|
- this._token = { type: 'boolean', value: identifier };
|
|
|
- }
|
|
|
- else {
|
|
|
- this._token = { type: 'identifier', value: identifier };
|
|
|
- }
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- if (c === "-") {
|
|
|
- if (position < this._lineEnd) {
|
|
|
- if (this._text[position] === '>') {
|
|
|
- position++;
|
|
|
- this._token = { type: 'arrow', value: '->' };
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- if (c >= "0" && c <= "9" || c === "-" || c === "+") {
|
|
|
- while (position < this._lineEnd) {
|
|
|
- c = this._text[position];
|
|
|
- if (c >= "a" && c <= "z" || c >= "A" && c <= "Z" || c >= "0" && c <= "9" || c === "_" || c === "+" || c === "-" || c === ".") {
|
|
|
- position++;
|
|
|
- continue;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- this._token = { type: 'number', value: this._text.substring(this._position, position) };
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- if (c === "\"" || c === "'") {
|
|
|
- var quote = c;
|
|
|
- while (position < this._lineEnd) {
|
|
|
- c = this._text[position];
|
|
|
- if (c === "\\" && position < this._lineEnd) {
|
|
|
- position += 2;
|
|
|
- continue;
|
|
|
- }
|
|
|
- position++;
|
|
|
- if (c === quote) {
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- this._token = { type: 'string', value: this._text.substring(this._position, position) };
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- if (c === '.') {
|
|
|
- this._token = { type: 'dot', value: c };
|
|
|
- return this._token;
|
|
|
- }
|
|
|
- throw new torchscript.Error("Unexpected token '" + c + "'" + this.location());
|
|
|
- }
|
|
|
-
|
|
|
- peek() {
|
|
|
- if (!this._cache) {
|
|
|
- this._token = this.tokenize();
|
|
|
- this._cache = true;
|
|
|
- }
|
|
|
- return this._token;
|
|
|
- }
|
|
|
-
|
|
|
- read() {
|
|
|
- if (!this._cache) {
|
|
|
- this._token = this.tokenize();
|
|
|
- }
|
|
|
- this._position += this._token.value.length;
|
|
|
- this._cache = false;
|
|
|
- return this._token;
|
|
|
- }
|
|
|
-
|
|
|
- match(value) {
|
|
|
- if (this.peek().value === value) {
|
|
|
- this.read();
|
|
|
- return true;
|
|
|
- }
|
|
|
- return false;
|
|
|
- }
|
|
|
-
|
|
|
- expect(value) {
|
|
|
- var token = this.read();
|
|
|
- if (token.value !== value) {
|
|
|
- throw new torchscript.Error("Unexpected '" + token + "' instead of '" + value + "'" + this.location());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- location() {
|
|
|
- return " at " + (this._line + 1).toString() + ":" + (this._position - this._lineStart + 1).toString();
|
|
|
- }
|
|
|
-
|
|
|
- letter(c) {
|
|
|
- return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
|
|
|
- }
|
|
|
-
|
|
|
- number(c) {
|
|
|
- return c >= '0' && c <= '9';
|
|
|
- }
|
|
|
-
|
|
|
- identifier() {
|
|
|
- var token = this.peek();
|
|
|
- if (token.type == 'identifier') {
|
|
|
- this.read();
|
|
|
- return token;
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- literal() {
|
|
|
- var token = this.peek();
|
|
|
- if (token.type == 'string' || token.type == 'number' || token.type == 'boolean') {
|
|
|
- this.read();
|
|
|
- return token;
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- typeArguments() {
|
|
|
- var list = [];
|
|
|
- this.expect('[');
|
|
|
- while (!this.match(']')) {
|
|
|
- var type = this.type();
|
|
|
- if (type == null) {
|
|
|
- throw new torchscript.Error('Expected type ' + this.location());
|
|
|
- }
|
|
|
- list.push(type);
|
|
|
- if (!this.match(',')) {
|
|
|
- this.expect(']');
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- return list;
|
|
|
- }
|
|
|
-
|
|
|
- type() {
|
|
|
- var identifier = this.identifier();
|
|
|
- if (identifier) {
|
|
|
- var type = { type: 'type', value: identifier.value };
|
|
|
- if (this.peek().value === '[') {
|
|
|
- type.arguments = this.typeArguments();
|
|
|
- }
|
|
|
- return type;
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- parameter() {
|
|
|
- var identifier = this.identifier();
|
|
|
- if (identifier != null) {
|
|
|
- var parameterType = null
|
|
|
- if (this.match(':')) {
|
|
|
- parameterType = this.type();
|
|
|
- }
|
|
|
- return { type: 'parameter', name: identifier.value, parameterType: parameterType };
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- parameters() {
|
|
|
- var list = [];
|
|
|
- this.expect('(');
|
|
|
- while (!this.match(')')) {
|
|
|
- this.match('\n');
|
|
|
- list.push(this.parameter());
|
|
|
- this.match('\n');
|
|
|
- if (!this.match(',')) {
|
|
|
- this.expect(')');
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- return list;
|
|
|
- }
|
|
|
-
|
|
|
- arguments() {
|
|
|
- var list = [];
|
|
|
- this.expect('(');
|
|
|
- while (!this.match(')')) {
|
|
|
- var expression = this.expression();
|
|
|
- if (expression == null) {
|
|
|
- throw new torchscript.Error('Expected expression ' + this.location());
|
|
|
- }
|
|
|
- list.push(expression);
|
|
|
- if (!this.match(',')) {
|
|
|
- this.expect(')');
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- return list;
|
|
|
- }
|
|
|
-
|
|
|
- expression() {
|
|
|
- var stack = [];
|
|
|
- for (;;) {
|
|
|
- var identifier = this.identifier();
|
|
|
- if (identifier) {
|
|
|
- stack.push(identifier);
|
|
|
- continue;
|
|
|
- }
|
|
|
- var literal = this.literal();
|
|
|
- if (literal) {
|
|
|
- stack.push(literal);
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (this.match('.')) {
|
|
|
- stack.push({
|
|
|
- type: '.',
|
|
|
- target: stack.pop(),
|
|
|
- member: this.identifier(),
|
|
|
- });
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (this.peek().value === '(') {
|
|
|
- if (stack.length == 0) {
|
|
|
- stack.push({ type: 'tuple', value: this.arguments() });
|
|
|
- }
|
|
|
- else {
|
|
|
- stack.push({ type: 'call', target: stack.pop(), arguments: this.arguments() });
|
|
|
- }
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (this.peek().value === '[') {
|
|
|
- stack.push({ type: 'list', value: this.expressions() });
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (this.match('=')) {
|
|
|
- stack.push({ type: '=', target: stack.pop(), expression: this.expression() });
|
|
|
- continue;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- if (stack.length == 1) {
|
|
|
- return stack.pop();
|
|
|
- }
|
|
|
- if (stack.length != 0) {
|
|
|
- throw new torchscript.Error('Unexpected expression ' + this.location());
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- expressions() {
|
|
|
- var list = [];
|
|
|
- this.expect('[');
|
|
|
- while (!this.match(']')) {
|
|
|
- var expression = this.expression();
|
|
|
- if (expression == null) {
|
|
|
- throw new torchscript.Error('Expected expression ' + this.location());
|
|
|
- }
|
|
|
- list.push(expression);
|
|
|
- if (!this.match(',')) {
|
|
|
- this.expect(']');
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- return list;
|
|
|
- }
|
|
|
-
|
|
|
- statement() {
|
|
|
- var stack = [];
|
|
|
- while (this.peek().type !== 'eof') {
|
|
|
-
|
|
|
- if (this.match('def')) {
|
|
|
- var node = { type: 'def' };
|
|
|
- node.name = this.identifier().value;
|
|
|
- node.parameters = this.parameters();
|
|
|
- if (this.match('->')) {
|
|
|
- node.returnType = this.type();
|
|
|
- }
|
|
|
- this.expect(':');
|
|
|
- this.expect('\n');
|
|
|
- var position = this._position;
|
|
|
- while (this.match('\n')) {
|
|
|
- position = this._position;
|
|
|
- }
|
|
|
- this.peek();
|
|
|
- this._indentation.push(this._text.substring(position, this._position));
|
|
|
- this._position = position;
|
|
|
- node.body = this.statements();
|
|
|
- this._indentation.pop();
|
|
|
- stack.push(node);
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- if (this.match('return')) {
|
|
|
- stack.push({ type: 'return', expression: this.expression() });
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- var expression = this.expression();
|
|
|
- if (expression) {
|
|
|
- if (expression.type == 'identifier') {
|
|
|
- if (this.peek().value === ',') {
|
|
|
- var list = [ expression ];
|
|
|
- while (this.match(',')) {
|
|
|
- var identifier = this.identifier();
|
|
|
- if (!identifier) {
|
|
|
- if (this.peek().value != '=') {
|
|
|
- throw new torchscript.Error('Expected identifier' + this.location());
|
|
|
- }
|
|
|
- }
|
|
|
- list.push(identifier);
|
|
|
- }
|
|
|
- expression = { type: 'identifier_list', value: list };
|
|
|
- if (this.match('=')) {
|
|
|
- expression = { type: '=', target: expression, expression: this.expression() };
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- if (expression.type == '=') {
|
|
|
- stack.push(expression);
|
|
|
- this.match('\n');
|
|
|
- break;
|
|
|
- }
|
|
|
- throw new torchscript.Error('Unhandled expression ' + this.location);
|
|
|
- }
|
|
|
-
|
|
|
- if (this.match('\n')) {
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if (stack.length == 1) {
|
|
|
- return stack.pop();
|
|
|
- }
|
|
|
- if (stack.length != 0) {
|
|
|
- throw new torchscript.Error('Unexpected statement ' + this.location());
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- statements() {
|
|
|
- var indentation = this._indentation.join('');
|
|
|
- var stack = [];
|
|
|
- while (this._position < this._text.length) {
|
|
|
- if (this._text.substring(this._position, this._position + indentation.length) !== indentation) {
|
|
|
- return stack;
|
|
|
- }
|
|
|
- this._position = this._position + indentation.length;
|
|
|
-
|
|
|
- var statement = this.statement();
|
|
|
- if (statement) {
|
|
|
- stack.push(statement);
|
|
|
- continue;
|
|
|
- }
|
|
|
- }
|
|
|
- return stack;
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
torchscript.Error = class extends Error {
|
|
|
constructor(message) {
|
|
|
super(message);
|