|
|
@@ -36,7 +36,7 @@ torchscript.ModelFactory = class {
|
|
|
var message = error && error.message ? error.message : error.toString();
|
|
|
message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
|
|
|
throw new torchscript.Error(message + " in '" + identifier + "'.");
|
|
|
- }
|
|
|
+ }
|
|
|
});
|
|
|
}
|
|
|
catch (error) {
|
|
|
@@ -107,6 +107,29 @@ torchscript.Graph = class {
|
|
|
|
|
|
var context = new torchscript.GraphContext(container, mainModule);
|
|
|
|
|
|
+ container.parameters = {};
|
|
|
+ var queue = [ mainModule ];
|
|
|
+ while (queue.length > 0) {
|
|
|
+ var module = queue.shift();
|
|
|
+ if (module.parameters) {
|
|
|
+ for (var parameter of module.parameters) {
|
|
|
+ if (parameter.tensorId) {
|
|
|
+ var tensorId = parseInt(parameter.tensorId, 10);
|
|
|
+ parameter.initializer = container.tensors[tensorId];
|
|
|
+ if (parameter.outputs && parameter.outputs.length == 1) {
|
|
|
+ container.parameters[parameter.outputs[0]] = parameter;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (module.submodules) {
|
|
|
+ for (var submodule of module.submodules) {
|
|
|
+ submodule.parent = module;
|
|
|
+ queue.push(submodule);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
for (var input of context.inputs) {
|
|
|
this._inputs.push(new torchscript.Argument(input, true, [
|
|
|
new torchscript.Connection(input, null, null)
|
|
|
@@ -119,21 +142,20 @@ torchscript.Graph = class {
|
|
|
}
|
|
|
|
|
|
for (var node of context.nodes) {
|
|
|
- this._nodes.push(new torchscript.Node(metadata, container, '', null, node));
|
|
|
+ this._nodes.push(new torchscript.Node(metadata, container, null, node));
|
|
|
}
|
|
|
|
|
|
- this._loadModule(metadata, container, '', mainModule);
|
|
|
+ this._loadModule(metadata, container, mainModule);
|
|
|
}
|
|
|
|
|
|
- _loadModule(metadata, container, group, module) {
|
|
|
- if (module.parameters && module.parameters.length > 0) {
|
|
|
- var node = new torchscript.Node(metadata, container, group, module, null);
|
|
|
+ _loadModule(metadata, container, module) {
|
|
|
+ if (module.parameters && module.parameters.length > 0 && !module.hide) {
|
|
|
+ var node = new torchscript.Node(metadata, container, module, null);
|
|
|
this._nodes.push(node);
|
|
|
}
|
|
|
if (module.submodules) {
|
|
|
- var subgroup = group ? [ group, module.name ].join('.') : module.name;
|
|
|
for (var submodule of module.submodules) {
|
|
|
- this._loadModule(metadata, container, subgroup, submodule);
|
|
|
+ this._loadModule(metadata, container, submodule);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -210,20 +232,22 @@ torchscript.Connection = class {
|
|
|
|
|
|
torchscript.Node = class {
|
|
|
|
|
|
- constructor(metadata, container, group, module, node) {
|
|
|
+ constructor(metadata, container, module, node) {
|
|
|
this._metadata = metadata;
|
|
|
this._attributes = [];
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
|
|
|
+ var input = null;
|
|
|
+ var connection = null;
|
|
|
+ var parameter = null;
|
|
|
+
|
|
|
if (module) {
|
|
|
this._operator = 'Module';
|
|
|
- this._name = group ? [ group, module.name ].join('.') : module.name;
|
|
|
if (module.parameters) {
|
|
|
- for (var parameter of module.parameters) {
|
|
|
- var tensorId = parseInt(parameter.tensorId, 10);
|
|
|
+ for (parameter of module.parameters) {
|
|
|
this._inputs.push(new torchscript.Argument(parameter.name, true, [
|
|
|
- new torchscript.Connection('', null, container.tensors[tensorId])
|
|
|
+ new torchscript.Connection('', null, parameter.initializer || null)
|
|
|
]));
|
|
|
if (parameter.outputs) {
|
|
|
this._outputs.push(new torchscript.Argument(parameter.name, true,
|
|
|
@@ -240,13 +264,49 @@ torchscript.Node = class {
|
|
|
|
|
|
var schema = metadata.getSchema(this._operator);
|
|
|
|
|
|
+ module = null;
|
|
|
+ var match = true;
|
|
|
+ var count = 0;
|
|
|
+ for (input of node.inputs) {
|
|
|
+ for (connection of input) {
|
|
|
+ parameter = container.parameters[connection.id];
|
|
|
+ if (parameter) {
|
|
|
+ if (parameter.module && (module == null || module == parameter.module)) {
|
|
|
+ module = parameter.module;
|
|
|
+ count++;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ match = false;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (!match) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (module && module.parameters.length == count && match) {
|
|
|
+ module.hide = true;
|
|
|
+ for (input of node.inputs) {
|
|
|
+ for (connection of input) {
|
|
|
+ parameter = container.parameters[connection.id];
|
|
|
+ if (parameter && parameter.initializer) {
|
|
|
+ connection.initializer = parameter.initializer;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ module = null;
|
|
|
+ }
|
|
|
+
|
|
|
for (var inputIndex = 0; inputIndex < node.inputs.length; inputIndex++) {
|
|
|
var inputName = inputIndex.toString();
|
|
|
if (schema && schema.inputs && schema.inputs.length > inputIndex) {
|
|
|
inputName = schema.inputs[inputIndex].name;
|
|
|
}
|
|
|
this._inputs.push(new torchscript.Argument(inputName, true,
|
|
|
- node.inputs[inputIndex].map((input) => new torchscript.Connection(input, null, null))
|
|
|
+ node.inputs[inputIndex].map((input) => new torchscript.Connection(input.id, null, input.initializer || null))
|
|
|
));
|
|
|
}
|
|
|
|
|
|
@@ -280,6 +340,17 @@ torchscript.Node = class {
|
|
|
this._attributes.push(new torchscript.Attribute(this, attributeSchema, attributeName, attributeValue));
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ if (module) {
|
|
|
+ if (module.name) {
|
|
|
+ var current = module;
|
|
|
+ this._name = current.name;
|
|
|
+ while (current.parent != null) {
|
|
|
+ current = current.parent;
|
|
|
+ this._name = [ current.name, this._name ].join('.')
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
get name() {
|
|
|
@@ -399,7 +470,13 @@ torchscript.Attribute = class {
|
|
|
case 'int32[]':
|
|
|
case 'int64[]':
|
|
|
if (this._value.type == 'list' && this._value.value.every((item) => item.type === 'number')) {
|
|
|
- this._value = this._value.value.map((item) => parseInt(item.value, 10));
|
|
|
+ this._value = this._value.value.map((item) => {
|
|
|
+ var number = parseInt(item.value, 10);
|
|
|
+ if (!Number.isNaN(item.value - number)) {
|
|
|
+ return number;
|
|
|
+ }
|
|
|
+ return item.value;
|
|
|
+ });
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
@@ -411,6 +488,11 @@ torchscript.Attribute = class {
|
|
|
if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
|
|
|
this._visible = false;
|
|
|
}
|
|
|
+ else if (Array.isArray(this._value) &&
|
|
|
+ !Array.isArray(schema.default) &&
|
|
|
+ this.value.every((item) => item == schema.default)) {
|
|
|
+ this._visible = false;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -438,6 +520,7 @@ torchscript.Tensor = class {
|
|
|
this._type = new torchscript.TensorType(tensor.dataType, new torchscript.TensorShape(tensor.dims));
|
|
|
var key = container.prefix + tensor.data.key;
|
|
|
var entry = container.entries.find((entry) => entry.name == key);
|
|
|
+ this._name = tensor.data.key;
|
|
|
this._data = entry.data;
|
|
|
this._littleEndian = true;
|
|
|
}
|
|
|
@@ -706,6 +789,7 @@ torchscript.GraphContext = class {
|
|
|
|
|
|
constructor(container, mainModule) {
|
|
|
|
|
|
+ this._container = container;
|
|
|
this._mainModule = mainModule;
|
|
|
|
|
|
this._inputs = [];
|
|
|
@@ -714,6 +798,7 @@ torchscript.GraphContext = class {
|
|
|
|
|
|
this._moduleMap = {};
|
|
|
this._connectionMap = {};
|
|
|
+ this._numToTensorMap = {};
|
|
|
|
|
|
if (mainModule.torchscriptArena && mainModule.torchscriptArena.key) {
|
|
|
var codeKey = container.prefix + mainModule.torchscriptArena.key;
|
|
|
@@ -751,6 +836,9 @@ torchscript.GraphContext = class {
|
|
|
|
|
|
while (this._body.length > 0) {
|
|
|
var statement = this._body.shift();
|
|
|
+ if (this._attributeStatement(statement)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
if (this._moduleStatement(statement)) {
|
|
|
continue;
|
|
|
}
|
|
|
@@ -763,7 +851,7 @@ torchscript.GraphContext = class {
|
|
|
if (this._returnStatement(statement)) {
|
|
|
continue;
|
|
|
}
|
|
|
- debugger;
|
|
|
+ throw new torchscript.Error("Unknown statement '" + JSON.stringify(statement) + "'.");
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -839,9 +927,8 @@ torchscript.GraphContext = class {
|
|
|
_nodeExpression(expression, target) {
|
|
|
if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'identifier_list')) {
|
|
|
var name = this._name(expression.target);
|
|
|
- var namespaces = [ 'torch.', 'ops.prim.' ];
|
|
|
- var namespace = namespaces.find((n) => name.startsWith(n));
|
|
|
- if (namespace) {
|
|
|
+ var namespace = 'torch.';
|
|
|
+ if (name.startsWith(namespace)) {
|
|
|
var node = {};
|
|
|
node.name = name.substring(namespace.length);
|
|
|
node.inputs = [];
|
|
|
@@ -856,7 +943,7 @@ torchscript.GraphContext = class {
|
|
|
delete this._connectionMap[argument.value];
|
|
|
}
|
|
|
if (argument.type == 'identifier') {
|
|
|
- node.inputs.push([ argument.value ]);
|
|
|
+ node.inputs.push([ { id: argument.value } ]);
|
|
|
args.shift();
|
|
|
continue;
|
|
|
}
|
|
|
@@ -865,13 +952,13 @@ torchscript.GraphContext = class {
|
|
|
for (var input of argument.value) {
|
|
|
var variable = this._variable();
|
|
|
if (this._nodeExpression(input, variable)) {
|
|
|
- connections.push(variable.value);
|
|
|
+ connections.push({ id: variable.value });
|
|
|
}
|
|
|
- if (this._connectionExpression(input, variable)) {
|
|
|
- connections.push(variable.value);
|
|
|
+ else if (this._connectionExpression(input, variable)) {
|
|
|
+ connections.push({ id: variable.value });
|
|
|
}
|
|
|
else if (input.type == 'identifier') {
|
|
|
- connections.push(input.value);
|
|
|
+ connections.push({ id: input.value });
|
|
|
}
|
|
|
else {
|
|
|
connections = null;
|
|
|
@@ -893,26 +980,41 @@ torchscript.GraphContext = class {
|
|
|
if (argument.type == '=') {
|
|
|
break;
|
|
|
}
|
|
|
- var variable = this._variable();
|
|
|
+ variable = this._variable();
|
|
|
if (this._nodeExpression(argument, variable)) {
|
|
|
- node.inputs.push([ variable.value ]);
|
|
|
+ node.inputs.push([ { id: variable.value } ]);
|
|
|
args.shift();
|
|
|
continue;
|
|
|
}
|
|
|
if (this._connectionExpression(argument, variable)) {
|
|
|
- node.inputs.push([ variable.value ]);
|
|
|
+ node.inputs.push([ { id: variable.value } ]);
|
|
|
args.shift();
|
|
|
continue;
|
|
|
}
|
|
|
- // TODO CONSTANTS.cx
|
|
|
- if (argument.type == '.' && argument.target.type == 'identifier' && argument.target.value == 'CONSTANTS') {
|
|
|
- node.inputs.push([ JSON.stringify(args[0]) ]);
|
|
|
+ if (argument.type == '.' &&
|
|
|
+ argument.target.type == 'identifier' &&
|
|
|
+ argument.target.value == 'CONSTANTS' &&
|
|
|
+ argument.member.type == 'identifier' &&
|
|
|
+ argument.member.value.startsWith('c')) {
|
|
|
+ var constantId = [ argument.target.value, argument.member.value ].join('.');
|
|
|
+ var constantIndex = parseInt(argument.member.value.substring(1), 10);
|
|
|
+ var constantTensor = this._container.tensors[constantIndex];
|
|
|
+ node.inputs.push([ { id: constantId, initializer: constantTensor } ]);
|
|
|
args.shift();
|
|
|
continue;
|
|
|
}
|
|
|
throw new torchscript.Error('Unknown function argument.');
|
|
|
}
|
|
|
while (args.length > 0) {
|
|
|
+ if (args[0].type == 'list') {
|
|
|
+ for (var i = 0; i < args[0].value.length; i++) {
|
|
|
+ args[0].value[i] = this._attributeExpression(args[0].value[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ var intExpression = this._attributeExpression(args[0]);
|
|
|
+ if (intExpression) {
|
|
|
+ args[0] = intExpression;
|
|
|
+ }
|
|
|
node.attributes.push(args[0]);
|
|
|
args.shift();
|
|
|
}
|
|
|
@@ -940,6 +1042,71 @@ torchscript.GraphContext = class {
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
+ _attributeExpression(expression) {
|
|
|
+ if (expression.type == 'identifier') {
|
|
|
+ if (this._numToTensorMap[expression.value]) {
|
|
|
+ return { type: 'number', value: this._numToTensorMap[expression.value] };
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (expression.type == 'call' &&
|
|
|
+ expression.target.type == 'identifier' &&
|
|
|
+ expression.target.value == 'int' &&
|
|
|
+ expression.arguments.length == 1)
|
|
|
+ {
|
|
|
+ var replace = this._attributeExpression(expression.arguments[0]);
|
|
|
+ if (replace) {
|
|
|
+ return replace;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return expression;
|
|
|
+ }
|
|
|
+
|
|
|
+ _attributeStatement(statement) {
|
|
|
+ if (statement.type == '=' &&
|
|
|
+ statement.target.type == 'identifier') {
|
|
|
+ if (statement.expression.type == 'call' &&
|
|
|
+ this._name(statement.expression.target) == 'ops.prim.NumToTensor' &&
|
|
|
+ statement.expression.arguments.length == 1) {
|
|
|
+ var size = statement.expression.arguments[0];
|
|
|
+ if (size.type == 'call' &&
|
|
|
+ size.arguments.length == 2 &&
|
|
|
+ this._name(size.target) == 'torch.size' &&
|
|
|
+ size.arguments[0].type == 'identifier' &&
|
|
|
+ size.arguments[1].type == 'number') {
|
|
|
+ this._numToTensorMap[statement.target.value] = this._name(size.target) + '(' + size.arguments.map((a) => a.value.toString()).join(',') + ')';;
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (size.type == 'identifier') {
|
|
|
+ var duplicate1 = this._numToTensorMap[size.value];
|
|
|
+ if (duplicate1) {
|
|
|
+ this._numToTensorMap[statement.target.value] = duplicate1;
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (statement.expression.type == 'call' &&
|
|
|
+ statement.expression.arguments.length == 2 &&
|
|
|
+ this._name(statement.expression.target) == 'torch.size' &&
|
|
|
+ statement.expression.arguments[0].type == 'identifier' &&
|
|
|
+ statement.expression.arguments[1].type == 'number') {
|
|
|
+ this._numToTensorMap[statement.target.value] = this._name(statement.expression.target) + '(' + statement.expression.arguments.map((a) => a.value.toString()).join(',') + ')';;
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (statement.expression.type == 'call' &&
|
|
|
+ statement.expression.target.type == 'identifier' &&
|
|
|
+ statement.expression.target.value == 'int' &&
|
|
|
+ statement.expression.arguments.length == 1 &&
|
|
|
+ statement.expression.arguments[0].type == 'identifier') {
|
|
|
+ var duplicate2 = this._numToTensorMap[statement.expression.arguments[0].value];
|
|
|
+ if (duplicate2) {
|
|
|
+ this._numToTensorMap[statement.target.value] = duplicate2;
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
_module(expression) {
|
|
|
var module;
|
|
|
var submodule;
|
|
|
@@ -999,24 +1166,27 @@ torchscript.GraphContext = class {
|
|
|
_connectionExpression(expression, target) {
|
|
|
expression = this._moduleTensor(expression);
|
|
|
if (expression.type === '.' && expression.member.type == 'identifier') {
|
|
|
- var module = this._module(expression.target);
|
|
|
- if (module && module.parameters) {
|
|
|
- for (var parameter of module.parameters) {
|
|
|
+ var targetModule = this._module(expression.target);
|
|
|
+ if (targetModule && targetModule.parameters) {
|
|
|
+ for (var parameter of targetModule.parameters) {
|
|
|
+ parameter.module = targetModule;
|
|
|
if (parameter.name === expression.member.value) {
|
|
|
parameter.outputs = parameter.outputs || [];
|
|
|
parameter.outputs.push(target.value);
|
|
|
return true;
|
|
|
}
|
|
|
}
|
|
|
- module.unresolvedParameters = module.unresolvedParameters || [];
|
|
|
- for (var unresolvedParameter of module.unresolvedParameters) {
|
|
|
+ targetModule.unresolvedParameters = targetModule.unresolvedParameters || [];
|
|
|
+ for (var unresolvedParameter of targetModule.unresolvedParameters) {
|
|
|
+ unresolvedParameter.module = targetModule;
|
|
|
if (unresolvedParameter.name === expression.member.value) {
|
|
|
unresolvedParameter.outputs = unresolvedParameter.outputs || [];
|
|
|
unresolvedParameter.outputs.push(target.value);
|
|
|
return true;
|
|
|
}
|
|
|
}
|
|
|
- module.unresolvedParameters.push({
|
|
|
+ targetModule.unresolvedParameters.push({
|
|
|
+ module: targetModule,
|
|
|
name: expression.member.value,
|
|
|
outputs: [ target.value ]
|
|
|
});
|
|
|
@@ -1041,8 +1211,7 @@ torchscript.GraphContext = class {
|
|
|
}
|
|
|
|
|
|
_variable() {
|
|
|
- var value = '_gen' + Math.random().toString(36).substring(7);
|
|
|
- return { type: 'identifier', value: value };
|
|
|
+ return { type: 'identifier', value: '_gen' + Math.random().toString(36).substring(7) };
|
|
|
}
|
|
|
|
|
|
_name(expression) {
|
|
|
@@ -1052,15 +1221,14 @@ torchscript.GraphContext = class {
|
|
|
if (expression.type == '.') {
|
|
|
return [ this._name(expression.target), this._name(expression.member) ].join('.');
|
|
|
}
|
|
|
- throw new torchscript.Error('Failed to resolve name.');
|
|
|
+ throw new torchscript.Error("Failed to resolve name '" + JSON.stringify(expression) + "'.");
|
|
|
}
|
|
|
|
|
|
_moduleTensor(expression) {
|
|
|
- if (expression.type == 'call' && expression.arguments.length == 1) {
|
|
|
- var name = this._name(expression.target);
|
|
|
- if (name == 'torch.t') {
|
|
|
- return expression.arguments[0];
|
|
|
- }
|
|
|
+ if (expression.type == 'call' &&
|
|
|
+ expression.arguments.length == 1 &&
|
|
|
+ this._name(expression.target) == 'torch.t') {
|
|
|
+ return expression.arguments[0];
|
|
|
}
|
|
|
return expression;
|
|
|
}
|