소스 검색

TorchScript 1.3 prototype (#281)

Lutz Roeder 6 년 전
부모
커밋
c5b1e1d054
5개의 변경된 파일236개의 추가작업 그리고 35개의 파일을 삭제
  1. 5 0
      src/python.js
  2. 59 2
      src/torchscript-metadata.json
  3. 138 32
      src/torchscript.js
  4. 30 0
      test/models.json
  5. 4 1
      tools/pytorch-script.py

+ 5 - 0
src/python.js

@@ -682,6 +682,11 @@ python.Parser = class {
                     node.value = [ stack.pop() ];
                     stack.push(node);
                 }
+                // for, bar, = <expr>
+                if (this._tokenizer.peek().value === '=') {
+                    node.value.push({ type: 'id', value: '' });
+                    continue;
+                }
                 if (!this._tokenizer.match('=') && !terminalSet.has(this._tokenizer.peek().value)) {
                     let nextTerminal = terminal.slice(0).concat([ ',', '=' ]);
                     let expression = this._parseExpression(minPrecedence, nextTerminal, tuple);

+ 59 - 2
src/torchscript-metadata.json

@@ -132,7 +132,8 @@
         { "name": "stride", "type": "int64[]" },
         { "name": "padding", "type": "int64[]", "default": 0 },
         { "name": "ceil_mode", "type": "boolean", "default": false },
-        { "name": "count_include_pad", "type": "boolean", "default": true }
+        { "name": "count_include_pad", "type": "boolean", "default": true },
+        { "name": "divisor_override" }
       ],
       "outputs": [
         { "name": "output" }
@@ -151,7 +152,8 @@
         { "name": "stride", "type": "int64[]" },
         { "name": "padding", "type": "int64[]" },
         { "name": "ceil_mode", "type": "boolean", "default": false },
-        { "name": "count_include_pad", "type": "boolean", "default": true }
+        { "name": "count_include_pad", "type": "boolean", "default": true },
+        { "name": "divisor_override" }
       ],
       "outputs": [
         { "name": "output" }
@@ -421,6 +423,18 @@
       ]
     }
   },
