Quellcode durchsuchen

TorchScript 1.4 prototype (#281)

Lutz Roeder vor 6 Jahren
Ursprung
Commit
e85ce03482
3 geänderte Dateien mit 529 neuen und 108 gelöschten Zeilen
  1. 387 24
      src/pytorch-metadata.json
  2. 120 83
      src/pytorch.js
  3. 22 1
      test/models.json

+ 387 - 24
src/pytorch-metadata.json

@@ -834,7 +834,8 @@
           "name": "weight"
         },
         {
-          "name": "bias"
+          "name": "bias",
+          "option": "optional"
         }
       ],
       "outputs": [
@@ -1552,6 +1553,32 @@
       ]
     }
   },
+  {
+    "name": "torch.feature_dropout",
+    "schema": {
+      "attributes": [
+        {
+          "name": "p",
+          "type": "float64"
+        },
+        {
+          "name": "train",
+          "type": "boolean"
+        }
+      ],
+      "category": "Dropout",
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
   {
     "name": "torch.dropout_",
     "schema": {
@@ -1661,8 +1688,15 @@
     }
   },
   {
-    "name": "torch.mul_",
+    "name": "torch.add:Tensor",
     "schema": {
+      "attributes": [
+        {
+          "default": 1,
+          "name": "alpha",
+          "type": "float64"
+        }
+      ],
       "inputs": [
         {
           "name": "input"
@@ -1679,9 +1713,12 @@
     }
   },
   {
-    "name": "torch.add",
+    "name": "torch.add:Scalar",
     "schema": {
       "attributes": [
+        {
+          "name": "other"
+        },
         {
           "default": 1,
           "name": "alpha",
@@ -1691,9 +1728,6 @@
       "inputs": [
         {
           "name": "input"
-        },
-        {
-          "name": "other"
         }
       ],
       "outputs": [
@@ -1782,24 +1816,6 @@
       ]
     }
   },
-  {
-    "name": "torch.mul",
-    "schema": {
-      "inputs": [
-        {
-          "name": "input"
-        },
-        {
-          "name": "other"
-        }
-      ],
-      "outputs": [
-        {
-          "name": "output"
-        }
-      ]
-    }
-  },
   {
     "name": "torch.matmul",
     "schema": {
@@ -1925,6 +1941,68 @@
       ]
     }
   },
+  {
+    "name": "torch.clone",
+    "schema": {
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.expand",
+    "schema": {
+      "attributes": [
+        {
+          "default": false,
+          "name": "non_blocking",
+          "type": "boolean"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "self"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.copy_",
+    "schema": {
+      "attributes": [
+        {
+          "default": false,
+          "name": "non_blocking",
+          "type": "boolean"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "self"
+        },
+        {
+          "name": "src"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
   {
     "name": "torch.reshape",
     "schema": {
@@ -2427,5 +2505,290 @@
         }
       ]
     }
+  },
+  {
+    "name": "torch.norm",
+    "schema": {
+      "attributes": [
+        {
+          "name": "dim"
+        },
+        {
+          "name": "p"
+        },
+        {
+          "default": false,
+          "name": "keepdim",
+          "type": "boolean"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.elu_",
+    "schema": {
+      "attributes": [
+        {
+          "default": 1,
+          "name": "alpha"
+        },
+        {
+          "default": 1,
+          "name": "scale"
+        },
+        {
+          "default": 1,
+          "name": "input_scale"
+        }
+      ],
+      "category": "Activation",
+      "inputs": [
+        {
+          "name": "self"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.pixel_shuffle",
+    "schema": {
+      "attributes": [
+        {
+          "name": "upscale_factor",
+          "type": "int64"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.nonzero",
+    "schema": {
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.neg",
+    "schema": {
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.grid_sampler",
+    "schema": {
+      "attributes": [
+        {
+          "name": "interpolation_mode",
+          "type": "int64"
+        },
+        {
+          "name": "padding_mode",
+          "type": "int64"
+        },
+        {
+          "name": "align_corners",
+          "type": "boolean"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "input"
+        },
+        {
+          "name": "grid"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.sort",
+    "schema": {
+      "attributes": [
+        {
+          "default": -1,
+          "name": "dim",
+          "type": "int64"
+        },
+        {
+          "default": false,
+          "name": "descending",
+          "type": "boolean"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "input"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "values"
+        },
+        {
+          "name": "indices"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.gt:Tensor",
+    "schema": {
+      "inputs": [
+        {
+          "name": "self"
+        },
+        {
+          "name": "other"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.gt:Scalar",
+    "schema": {
+      "attributes": [
+        {
+          "name": "other"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "self"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.mul:Tensor",
+    "schema": {
+      "inputs": [
+        {
+          "name": "self"
+        },
+        {
+          "name": "other"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.mul:Scalar",
+    "schema": {
+      "attributes": [
+        {
+          "name": "other"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "self"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.mul_:Tensor",
+    "schema": {
+      "inputs": [
+        {
+          "name": "self"
+        },
+        {
+          "name": "other"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
+  {
+    "name": "torch.mul_:Scalar",
+    "schema": {
+      "attributes": [
+        {
+          "name": "other"
+        }
+      ],
+      "inputs": [
+        {
+          "name": "self"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
   }
 ]

+ 120 - 83
src/pytorch.js

@@ -538,7 +538,8 @@ pytorch.Node = class {
     }
 
     get operator() {
-        return this._type;
+        const index = this._type.indexOf(':');
+        return index === -1 ? this._type : this._type.substring(0, index);
     }
 
     get category() {
@@ -960,15 +961,28 @@ pytorch.Metadata = class {
             if (items) {
                 for (const item of items) {
                     if (item.name && item.schema) {
+                        item.schema.name = item.name;
                         this._map.set(item.name, item.schema);
                     }
+                    const index = item.name.indexOf(':');
+                    if (index !== -1) {
+                        const name = item.name.substring(0, index);
+                        if (!this._map.has(name)) {
+                            this._map.set(name, [])
+                        }
+                        this._map.get(name).push(item.name);
+                    }
                 }
             }
         }
     }
 
     type(operator) {
-        return this._map.get(operator) || null;
+        const schema = this._map.get(operator);
+        if (schema) {
+            return Array.isArray(schema) ? schema.map((name) => this._map.get(name)) : schema;
+        }
+        return null;
     }
 
     attribute(operator, name) {
@@ -1466,6 +1480,12 @@ pytorch.Execution = class {
         this._registerFunction('ops.prim.RaiseException', function(message) {
             throw new pytorch.Error(message);
         });
+        this._registerFunction('torch.add', function(left, right) {
+            if (typeof left === 'number' && typeof right === 'number') {
+                return left * right;
+            }
+            throw new pytorch.Error('Unknown torch.add expression type.');
+        });
         this._registerFunction('torch.__is__', function(left, right) {
             if (left === null && right === null) {
                 return true;
@@ -1552,9 +1572,6 @@ pytorch.Execution = class {
             if (typeof left === 'number' && typeof right === 'number') {
                 return left * right;
             }
-            if (pytorch.Utility.isTensor(left) && pytorch.Utility.isTensor(right)) {
-                return { __module__: 'torch', __name__: 'Tensor', __origin__: 'torch.mul' };
-            }
             throw new pytorch.Error('Unknown torch.mul expression type.');
         });
         this._registerFunction('torch.ne', function(left, right) {
@@ -2863,12 +2880,18 @@ pytorch.Container.Zip = class {
             let args = [ this.data ]; // self
             if (this.data.forward.__code__ && this.data.forward.__code__.parameters) {
                 for (const parameter of this.data.forward.__code__.parameters) {
-                    if (parameter.name !== 'self' && 
-                        parameter.parameterType.type === 'type' &&
-                        parameter.parameterType.name.type === 'id' &&
-                        parameter.parameterType.name.value === 'Tensor') {
-                        this._inputs.push(parameter.name);
-                        args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' });
+                    if (parameter.name !== 'self') {
+                        const type = parameter.parameterType;
+                        if (type.type === 'type' && type.name.type) {
+                            if (type.name.value === 'Tensor') {
+                                this._inputs.push(parameter.name);
+                                args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' });
+                            }
+                            if (type.name.value === 'Tuple' && type.arguments.every((item) => item.type === 'type' && item.name.type === 'id' && item.name.value === 'Tensor')) {
+                                this._inputs.push(parameter.name);
+                                args.push(type.arguments.map(() => { return { __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' } }));
+                            }
+                        }
                     }
                 }
             }
@@ -2919,87 +2942,101 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
         let callArgs = Array.prototype.slice.call(args);
         if (callTarget) {
             const type = callTarget + '.' + name;
-            const schema = this._metadata.type(type);
-            if (schema) {
-                args = Array.prototype.slice.call(args);
-                let node = {
-                    type: type,
-                    inputs: [],
-                    attributes: [],
-                    outputs: []
-                };
-                const inputSchemas = Array.prototype.slice.call(schema.inputs);
-                while (inputSchemas.length > 0) {
-                    let inputSchema = inputSchemas.shift();
-                    const argument = this.expression(callArgs.shift(), context);
-                    while (inputSchema.option === 'optional' && Array.isArray(argument) && inputSchema.type !== 'T[]' && inputSchemas.length > 0) {
-                        inputSchema = inputSchemas.shift();
-                    }
-                    const parameters = Array.isArray(argument) ? argument : [ argument ];
-                    let inputs = [];
-                    for (let parameter of parameters) {
-                        if (parameter) {
-                            if (!pytorch.Utility.isTensor(parameter)) {
-                                if (typeof parameter !== 'number' && isNaN(parameter)) {
-                                    return super.call(target, name, args, context);
+            // ./aten/src/ATen/native/native_functions.yaml
+            let schemas = this._metadata.type(type);
+            if (schemas) {
+                if (!Array.isArray(schemas)) {
+                    schemas = [ schemas ]
+                }
+                for (const schema of schemas) {
+                    let node = {
+                        type: schema.name,
+                        inputs: [],
+                        attributes: [],
+                        outputs: []
+                    };
+                    let next = false;
+                    const inputSchemas = Array.prototype.slice.call(schema.inputs);
+                    while (inputSchemas.length > 0) {
+                        let inputSchema = inputSchemas.shift();
+                        const argument = this.expression(callArgs.shift(), context);
+                        while (inputSchema.option === 'optional' && Array.isArray(argument) && inputSchema.type !== 'T[]' && inputSchemas.length > 0) {
+                            inputSchema = inputSchemas.shift();
+                        }
+                        const parameters = Array.isArray(argument) ? argument : [ argument ];
+                        let inputs = [];
+                        for (let parameter of parameters) {
+                            if (parameter !== undefined) {
+                                if (!pytorch.Utility.isTensor(parameter) && parameter !== null) {
+                                    next = true;
+                                    break;
+                                }
+                                if (parameter === null) {
+                                    parameter = {};
+                                }
+                                if (parameter.__variable__) {
+                                    inputs.push({ id: parameter.__variable__ });
+                                }
+                                else {
+                                    const id = this._variable().value;
+                                    parameter.__variable__ = id;
+                                    parameter.__outputs__ = parameter.__outputs__ || [];
+                                    parameter.__outputs__.push(id);
+                                    inputs.push({ id: id });
                                 }
-                                parameter = {};
-                            }
-                            if (parameter.__variable__) {
-                                inputs.push({ id: parameter.__variable__ });
-                            }
-                            else {
-                                const id = this._variable().value;
-                                parameter.__variable__ = id;
-                                parameter.__outputs__ = parameter.__outputs__ || [];
-                                parameter.__outputs__.push(id);
-                                inputs.push({ id: id });
                             }
                         }
+                        if (next) {
+                            break;
+                        }
+                        node.inputs.push(inputs);
                     }
-                    node.inputs.push(inputs);
-                }
-                while (callArgs.length > 0 && callArgs[0].type !== '=') {
-                    const value = this.expression(callArgs.shift(), context);
-                    node.attributes.push(value);
-                }
-                while (callArgs.length > 0) {
-                    const arg = callArgs.shift();
-                    if (arg.type === '=' && arg.target && arg.target.type === 'id') {
-                        const value = this.expression(arg.expression, context);
-                        node.attributes.push({ type: '=', target: arg.target, expression: value });
+                    if (next) {
+                        callArgs = Array.prototype.slice.call(args);
+                        continue;
                     }
-                    else {
-                        throw new pytorch.Attribute('Expected named argument.');
-                    }
-                }
-                let outputs = []
-                for (let i = 0; i < schema.outputs.length; i++) {
-                    let parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
-                    switch (type) {
-                        case 'torch.cat': 
-                        case 'torch.conv2d': 
-                        case 'torch.flatten':
-                        case 'torch.quantize_per_tensor':
-                        case 'torch.relu_':
-                        case 'torch.dropout': {
-                            parameter.size = [ undefined, undefined, undefined, undefined ];
-                            break;
+                    while (callArgs.length > 0 && callArgs[0].type !== '=') {
+                        const value = this.expression(callArgs.shift(), context);
+                        node.attributes.push(value);
+                    }
+                    while (callArgs.length > 0) {
+                        const arg = callArgs.shift();
+                        if (arg.type === '=' && arg.target && arg.target.type === 'id') {
+                            const value = this.expression(arg.expression, context);
+                            node.attributes.push({ type: '=', target: arg.target, expression: value });
                         }
-                        case 'torch.embedding': {
-                            parameter.size = [ undefined, undefined, undefined ];
-                            break;
+                        else {
+                            throw new pytorch.Attribute('Expected named argument.');
                         }
                     }
-                    parameter.__variable__ = this._variable().value;
-                    outputs.push(parameter)
-                    node.outputs.push(parameter.__variable__);
-                }
-                this._nodes.push(node);
-                if (outputs.length > 1) {
-                    return outputs;
+                    let outputs = []
+                    for (let i = 0; i < schema.outputs.length; i++) {
+                        let parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
+                        switch (type) {
+                            case 'torch.cat': 
+                            case 'torch.conv2d': 
+                            case 'torch.flatten':
+                            case 'torch.quantize_per_tensor':
+                            case 'torch.relu_':
+                            case 'torch.dropout': {
+                                parameter.size = [ undefined, undefined, undefined, undefined ];
+                                break;
+                            }
+                            case 'torch.embedding': {
+                                parameter.size = [ undefined, undefined, undefined ];
+                                break;
+                            }
+                        }
+                        parameter.__variable__ = this._variable().value;
+                        outputs.push(parameter)
+                        node.outputs.push(parameter.__variable__);
+                    }
+                    this._nodes.push(node);
+                    if (outputs.length > 1) {
+                        return outputs;
+                    }
+                    return outputs[0];
                 }
-                return outputs[0];
             }
         }
         return super.call(target, name, args, context);

+ 22 - 1
test/models.json

@@ -3937,6 +3937,13 @@
     "format": "PyTorch v0.1.10",
     "source": "https://download.pytorch.org/models/densenet161-8d451a50.pth"
   },
+  {
+    "type":   "pytorch",
+    "target": "gcn2_tiny_320x240.pt",
+    "source": "https://github.com/jiexiong2016/GCNv2_SLAM/blob/master/GCN2/gcn2_tiny_320x240.pt?raw=true",
+    "format": "TorchScript v1.0",
+    "link":   "https://github.com/jiexiong2016/GCNv2_SLAM"
+  },
   {
     "type":   "pytorch",
     "target": "inception_v3.pkl.pth",
@@ -4230,6 +4237,13 @@
     "format": "PyTorch v0.1.1",
     "link":   "https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py"
   },
+  {
+    "type":   "pytorch",
+    "target": "semantic_lstm_vehicle_model.pt",
+    "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/semantic_lstm_vehicle_model.pt?raw=true",
+    "format": "TorchScript v1.1",
+    "link":   "https://github.com/ApolloAuto/apollo"
+  },
   {
     "type":   "pytorch",
     "target": "SiamFC_50_model.pth",
@@ -4302,10 +4316,17 @@
   {
     "type":   "pytorch",
     "target": "superpoint_v1.pth",
-    "source": "https://raw.githubusercontent.com/MagicLeapResearch/SuperPointPretrainedNetwork/master/superpoint_v1.pth",
+    "source": "https://github.com/MagicLeapResearch/SuperPointPretrainedNetwork/blob/master/superpoint_v1.pth?raw=true",
     "format": "PyTorch v0.1.10",
     "link":   "https://github.com/MagicLeapResearch/SuperPointPretrainedNetwork"
   },
+  {
+    "type":   "pytorch",
+    "target": "superpoint.pt",
+    "source": "https://github.com/KinglittleQ/SuperPoint_SLAM/blob/master/superpoint.pt?raw=true",
+    "format": "TorchScript v1.0",
+    "link":   "https://github.com/KinglittleQ/SuperPoint_SLAM"
+  },
   {
     "type":   "pytorch",
     "target": "traced_online_lane_enc.pt",