|
|
@@ -117,7 +117,7 @@ torchscript.ModelFactory = class {
|
|
|
}
|
|
|
return obj;
|
|
|
};
|
|
|
- functionTable['torch._utils._rebuild_tensor_v2'] = function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
|
|
|
+ functionTable['torch._utils._rebuild_tensor_v2'] = function(storage, storage_offset, size, stride, requires_grad, backward_hooks) {
|
|
|
return {
|
|
|
__type__: storage.__type__.replace('Storage', 'Tensor'),
|
|
|
storage: storage,
|
|
|
@@ -128,6 +128,21 @@ torchscript.ModelFactory = class {
|
|
|
backward_hooks: backward_hooks
|
|
|
};
|
|
|
};
|
|
|
+ functionTable['torch._utils._rebuild_qtensor'] = function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
|
|
|
+ return {
|
|
|
+ __type__: storage.__type__.replace('Storage', 'Tensor'),
|
|
|
+ storage: storage,
|
|
|
+ storage_offset: storage_offset,
|
|
|
+ size: size,
|
|
|
+ stride: stride,
|
|
|
+ quantizer_params: quantizer_params,
|
|
|
+ requires_grad:requires_grad,
|
|
|
+ backward_hooks: backward_hooks
|
|
|
+ };
|
|
|
+ }
|
|
|
+ functionTable['torch.jit._pickle.build_intlist'] = function(data) {
|
|
|
+ return data;
|
|
|
+ }
|
|
|
let constructorTable = {};
|
|
|
constructorTable['torch.ByteStorage'] = function (size) {
|
|
|
this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8';
|
|
|
@@ -153,6 +168,9 @@ torchscript.ModelFactory = class {
|
|
|
constructorTable['torch.DoubleStorage'] = function (size) {
|
|
|
this.size = size; this.dataTypeSize = 8; this.dataType = 'float64';
|
|
|
};
|
|
|
+ constructorTable['torch.QInt8Storage'] = function (size) {
|
|
|
+ this.size = size; this.dataTypeSize = 1; this.dataType = 'qint8';
|
|
|
+ };
|
|
|
let function_call = (name, args) => {
|
|
|
let func = functionTable[name];
|
|
|
if (func) {
|
|
|
@@ -256,6 +274,7 @@ torchscript.Graph = class {
|
|
|
let context = null;
|
|
|
try {
|
|
|
let script = '';
|
|
|
+ let namespaceName = null;
|
|
|
let className = null;
|
|
|
if (container.model && container.model.mainModule) {
|
|
|
mainModule = container.model.mainModule;
|
|
|
@@ -265,9 +284,10 @@ torchscript.Graph = class {
|
|
|
mainModule = container.data;
|
|
|
const typeName = mainModule.__type__.split('.');
|
|
|
className = typeName.pop();
|
|
|
+ namespaceName = typeName.join('.');
|
|
|
script = 'code/' + typeName.join('/') + '.py';
|
|
|
}
|
|
|
- context = new torchscript.GraphContext(container, python, mainModule, script, className);
|
|
|
+ context = new torchscript.GraphContext(container, python, mainModule, script, className, namespaceName);
|
|
|
}
|
|
|
catch (error) {
|
|
|
let message = error && error.message ? error.message : error.toString();
|
|
|
@@ -900,6 +920,20 @@ torchscript.Tensor = class {
|
|
|
context.state = 'Tensor has no data type.';
|
|
|
return context;
|
|
|
}
|
|
|
+ switch (this._type.dataType) {
|
|
|
+ case 'uint8':
|
|
|
+ case 'int8':
|
|
|
+ case 'int16':
|
|
|
+ case 'int32':
|
|
|
+ case 'int64':
|
|
|
+ case 'float16':
|
|
|
+ case 'float32':
|
|
|
+ case 'float64':
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
|
|
|
+ return context;
|
|
|
+ }
|
|
|
if (!this._type.shape) {
|
|
|
context.state = 'Tensor has no dimensions.';
|
|
|
return context;
|
|
|
@@ -1155,7 +1189,7 @@ torchscript.Utility = class {
|
|
|
|
|
|
torchscript.GraphContext = class {
|
|
|
|
|
|
- constructor(container, python, mainModule, script, className) {
|
|
|
+ constructor(container, python, mainModule, script, className, namespaceName) {
|
|
|
|
|
|
this._container = container;
|
|
|
this._mainModule = mainModule;
|
|
|
@@ -1165,6 +1199,7 @@ torchscript.GraphContext = class {
|
|
|
this._nodes = [];
|
|
|
|
|
|
this._moduleMap = new Map();
|
|
|
+ this._classMap = new Map();
|
|
|
this._state = {};
|
|
|
|
|
|
if (script) {
|
|
|
@@ -1178,8 +1213,18 @@ torchscript.GraphContext = class {
|
|
|
let program = reader.parse();
|
|
|
let statements = program.body;
|
|
|
if (className) {
|
|
|
- let block = statements.find((statment) => statment.type == 'class' && statment.name == className);
|
|
|
- statements = block.body.statements;
|
|
|
+ let main = null;
|
|
|
+ for (let statement of statements) {
|
|
|
+ if (statement.type == 'class') {
|
|
|
+ if (namespaceName) {
|
|
|
+ this._classMap.set(namespaceName + '.' + statement.name, statement);
|
|
|
+ }
|
|
|
+ if (statement.name == className) {
|
|
|
+ main = statement;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ statements = main.body.statements;
|
|
|
}
|
|
|
let method = statements.find((statement) => statement.type == 'def' && statement.name == 'forward');
|
|
|
if (method) {
|
|
|
@@ -1301,24 +1346,28 @@ torchscript.GraphContext = class {
|
|
|
_returnStatement(statement) {
|
|
|
if (statement.type == 'return') {
|
|
|
let variable = this._variable();
|
|
|
- if (this._nodeExpression(statement.expression, variable)) {
|
|
|
+ let expression = statement.expression;
|
|
|
+ if (this._nodeExpression(expression, variable)) {
|
|
|
this._outputs.push(variable.value);
|
|
|
return true;
|
|
|
}
|
|
|
- if (statement.expression.type == 'id') {
|
|
|
- this._outputs.push(statement.expression.value);
|
|
|
+ if (expression.type == 'id' && this._state[expression.value] && this._state[expression.value].type === 'tuple' ) {
|
|
|
+ expression = this._state[expression.value];
|
|
|
+ }
|
|
|
+ if (expression.type == 'id') {
|
|
|
+ this._outputs.push(expression.value);
|
|
|
return true;
|
|
|
}
|
|
|
- if (statement.expression.type == 'tuple') {
|
|
|
+ if (expression.type == 'tuple') {
|
|
|
let outputs = [];
|
|
|
- for (let expression of statement.expression.value) {
|
|
|
+ for (let item of expression.value) {
|
|
|
variable = this._variable();
|
|
|
- if (this._nodeExpression(expression, variable)) {
|
|
|
+ if (this._nodeExpression(item, variable)) {
|
|
|
outputs.push(variable.value);
|
|
|
continue;
|
|
|
}
|
|
|
- if (expression.type == 'id') {
|
|
|
- outputs.push(expression.value);
|
|
|
+ if (item.type == 'id') {
|
|
|
+ outputs.push(item.value);
|
|
|
continue;
|
|
|
}
|
|
|
return false;
|
|
|
@@ -1341,6 +1390,9 @@ torchscript.GraphContext = class {
|
|
|
while (args.length > 0) {
|
|
|
let argumentExpression = args[0];
|
|
|
argumentExpression = this._moduleTensor(argumentExpression);
|
|
|
+ if (this._isCall(argumentExpression, 'ops.prim.data', [ {} ])) {
|
|
|
+ argumentExpression = argumentExpression.arguments[0];
|
|
|
+ }
|
|
|
if (argumentExpression.type == 'id' &&
|
|
|
this._state[argumentExpression.value]) {
|
|
|
const valueExpression = this._state[argumentExpression.value];
|
|
|
@@ -1384,15 +1436,18 @@ torchscript.GraphContext = class {
|
|
|
if (argumentExpression.type == 'list') {
|
|
|
break;
|
|
|
}
|
|
|
- if (argumentExpression.type == 'number' || argumentExpression.type == 'string' || argumentExpression.type == 'boolean') {
|
|
|
+ if (argumentExpression.type === 'number' || argumentExpression.type == 'string' || argumentExpression.type == 'boolean') {
|
|
|
break;
|
|
|
}
|
|
|
- if (argumentExpression.type == '=') {
|
|
|
+ if (argumentExpression.type === '=') {
|
|
|
break;
|
|
|
}
|
|
|
if (this._isCall(argumentExpression, 'torch.list_with_default', [ {}, {} ])) {
|
|
|
break;
|
|
|
}
|
|
|
+ if (this._isCall(argumentExpression, 'torch.device', [ { type: 'string' } ])) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
const variable = this._variable();
|
|
|
if (this._nodeExpression(argumentExpression, variable)) {
|
|
|
inputs.push([ { id: variable.value } ]);
|
|
|
@@ -1487,13 +1542,7 @@ torchscript.GraphContext = class {
|
|
|
return this._state[expression.value];
|
|
|
}
|
|
|
}
|
|
|
- if (this._isCall(expression, 'int', [ {} ])) {
|
|
|
- let replace = this._attributeExpression(expression.arguments[0]);
|
|
|
- if (replace) {
|
|
|
- return replace;
|
|
|
- }
|
|
|
- }
|
|
|
- return expression;
|
|
|
+ return this._evaluateExpression(expression);
|
|
|
}
|
|
|
|
|
|
_assignStatement(statement) {
|
|
|
@@ -1555,7 +1604,12 @@ torchscript.GraphContext = class {
|
|
|
}
|
|
|
// exponential_average_factor = 0.10000000000000001
|
|
|
if (expression.type === 'number') {
|
|
|
- this._state[target.value] = Number(expression.value);
|
|
|
+ this._state[target.value] = expression.value;
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ const valueExpression = this._evaluateExpression(expression);
|
|
|
+ if (valueExpression.type === 'number' || this._isBooleanLiteral(valueExpression)) {
|
|
|
+ this._state[target.value] = expression;
|
|
|
return true;
|
|
|
}
|
|
|
// _aux = None
|
|
|
@@ -1576,6 +1630,18 @@ torchscript.GraphContext = class {
|
|
|
this._moduleMap.set(moduleName, module);
|
|
|
return true;
|
|
|
}
|
|
|
+ // _14190 = __torch__.torchvision.models.inception.InceptionOutputs(x219, aux)
|
|
|
+ if (expression.type == 'call') {
|
|
|
+ const className = torchscript.Utility.target(expression.target);
|
|
|
+ if (this._classMap.has(className)) {
|
|
|
+ const tuple = this._classMap.get(className);
|
|
|
+ if (tuple && tuple.base && tuple.base.length > 0 &&
|
|
|
+ tuple.base[0].type === 'id' && tuple.base[0].value === 'NamedTuple') {
|
|
|
+ this._state[target.value] = { type: 'tuple', value: expression.arguments };
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
return false;
|
|
|
}
|
|
|
@@ -1769,10 +1835,41 @@ torchscript.GraphContext = class {
|
|
|
return expression && expression.type === 'id' && (expression.value === 'True' || expression.value === 'False');
|
|
|
}
|
|
|
|
|
|
+ _evaluateExpression(expression) {
|
|
|
+ // _150.drop_rate
|
|
|
+ if (expression.type === '.') {
|
|
|
+ const module = this._getModule(expression.target);
|
|
|
+ if (module &&
|
|
|
+ expression.member.type === 'id' &&
|
|
|
+ Object.prototype.hasOwnProperty.call(module, expression.member.value)) {
|
|
|
+ const value = module[expression.member.value];
|
|
|
+ if (typeof value === 'number') {
|
|
|
+ return { type: 'number', value: value };
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // int(x)
|
|
|
+ if (this._isCall(expression, 'int', [ {} ])) {
|
|
|
+ return this._evaluateExpression(expression.arguments[0]);
|
|
|
+ }
|
|
|
+ // float(x)
|
|
|
+ if (this._isCall(expression, 'float', [ {} ])) {
|
|
|
+ return this._evaluateExpression(expression.arguments[0]);
|
|
|
+ }
|
|
|
+ return expression;
|
|
|
+ }
|
|
|
+
|
|
|
_evaluateBooleanExpression(expression) {
|
|
|
// torch.eq("zeros", "circular"):
|
|
|
- if (this._isCall(expression, 'torch.eq', [ { type: 'string' }, { type: 'string' } ])) {
|
|
|
- return this._toBooleanLiteral(expression.arguments[0].value === expression.arguments[1].value);
|
|
|
+ if (this._isCall(expression, 'torch.eq', [ {}, {} ])) {
|
|
|
+ const left = this._evaluateExpression(expression.arguments[0]);
|
|
|
+ const right = this._evaluateExpression(expression.arguments[1]);
|
|
|
+ if (left.type === 'number' && right.type === 'number') {
|
|
|
+ return this._toBooleanLiteral(Number(left.value) === Number(right.value));
|
|
|
+ }
|
|
|
+ if (left.type === 'string' && right.type === 'string') {
|
|
|
+ return this._toBooleanLiteral(left.value === right.value);
|
|
|
+ }
|
|
|
}
|
|
|
// torch.eq(torch.dim(x4), 2):
|
|
|
if (this._isCall(expression, 'torch.eq', [ {}, { type: 'number' } ]) &&
|
|
|
@@ -1780,9 +1877,11 @@ torchscript.GraphContext = class {
|
|
|
return this._toBooleanLiteral(true); // TODO
|
|
|
}
|
|
|
// torch.ne(torch.dim(x4), 4):
|
|
|
- if (this._isCall(expression, 'torch.ne', [ {}, { type: 'number' } ]) &&
|
|
|
- this._isCall(expression.arguments[0], 'torch.dim', [ { type: 'id' } ])) {
|
|
|
- return this._toBooleanLiteral(false); // TODO
|
|
|
+ if (this._isCall(expression, 'torch.ne', [ {}, { type: 'number' } ])) {
|
|
|
+ if (this._isCall(expression.arguments[0], 'torch.dim', [ { type: 'id' } ]) ||
|
|
|
+ this._isCall(expression.arguments[0], 'torch.len', [ {} ])) {
|
|
|
+ return this._toBooleanLiteral(false); // TODO
|
|
|
+ }
|
|
|
}
|
|
|
// torch.__is__(None, None)
|
|
|
if (this._isCall(expression, 'torch.__is__', [ { type: 'id', value: 'None' }, { type: 'id', value: 'None' } ])) {
|
|
|
@@ -1804,7 +1903,10 @@ torchscript.GraphContext = class {
|
|
|
}
|
|
|
// torch.__isnot__(<id>, None)
|
|
|
if (this._isCall(expression, 'torch.__isnot__', [ { type: 'id' }, { type: 'id', value: 'None' } ])) {
|
|
|
- const argumentExpression = this._state[expression.arguments[0].value];
|
|
|
+ let argumentExpression = expression.arguments[0];
|
|
|
+ if (this._state[argumentExpression.value]) {
|
|
|
+ argumentExpression = this._state[argumentExpression.value];
|
|
|
+ }
|
|
|
if (argumentExpression) {
|
|
|
return this._toBooleanLiteral(argumentExpression.value !== 'None');
|
|
|
}
|
|
|
@@ -1818,11 +1920,15 @@ torchscript.GraphContext = class {
|
|
|
}
|
|
|
// torch.lt(0.5, 0.)
|
|
|
if (this._isCall(expression, 'torch.lt', [ { type: 'number' }, { type: 'number' } ])) {
|
|
|
- return this._toBooleanLiteral(Number(expression.arguments[0].value) < Number(expression.arguments[0].value));
|
|
|
+ return this._toBooleanLiteral(Number(expression.arguments[0].value) < Number(expression.arguments[1].value));
|
|
|
}
|
|
|
// torch.gt(0.5, 0.)
|
|
|
- if (this._isCall(expression, 'torch.gt', [ { type: 'number' }, { type: 'number' } ])) {
|
|
|
- return this._toBooleanLiteral(Number(expression.arguments[0].value) > Number(expression.arguments[0].value));
|
|
|
+ if (this._isCall(expression, 'torch.gt', [ {}, {} ])) {
|
|
|
+ const left = this._evaluateExpression(expression.arguments[0]);
|
|
|
+ const right = this._evaluateExpression(expression.arguments[1]);
|
|
|
+ if (left.type === 'number' && right.type === 'number') {
|
|
|
+ return this._toBooleanLiteral(Number(left.value) > Number(right.value));
|
|
|
+ }
|
|
|
}
|
|
|
// torch.__not__(...)
|
|
|
if (this._isCall(expression, 'torch.__not__', [ { type: 'id' } ])) {
|