+  {
+    "name": "sub",
+    "schema": {
+      "inputs": [
+        { "name": "input" },
+        { "name": "other" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
   {
     "name": "div",
     "schema": {
@@ -433,6 +447,18 @@
       ]
     }
   },
+  {
+    "name": "floordiv",
+    "schema": {
+      "inputs": [
+        { "name": "input" },
+        { "name": "other" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
   {
     "name": "sum",
     "schema": {
@@ -487,6 +513,21 @@
       ]
     }
   },
+  {
+    "name": "chunk",
+    "schema": {
+      "attributes": [
+        { "name": "chunks", "type": "int64" },
+        { "name": "dim", "type": "int64", "default": 0 }
+      ],
+      "inputs": [
+        { "name": "input" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
   {
     "name": "size",
     "schema": {
@@ -789,5 +830,21 @@
         { "name": "output" }
       ]
     }
+  },
+  {
+    "name": "quantize_per_tensor",
+    "schema": {
+      "attributes": [
+        { "name": "scale", "type": "float32" },
+        { "name": "zero_point", "type": "float32" },
+        { "name": "dtype" }
+      ],
+      "inputs": [
+        { "name": "input" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
   }
 ]

+ 138 - 32
src/torchscript.js

@@ -117,7 +117,7 @@ torchscript.ModelFactory = class {
             }
             return obj;
         };
-        functionTable['torch._utils._rebuild_tensor_v2'] = function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
+        functionTable['torch._utils._rebuild_tensor_v2'] = function(storage, storage_offset, size, stride, requires_grad, backward_hooks) {
             return {
                 __type__: storage.__type__.replace('Storage', 'Tensor'),
                 storage: storage,
@@ -128,6 +128,21 @@ torchscript.ModelFactory = class {
                 backward_hooks: backward_hooks
             };
         };
+        functionTable['torch._utils._rebuild_qtensor'] = function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
+            return {
+                __type__: storage.__type__.replace('Storage', 'Tensor'),
+                storage: storage,
+                storage_offset: storage_offset,
+                size: size,
+                stride: stride,
+                quantizer_params: quantizer_params,
+                requires_grad:requires_grad,
+                backward_hooks: backward_hooks
+            };
+        }
+        functionTable['torch.jit._pickle.build_intlist'] = function(data) {
+            return data;
+        }
         let constructorTable = {};
         constructorTable['torch.ByteStorage'] = function (size) { 
             this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8'; 
@@ -153,6 +168,9 @@ torchscript.ModelFactory = class {
         constructorTable['torch.DoubleStorage'] = function (size) { 
             this.size = size; this.dataTypeSize = 8; this.dataType = 'float64';
         };
+        constructorTable['torch.QInt8Storage'] = function (size) { 
+            this.size = size; this.dataTypeSize = 1; this.dataType = 'qint8';
+        };
         let function_call = (name, args) => {
             let func = functionTable[name];
             if (func) {
@@ -256,6 +274,7 @@ torchscript.Graph = class {
         let context = null;
         try {
             let script = '';
+            let namespaceName = null;
             let className = null;
             if (container.model && container.model.mainModule) {
                 mainModule = container.model.mainModule;
@@ -265,9 +284,10 @@ torchscript.Graph = class {
                 mainModule = container.data;
                 const typeName = mainModule.__type__.split('.');
                 className = typeName.pop();
+                namespaceName = typeName.join('.');
                 script = 'code/' + typeName.join('/') + '.py';
             }
-            context = new torchscript.GraphContext(container, python, mainModule, script, className);
+            context = new torchscript.GraphContext(container, python, mainModule, script, className, namespaceName);
         }
         catch (error) {
             let message = error && error.message ? error.message : error.toString();
@@ -900,6 +920,20 @@ torchscript.Tensor = class {
             context.state = 'Tensor has no data type.';
             return context;
         }
+        switch (this._type.dataType) {
+            case 'uint8':
+            case 'int8':
+            case 'int16':
+            case 'int32':
+            case 'int64':
+            case 'float16':
+            case 'float32':
+            case 'float64':
+                break;
+            default:
+                context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
+                return context;
+        }
         if (!this._type.shape) {
             context.state = 'Tensor has no dimensions.';
             return context;
@@ -1155,7 +1189,7 @@ torchscript.Utility = class {
 
 torchscript.GraphContext = class {
 
-    constructor(container, python, mainModule, script, className) {
+    constructor(container, python, mainModule, script, className, namespaceName) {
 
         this._container = container;
         this._mainModule = mainModule;
@@ -1165,6 +1199,7 @@ torchscript.GraphContext = class {
         this._nodes = [];
 
         this._moduleMap = new Map();
+        this._classMap = new Map();
         this._state = {};
 
         if (script) {
@@ -1178,8 +1213,18 @@ torchscript.GraphContext = class {
                 let program = reader.parse();
                 let statements = program.body;
                 if (className) {
-                    let block = statements.find((statment) => statment.type == 'class' && statment.name == className);
-                    statements = block.body.statements;
+                    let main = null;
+                    for (let statement of statements) {
+                        if (statement.type == 'class') {
+                            if (namespaceName) {
+                                this._classMap.set(namespaceName + '.' + statement.name, statement);
+                            }
+                            if (statement.name == className) {
+                                main = statement;
+                            }
+                        }
+                    }
+                    statements = main.body.statements;
                 }
                 let method = statements.find((statement) => statement.type == 'def' && statement.name == 'forward');
                 if (method) {
@@ -1301,24 +1346,28 @@ torchscript.GraphContext = class {
     _returnStatement(statement) {
         if (statement.type == 'return') {
             let variable = this._variable();
-            if (this._nodeExpression(statement.expression, variable)) {
+            let expression = statement.expression;
+            if (this._nodeExpression(expression, variable)) {
                 this._outputs.push(variable.value);
                 return true;
             }
-            if (statement.expression.type == 'id') {
-                this._outputs.push(statement.expression.value);
+            if (expression.type == 'id' && this._state[expression.value] && this._state[expression.value].type === 'tuple' ) {
+                expression = this._state[expression.value];
+            }
+            if (expression.type == 'id') {
+                this._outputs.push(expression.value);
                 return true;
             }
-            if (statement.expression.type == 'tuple') {
+            if (expression.type == 'tuple') {
                 let outputs = [];
-                for (let expression of statement.expression.value) {
+                for (let item of expression.value) {
                     variable = this._variable();
-                    if (this._nodeExpression(expression, variable)) {
+                    if (this._nodeExpression(item, variable)) {
                         outputs.push(variable.value);
                         continue;
                     }
-                    if (expression.type == 'id') {
-                        outputs.push(expression.value);
+                    if (item.type == 'id') {
+                        outputs.push(item.value);
                         continue;
                     }
                     return false;
@@ -1341,6 +1390,9 @@ torchscript.GraphContext = class {
                 while (args.length > 0) {
                     let argumentExpression = args[0];
                     argumentExpression = this._moduleTensor(argumentExpression);
+                    if (this._isCall(argumentExpression, 'ops.prim.data', [ {} ])) {
+                        argumentExpression = argumentExpression.arguments[0];
+                    }
                     if (argumentExpression.type == 'id' &&
                         this._state[argumentExpression.value]) {
                         const valueExpression = this._state[argumentExpression.value];
@@ -1384,15 +1436,18 @@ torchscript.GraphContext = class {
                     if (argumentExpression.type == 'list') {
                         break;
                     }
-                    if (argumentExpression.type == 'number' || argumentExpression.type == 'string' || argumentExpression.type == 'boolean') {
+                    if (argumentExpression.type === 'number' || argumentExpression.type == 'string' || argumentExpression.type == 'boolean') {
                         break;
                     }
-                    if (argumentExpression.type == '=') {
+                    if (argumentExpression.type === '=') {
                         break;
                     }
                     if (this._isCall(argumentExpression, 'torch.list_with_default', [ {}, {} ])) {
                         break;
                     }
+                    if (this._isCall(argumentExpression, 'torch.device', [ { type: 'string' } ])) {
+                        break;
+                    }
                     const variable = this._variable();
                     if (this._nodeExpression(argumentExpression, variable)) {
                         inputs.push([ { id: variable.value } ]);
@@ -1487,13 +1542,7 @@ torchscript.GraphContext = class {
                 return this._state[expression.value];
             }
         }
-        if (this._isCall(expression, 'int', [ {} ])) {
-            let replace = this._attributeExpression(expression.arguments[0]);
-            if (replace) {
-                return replace;
-            }
-        }
-        return expression;
+        return this._evaluateExpression(expression);
     }
 
     _assignStatement(statement) {
@@ -1555,7 +1604,12 @@ torchscript.GraphContext = class {
             }
             // exponential_average_factor = 0.10000000000000001
             if (expression.type === 'number') {
-                this._state[target.value] = Number(expression.value);
+                this._state[target.value] = expression.value;
+                return true;
+            }
+            const valueExpression = this._evaluateExpression(expression);
+            if (valueExpression.type === 'number' || this._isBooleanLiteral(valueExpression)) {
+                this._state[target.value] = expression;
                 return true;
             }
             // _aux = None
@@ -1576,6 +1630,18 @@ torchscript.GraphContext = class {
                 this._moduleMap.set(moduleName, module);
                 return true;
             }
+            // _14190 = __torch__.torchvision.models.inception.InceptionOutputs(x219, aux)
+            if (expression.type == 'call') {
+                const className = torchscript.Utility.target(expression.target);
+                if (this._classMap.has(className)) {
+                    const tuple = this._classMap.get(className);
+                    if (tuple && tuple.base && tuple.base.length > 0 &&
+                        tuple.base[0].type === 'id' && tuple.base[0].value === 'NamedTuple') {
+                        this._state[target.value] = { type: 'tuple', value: expression.arguments };
+                        return true;
+                    }
+                }
+            }
         }
         return false;
     }
@@ -1769,10 +1835,41 @@ torchscript.GraphContext = class {
         return expression && expression.type === 'id' && (expression.value === 'True' || expression.value === 'False');
     }
 
+    _evaluateExpression(expression) {
+        // _150.drop_rate
+        if (expression.type === '.') {
+            const module = this._getModule(expression.target);
+            if (module &&
+                expression.member.type === 'id' &&
+                Object.prototype.hasOwnProperty.call(module, expression.member.value)) {
+                const value = module[expression.member.value];
+                if (typeof value === 'number') {
+                    return { type: 'number', value: value };
+                }
+            }
+        }
+        // int(x)
+        if (this._isCall(expression, 'int', [ {} ])) {
+            return this._evaluateExpression(expression.arguments[0]);
+        }
+        // float(x)
+        if (this._isCall(expression, 'float', [ {} ])) {
+            return this._evaluateExpression(expression.arguments[0]);
+        }
+        return expression;
+    }
+
     _evaluateBooleanExpression(expression) {
         // torch.eq("zeros", "circular"):
-        if (this._isCall(expression, 'torch.eq', [ { type: 'string' }, { type: 'string' } ])) {
-            return this._toBooleanLiteral(expression.arguments[0].value === expression.arguments[1].value);
+        if (this._isCall(expression, 'torch.eq', [ {}, {} ])) {
+            const left = this._evaluateExpression(expression.arguments[0]);
+            const right = this._evaluateExpression(expression.arguments[1]);
+            if (left.type === 'number' && right.type === 'number') {
+                return this._toBooleanLiteral(Number(left.value) === Number(right.value));
+            }
+            if (left.type === 'string' && right.type === 'string') {
+                return this._toBooleanLiteral(left.value === right.value);
+            }
         }
         // torch.eq(torch.dim(x4), 2):
         if (this._isCall(expression, 'torch.eq', [ {}, { type: 'number' } ]) &&
@@ -1780,9 +1877,11 @@ torchscript.GraphContext = class {
             return this._toBooleanLiteral(true); // TODO
         }
         // torch.ne(torch.dim(x4), 4):
-        if (this._isCall(expression, 'torch.ne', [ {}, { type: 'number' } ]) &&
-            this._isCall(expression.arguments[0], 'torch.dim', [ { type: 'id' } ])) {
-            return this._toBooleanLiteral(false); // TODO
+        if (this._isCall(expression, 'torch.ne', [ {}, { type: 'number' } ])) {
+            if (this._isCall(expression.arguments[0], 'torch.dim', [ { type: 'id' } ]) ||
+                this._isCall(expression.arguments[0], 'torch.len', [ {} ])) {
+                return this._toBooleanLiteral(false); // TODO
+            }
         }
         // torch.__is__(None, None)
         if (this._isCall(expression, 'torch.__is__', [ { type: 'id', value: 'None' }, { type: 'id', value: 'None' } ])) {
@@ -1804,7 +1903,10 @@ torchscript.GraphContext = class {
         }
         // torch.__isnot__(<id>, None)
         if (this._isCall(expression, 'torch.__isnot__', [ { type: 'id' }, { type: 'id', value: 'None' } ])) {
-            const argumentExpression = this._state[expression.arguments[0].value];
+            let argumentExpression = expression.arguments[0];
+            if (this._state[argumentExpression.value]) {
+                argumentExpression = this._state[argumentExpression.value];
+            }
             if (argumentExpression) {
                 return this._toBooleanLiteral(argumentExpression.value !== 'None');
             }
@@ -1818,11 +1920,15 @@ torchscript.GraphContext = class {
         }
         // torch.lt(0.5, 0.)
         if (this._isCall(expression, 'torch.lt', [ { type: 'number' }, { type: 'number' } ])) {
-            return this._toBooleanLiteral(Number(expression.arguments[0].value) < Number(expression.arguments[0].value));
+            return this._toBooleanLiteral(Number(expression.arguments[0].value) < Number(expression.arguments[1].value));
         }
         // torch.gt(0.5, 0.)
-        if (this._isCall(expression, 'torch.gt', [ { type: 'number' }, { type: 'number' } ])) {
-            return this._toBooleanLiteral(Number(expression.arguments[0].value) > Number(expression.arguments[0].value));
+        if (this._isCall(expression, 'torch.gt', [ {}, {} ])) {
+            const left = this._evaluateExpression(expression.arguments[0]);
+            const right = this._evaluateExpression(expression.arguments[1]);
+            if (left.type === 'number' && right.type === 'number') {
+                return this._toBooleanLiteral(Number(left.value) > Number(right.value));
+            }
         }
         // torch.__not__(...)
         if (this._isCall(expression, 'torch.__not__', [ { type: 'id' } ])) {

+ 30 - 0
test/models.json

@@ -5046,6 +5046,14 @@
     "format": "TorchScript v1",
     "link":   "https://github.com/ApolloAuto/apollo"
   },
+  {
+    "type":   "torchscript",
+    "target": "densenet121.pt",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "render": "skip",
+    "script": [ "${root}/tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
   {
     "type":   "torchscript",
     "target": "densenet121_traced.pt",
@@ -5068,6 +5076,13 @@
     "format": "TorchScript v1",
     "link":   "https://github.com/lutzroeder/netron/issues/281"
   },
+  {
+    "type":   "torchscript",
+    "target": "inception_v3.pt",
+    "script": [ "${root}/tools/pytorch", "sync install zoo" ],
+    "format": "TorchScript v1",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
+  },
   {
     "type":   "torchscript",
     "target": "inception_v3_traced.pt",
@@ -5173,6 +5188,21 @@
     "format": "TorchScript v1",
     "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
   },
+  {
+    "type":   "torchscript",
+    "target": "shufflenet_v2_x1_0.pt",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "render": "skip",
+    "script": [ "${root}/tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "type":   "torchscript",
+    "target": "squeezenet1_1.pt",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "${root}/tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
   {
     "type":   "torchscript",
     "target": "squeezenet1_1_traced.pt",

+ 4 - 1
tools/pytorch-script.py

@@ -92,11 +92,14 @@ def zoo():
     download_pytorch_model('torchvision.models.vgg11_bn', '${test}/data/pytorch/vgg11_bn.pth')
     download_pytorch_model('torchvision.models.vgg16', '${test}/data/pytorch/vgg16.pth')
     download_torchscript_model('torchvision.models.alexnet', '${test}/data/torchscript/alexnet.pt')
+    download_torchscript_model('torchvision.models.densenet121', '${test}/data/torchscript/densenet121.pt')
     download_torchscript_model('torchvision.models.inception_v3', '${test}/data/torchscript/inception_v3.pt')
+    download_torchscript_model('torchvision.models.mobilenet_v2', '${test}/data/torchscript/mobilenet_v2.pt')
+    download_torchscript_model('torchvision.models.mnasnet1_0', '${test}/data/torchscript/mnasnet1_0.pt')
     download_torchscript_model('torchvision.models.resnet18', '${test}/data/torchscript/resnet18.pt')
     download_torchscript_model('torchvision.models.resnet50', '${test}/data/torchscript/resnet50.pt')
+    download_torchscript_model('torchvision.models.shufflenet_v2_x1_0', '${test}/data/torchscript/shufflenet_v2_x1_0.pt')
     download_torchscript_model('torchvision.models.squeezenet1_1', '${test}/data/torchscript/squeezenet1_1.pt')
-    download_torchscript_model('torchvision.models.mobilenet_v2', '${test}/data/torchscript/mobilenet_v2.pt')
     download_torchscript_model('torchvision.models.vgg16', '${test}/data/torchscript/vgg16.pt')
     download_torchscript_traced_model('torchvision.models.alexnet', '${test}/data/torchscript/alexnet_traced.pt', [ 1, 3, 299, 299 ])
     download_torchscript_traced_model('torchvision.models.densenet121', '${test}/data/torchscript/densenet121_traced.pt', [ 1, 3, 224, 224 ])