Lutz Roeder 6 лет назад
Родитель
Сommit
e2b8f5933a
3 измененных файлов с 86 добавлено и 21 удалено
  1. 68 10
      src/pytorch-metadata.json
  2. 17 10
      src/pytorch.js
  3. 1 1
      test/models.json

+ 68 - 10
src/pytorch-metadata.json

@@ -844,6 +844,73 @@
       ]
     }
   },
+  {
+    "name": "ops.quantized.add",
+    "schema": {
+      "attributes": [
+        {
+          "name": "scale"
+        },
+        {
+          "name": "zero_point"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "A"
+        },
+        {
+          "name": "B"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "C"
+        }
+      ]
+    }
+  },
+  {
+    "name": "ops.quantized.linear",
+    "schema": {
+      "attributes": [
+        {
+          "name": "scale"
+        },
+        {
+          "name": "zero_point"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "input"
+        },
+        {
+          "name": "packed_params"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.dequantize",
+    "schema": {
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
   {
     "name": "ops.quantized.conv2d",
     "schema": {
@@ -875,10 +942,7 @@
           "name": "input"
         },
         {
-          "name": "weight"
-        },
-        {
-          "name": "bias"
+          "name": "packed_params"
         }
       ],
       "outputs": [
@@ -937,12 +1001,6 @@
       ]
     }
   },
-  {
-    "name": "ops.quantized.linear",
-    "schema": {
-      "category": "Layer"
-    }
-  },
   {
     "name": "torch.max_pool2d",
     "schema": {

+ 17 - 10
src/pytorch.js

@@ -89,7 +89,7 @@ pytorch.Graph = class {
                                 if (pytorch.Utility.isTensor(obj)) {
                                     let parameter = obj;
                                     parameter.__parent__ = module;
-                                    if (!parameter.initializer) {
+                                    if (!parameter.initializer && parameter.storage) {
                                         parameter.initializer = new pytorch.Tensor(parameter.name, parameter, true);
                                     }
                                     if (parameter.__outputs__ && parameter.__outputs__.length == 1) {
@@ -1463,17 +1463,17 @@ pytorch.Execution = class {
         this._registerFunction('ops.prim.NumToTensor', function(value) {
             return { __module__: 'torch', __name__: 'Tensor', value: value }; // TODO
         });
-        this._registerFunction('ops.prim.shape', function(/* value */) {
-            return undefined; // TODO
+        this._registerFunction('ops.prim.shape', function(value) {
+            return value.size;
         });
         this._registerFunction('ops.quantized.conv_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
-            return { __module__: 'torch', __name__: '__conv_prepack__' }; // TODO
+            return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv_prepack' }; // TODO
         });
         this._registerFunction('ops.quantized.conv2d_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
-            return { __module__: 'torch', __name__: '__conv2d_prepack__' }; // TODO
+            return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv2d_prepack' }; // TODO
         });
         this._registerFunction('ops.quantized.linear_prepack', function(/* weight, bias */) {
-            return { __module__: 'torch', __name__: '__linear_prepack__' }; // TODO
+            return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.linear_prepack' }; // TODO
         });
 
         this._registerFunction('ops.prim.RaiseException', function(message) {
@@ -1546,7 +1546,10 @@ pytorch.Execution = class {
         this._registerFunction('torch.jit._pickle.build_tensorlist', function(data) {
             return data;
         });
-        this._registerFunction('torch.len', function(/* value */) {
+        this._registerFunction('torch.len', function(value) {
+            if (value) {
+                return value.length;
+            }
             return undefined;
         });
         this._registerFunction('torch.list_with_default', function(size /*, defaults */) {
@@ -1556,7 +1559,7 @@ pytorch.Execution = class {
             if (typeof left === 'number' && typeof right === 'number') {
                 return left < right;
             }
-            throw new pytorch.Error('Unknown expression type.');
+            throw new pytorch.Error('Unknown torch.lt expression type.');
         });
         this._registerFunction('torch.mul', function(left, right) {
             if (typeof left === 'number' && typeof right === 'number') {
@@ -1565,13 +1568,16 @@ pytorch.Execution = class {
             if (pytorch.Utility.isTensor(left) && pytorch.Utility.isTensor(right)) {
                 return { __module__: 'torch', __name__: 'Tensor', __origin__: 'torch.mul' };
             }
-            throw new pytorch.Error('Unknown expression type.');
+            throw new pytorch.Error('Unknown torch.mul expression type.');
         });
         this._registerFunction('torch.ne', function(left, right) {
             if (typeof left === 'number' && typeof right === 'number') {
                 return left !== right;
             }
-            throw new pytorch.Error('Unknown expression type.');
+            if (left === undefined && typeof right === 'number') {
+                return false;
+            }
+            throw new pytorch.Error('Unknown torch.ne expression type.');
         });
         this._registerFunction('torch.q_scale', function(/* tensor */) {
             return -1; // TODO
@@ -2829,6 +2835,7 @@ pytorch.Container.Zip = class {
                     case 'torch.cat': 
                     case 'torch.conv2d': 
                     case 'torch.flatten':
+                    case 'torch.quantize_per_tensor':
                     case 'torch.relu_':
                     case 'torch.dropout': {
                         parameter.size = [ undefined, undefined, undefined, undefined ];

+ 1 - 1
test/models.json

@@ -3971,7 +3971,7 @@
     "type":   "pytorch",
     "target": "inception_v3.pt",
     "script": "./tools/pytorch sync install zoo",
-    "error":  "Unknown expression type in 'inception_v3.pt'.",
+    "error":  "Unknown torch.mul expression type in 'inception_v3.pt'.",
     "format": "TorchScript v1.4",
     "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
   },