Explorar o código

Add TorchScript test files (#281)

Lutz Roeder %!s(int64=6) %!d(string=hai) anos
pai
achega
c6f59e2b5d
Modificáronse 6 ficheiros con 319 adicións e 60 borrados
  1. 53 12
      src/torchscript-metadata.json
  2. 212 44
      src/torchscript.js
  3. 2 2
      src/view.js
  4. 43 0
      test/models.json
  5. 5 1
      test/test.js
  6. 4 1
      tools/pytorch-script.py

+ 53 - 12
src/torchscript-metadata.json

@@ -33,7 +33,7 @@
       ],
       "attributes": [
         { "name": "kernel_size", "type": "int64[]" },
-        { "name": "stride", "type": "int64[]" },
+        { "name": "stride", "type": "int64[]", "default": 2 },
         { "name": "padding", "type": "int64[]", "default": 0 },
         { "name": "dilation", "type": "int64[]", "default": 1 },
         { "name": "ceil_mode", "type": "boolean", "default": false }
@@ -93,7 +93,7 @@
         { "name": "input" }
       ],
       "attributes": [
-        { "name": "output_size", "type": "int64[]" }
+        { "name": "output_size", "type": "int64[]", "visible": false }
       ],
       "outputs": [
         { "name": "output" }
@@ -110,9 +110,9 @@
       "attributes": [
         { "name": "kernel_size", "type": "int64[]" },
         { "name": "stride", "type": "int64[]" },
-        { "name": "padding", "type": "int64[]" },
+        { "name": "padding", "type": "int64[]", "default": 0 },
         { "name": "ceil_mode", "type": "boolean", "default": false },
-        { "name": "count_include_pad", "type": "boolean" }
+        { "name": "count_include_pad", "type": "boolean", "default": true }
       ],
       "outputs": [
         { "name": "output" }
@@ -131,7 +131,7 @@
         { "name": "stride", "type": "int64[]" },
         { "name": "padding", "type": "int64[]" },
         { "name": "ceil_mode", "type": "boolean", "default": false },
-        { "name": "count_include_pad", "type": "boolean" }
+        { "name": "count_include_pad", "type": "boolean", "default": true }
       ],
       "outputs": [
         { "name": "output" }
@@ -150,10 +150,10 @@
         { "name": "running_var" }
       ],
       "attributes": [
-        { "name": "training", "type": "boolean", "visible": "false" },
+        { "name": "training", "type": "boolean", "visible": false },
         { "name": "momentum", "type": "float32", "default": 0.1 },
         { "name": "eps", "type": "float32", "default": 1e-05 },
-        { "name": "cudnn_enabled", "type": "bool", "visible": "false" }
+        { "name": "cudnn_enabled", "type": "boolean", "visible": false }
       ],
       "outputs": [
         { "name": "output" }
@@ -220,8 +220,8 @@
         { "name": "input" }
       ],
       "attributes": [
-        { "name": "threshold", "type": "float64" },
-        { "name": "value", "type": "float64" },
+        { "name": "threshold", "type": "float64", "default": 0 },
+        { "name": "value", "type": "float64", "default": 0 },
         { "name": "inplace", "type": "boolean", "default": false }
       ],
       "outputs": [
@@ -298,8 +298,9 @@
   {
     "name": "addmm",
     "schema": {
+      "category": "Layer",
       "inputs": [
-        { "name": "input" },
+        { "name": "mat" },
         { "name": "mat1" },
         { "name": "mat2" }
       ],
@@ -422,6 +423,20 @@
       ]
     }
   },
+  {
+    "name": "view",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ],
+      "attributes": [
+        { "name": "size", "type": "int64[]", "visible": false }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
   {
     "name": "to",
     "schema": {
@@ -429,7 +444,7 @@
         { "name": "input" }
       ],
       "attributes": [
-        { "name": "dtype" },
+        { "name": "dtype", "visible": false },
         { "name": "non_blocking", "type": "boolean", "default": false },
         { "name": "copy", "type": "boolean", "default": false }
       ],
@@ -533,6 +548,17 @@
       ]
     }
   },
+  {
+    "name": "contiguous",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
   {
     "name": "unsqueeze",
     "schema": {
@@ -548,12 +574,27 @@
       ]
     }
   },
+  {
+    "name": "max",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ],
+      "attributes": [
+        { "name": "dim_or_y", "type": "int64" },
+        { "name": "dim", "type": "boolean" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
   {
     "name": "cat",
     "schema": {
       "category": "Tensor",
       "inputs": [
-        { "name": "input" }
+        { "name": "inputs" }
       ],
       "attributes": [
         { "name": "dim", "type": "int64" }

+ 212 - 44
src/torchscript.js

@@ -36,7 +36,7 @@ torchscript.ModelFactory = class {
                     var message = error && error.message ? error.message : error.toString();
                     message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
                     throw new torchscript.Error(message + " in '" + identifier + "'.");
-                }    
+                }
             });
         }
         catch (error) {
@@ -107,6 +107,29 @@ torchscript.Graph = class {
 
         var context = new torchscript.GraphContext(container, mainModule);
 
+        container.parameters = {};
+        var queue = [ mainModule ];
+        while (queue.length > 0) {
+            var module = queue.shift();
+            if (module.parameters) {
+                for (var parameter of module.parameters) {
+                    if (parameter.tensorId) {
+                        var tensorId = parseInt(parameter.tensorId, 10);
+                        parameter.initializer = container.tensors[tensorId];
+                        if (parameter.outputs && parameter.outputs.length == 1) {
+                            container.parameters[parameter.outputs[0]] = parameter;
+                        }
+                    }
+                }
+            }
+            if (module.submodules) {
+                for (var submodule of module.submodules) {
+                    submodule.parent = module;
+                    queue.push(submodule);
+                }
+            }
+        }
+
         for (var input of context.inputs) {
             this._inputs.push(new torchscript.Argument(input, true, [
                 new torchscript.Connection(input, null, null)
@@ -119,21 +142,20 @@ torchscript.Graph = class {
         }
 
         for (var node of context.nodes) {
-            this._nodes.push(new torchscript.Node(metadata, container, '', null, node));
+            this._nodes.push(new torchscript.Node(metadata, container, null, node));
         }
 
-        this._loadModule(metadata, container, '', mainModule);
+        this._loadModule(metadata, container, mainModule);
     }
 
-    _loadModule(metadata, container, group, module) {
-        if (module.parameters && module.parameters.length > 0) {
-            var node = new torchscript.Node(metadata, container, group, module, null);
+    _loadModule(metadata, container, module) {
+        if (module.parameters && module.parameters.length > 0 && !module.hide) {
+            var node = new torchscript.Node(metadata, container, module, null);
             this._nodes.push(node);
         }
         if (module.submodules) {
-            var subgroup = group ? [ group, module.name ].join('.') : module.name;
             for (var submodule of module.submodules) {
-                this._loadModule(metadata, container, subgroup, submodule);
+                this._loadModule(metadata, container, submodule);
             }
         }
     }
@@ -210,20 +232,22 @@ torchscript.Connection = class {
 
 torchscript.Node = class {
 
-    constructor(metadata, container, group, module, node) {
+    constructor(metadata, container, module, node) {
         this._metadata = metadata;
         this._attributes = [];
         this._inputs = [];
         this._outputs = [];
 
+        var input = null;
+        var connection = null;
+        var parameter = null;
+
         if (module) {
             this._operator = 'Module';
-            this._name = group ? [ group, module.name ].join('.') : module.name;
             if (module.parameters) {
-                for (var parameter of module.parameters) {
-                    var tensorId = parseInt(parameter.tensorId, 10);
+                for (parameter of module.parameters) {
                     this._inputs.push(new torchscript.Argument(parameter.name, true, [
-                        new torchscript.Connection('', null, container.tensors[tensorId])
+                        new torchscript.Connection('', null, parameter.initializer || null)
                     ]));
                     if (parameter.outputs) {
                         this._outputs.push(new torchscript.Argument(parameter.name, true,
@@ -240,13 +264,49 @@ torchscript.Node = class {
 
             var schema = metadata.getSchema(this._operator);
 
+            module = null; 
+            var match = true;
+            var count = 0;
+            for (input of node.inputs) {
+                for (connection of input) {
+                    parameter = container.parameters[connection.id];
+                    if (parameter) {
+                        if (parameter.module && (module == null || module == parameter.module)) {
+                            module = parameter.module;
+                            count++;
+                        }
+                        else {
+                            match = false;
+                            break;
+                        }
+                    }
+                }
+                if (!match) {
+                    break;
+                }
+            }
+            if (module && module.parameters.length == count && match) {
+                module.hide = true;
+                for (input of node.inputs) {
+                    for (connection of input) {
+                        parameter = container.parameters[connection.id];
+                        if (parameter && parameter.initializer) {
+                            connection.initializer = parameter.initializer;
+                        }
+                    }
+                }
+            }
+            else {
+                module = null;
+            }
+
             for (var inputIndex = 0; inputIndex < node.inputs.length; inputIndex++) {
                 var inputName = inputIndex.toString(); 
                 if (schema && schema.inputs && schema.inputs.length > inputIndex) {
                     inputName = schema.inputs[inputIndex].name;
                 }
                 this._inputs.push(new torchscript.Argument(inputName, true,
-                    node.inputs[inputIndex].map((input) => new torchscript.Connection(input, null, null))
+                    node.inputs[inputIndex].map((input) => new torchscript.Connection(input.id, null, input.initializer || null))
                 ));
             }
 
@@ -280,6 +340,17 @@ torchscript.Node = class {
                 this._attributes.push(new torchscript.Attribute(this, attributeSchema, attributeName, attributeValue));
             }
         }
+        
+        if (module) {
+            if (module.name) {
+                var current = module;
+                this._name = current.name;
+                while (current.parent != null) {
+                    current = current.parent;
+                    this._name = [ current.name, this._name ].join('.')
+                }
+            }
+        }
     }
 
     get name() {
@@ -399,7 +470,13 @@ torchscript.Attribute = class {
                 case 'int32[]':
                 case 'int64[]':
                     if (this._value.type == 'list' && this._value.value.every((item) => item.type === 'number')) {
-                        this._value = this._value.value.map((item) => parseInt(item.value, 10));
+                        this._value = this._value.value.map((item) => {
+                            var number = parseInt(item.value, 10);
+                            if (!Number.isNaN(item.value - number)) {
+                                return number;
+                            }
+                            return item.value;
+                        });
                     }
                     break;
             }
@@ -411,6 +488,11 @@ torchscript.Attribute = class {
                 if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
                     this._visible = false;
                 }
+                else if (Array.isArray(this._value) && 
+                    !Array.isArray(schema.default) &&
+                    this.value.every((item) => item == schema.default)) {
+                    this._visible = false;
+                }
             }
         }
     }
@@ -438,6 +520,7 @@ torchscript.Tensor = class {
         this._type = new torchscript.TensorType(tensor.dataType, new torchscript.TensorShape(tensor.dims));
         var key = container.prefix + tensor.data.key;
         var entry = container.entries.find((entry) => entry.name == key);
+        this._name = tensor.data.key;
         this._data = entry.data;
         this._littleEndian = true;
     }
@@ -706,6 +789,7 @@ torchscript.GraphContext = class {
 
     constructor(container, mainModule) {
 
+        this._container = container;
         this._mainModule = mainModule;
 
         this._inputs = [];
@@ -714,6 +798,7 @@ torchscript.GraphContext = class {
 
         this._moduleMap = {};
         this._connectionMap = {};
+        this._numToTensorMap = {};
 
         if (mainModule.torchscriptArena && mainModule.torchscriptArena.key) {
             var codeKey = container.prefix + mainModule.torchscriptArena.key;
@@ -751,6 +836,9 @@ torchscript.GraphContext = class {
 
                     while (this._body.length > 0) {
                         var statement = this._body.shift();
+                        if (this._attributeStatement(statement)) {
+                            continue;
+                        }
                         if (this._moduleStatement(statement)) {
                             continue;
                         }
@@ -763,7 +851,7 @@ torchscript.GraphContext = class {
                         if (this._returnStatement(statement)) {
                             continue;
                         }
-                        debugger;
+                        throw new torchscript.Error("Unknown statement '" + JSON.stringify(statement) + "'.");
                     }
                 }
             }
@@ -839,9 +927,8 @@ torchscript.GraphContext = class {
     _nodeExpression(expression, target) {
         if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'identifier_list')) {
             var name = this._name(expression.target);
-            var namespaces = [ 'torch.', 'ops.prim.' ];
-            var namespace = namespaces.find((n) => name.startsWith(n));
-            if (namespace) {
+            var namespace = 'torch.';
+            if (name.startsWith(namespace)) {
                 var node = {};
                 node.name = name.substring(namespace.length);
                 node.inputs = [];
@@ -856,7 +943,7 @@ torchscript.GraphContext = class {
                         delete this._connectionMap[argument.value];
                     }
                     if (argument.type == 'identifier') {
-                        node.inputs.push([ argument.value ]);
+                        node.inputs.push([ { id: argument.value } ]);
                         args.shift();
                         continue;
                     }
@@ -865,13 +952,13 @@ torchscript.GraphContext = class {
                         for (var input of argument.value) {
                             var variable = this._variable();
                             if (this._nodeExpression(input, variable)) {
-                                connections.push(variable.value);
+                                connections.push({ id: variable.value });
                             }
-                            if (this._connectionExpression(input, variable)) {
-                                connections.push(variable.value);
+                            else if (this._connectionExpression(input, variable)) {
+                                connections.push({ id: variable.value });
                             }
                             else if (input.type == 'identifier') {
-                                connections.push(input.value);
+                                connections.push({ id: input.value });
                             }
                             else {
                                 connections = null;
@@ -893,26 +980,41 @@ torchscript.GraphContext = class {
                     if (argument.type == '=') {
                         break;
                     }
-                    var variable = this._variable();
+                    variable = this._variable();
                     if (this._nodeExpression(argument, variable)) {
-                        node.inputs.push([ variable.value ]);
+                        node.inputs.push([ { id: variable.value } ]);
                         args.shift();
                         continue;
                     }
                     if (this._connectionExpression(argument, variable)) {
-                        node.inputs.push([ variable.value ]);
+                        node.inputs.push([ { id: variable.value } ]);
                         args.shift();
                         continue;
                     }
-                    // TODO CONSTANTS.cx
-                    if (argument.type == '.' && argument.target.type == 'identifier' && argument.target.value == 'CONSTANTS') {
-                        node.inputs.push([ JSON.stringify(args[0]) ]);
+                    if (argument.type == '.' &&
+                        argument.target.type == 'identifier' &&
+                        argument.target.value == 'CONSTANTS' &&
+                        argument.member.type == 'identifier' &&
+                        argument.member.value.startsWith('c')) {
+                        var constantId = [ argument.target.value, argument.member.value ].join('.');
+                        var constantIndex = parseInt(argument.member.value.substring(1), 10);
+                        var constantTensor = this._container.tensors[constantIndex];
+                        node.inputs.push([ { id: constantId, initializer: constantTensor } ]);
                         args.shift();
                         continue;
                     }
                     throw new torchscript.Error('Unknown function argument.');
                 }
                 while (args.length > 0) {
+                    if (args[0].type == 'list') {
+                        for (var i = 0; i < args[0].value.length; i++) {
+                            args[0].value[i] = this._attributeExpression(args[0].value[i]);
+                        }
+                    }
+                    var intExpression = this._attributeExpression(args[0]);
+                    if (intExpression) {
+                        args[0] = intExpression;
+                    }
                     node.attributes.push(args[0]);
                     args.shift();
                 }
@@ -940,6 +1042,71 @@ torchscript.GraphContext = class {
         return false;
     }
 
+    _attributeExpression(expression) {
+        if (expression.type == 'identifier') {
+            if (this._numToTensorMap[expression.value]) {
+                return { type: 'number', value: this._numToTensorMap[expression.value] };
+            }
+        }
+        if (expression.type == 'call' && 
+            expression.target.type == 'identifier' &&
+            expression.target.value == 'int' &&
+            expression.arguments.length == 1) 
+        {
+            var replace = this._attributeExpression(expression.arguments[0]);
+            if (replace) {
+                return replace;
+            }
+        }
+        return expression;
+    }
+
+    _attributeStatement(statement) {
+        if (statement.type == '=' &&
+            statement.target.type == 'identifier') { 
+            if (statement.expression.type == 'call' &&
+                this._name(statement.expression.target) == 'ops.prim.NumToTensor' && 
+                statement.expression.arguments.length == 1) {
+                var size = statement.expression.arguments[0];
+                if (size.type == 'call' &&
+                    size.arguments.length == 2 &&
+                    this._name(size.target) == 'torch.size' &&
+                    size.arguments[0].type == 'identifier' &&
+                    size.arguments[1].type == 'number') {
+                    this._numToTensorMap[statement.target.value] = this._name(size.target) + '(' + size.arguments.map((a) => a.value.toString()).join(',') + ')';;
+                    return true;
+                }
+                if (size.type == 'identifier') {
+                    var duplicate1 = this._numToTensorMap[size.value];
+                    if (duplicate1) {
+                        this._numToTensorMap[statement.target.value] = duplicate1;
+                        return true;
+                    }
+                }
+            }
+            if (statement.expression.type == 'call' &&
+                statement.expression.arguments.length == 2 &&
+                this._name(statement.expression.target) == 'torch.size' &&
+                statement.expression.arguments[0].type == 'identifier' &&
+                statement.expression.arguments[1].type == 'number') {
+                this._numToTensorMap[statement.target.value] = this._name(statement.expression.target) + '(' + statement.expression.arguments.map((a) => a.value.toString()).join(',') + ')';;
+                return true;
+            }
+            if (statement.expression.type == 'call' &&
+                statement.expression.target.type == 'identifier' &&
+                statement.expression.target.value == 'int' &&
+                statement.expression.arguments.length == 1 &&
+                statement.expression.arguments[0].type == 'identifier') {
+                var duplicate2 = this._numToTensorMap[statement.expression.arguments[0].value];
+                if (duplicate2) {
+                    this._numToTensorMap[statement.target.value] = duplicate2;
+                    return true;
+                }
+            }
+        }
+        return false;
+    }
+
     _module(expression) {
         var module;
         var submodule;
@@ -999,24 +1166,27 @@ torchscript.GraphContext = class {
     _connectionExpression(expression, target) {
         expression = this._moduleTensor(expression);
         if (expression.type === '.' && expression.member.type == 'identifier') {
-            var module = this._module(expression.target);
-            if (module && module.parameters) {
-                for (var parameter of module.parameters) {
+            var targetModule = this._module(expression.target);
+            if (targetModule && targetModule.parameters) {
+                for (var parameter of targetModule.parameters) {
+                    parameter.module = targetModule;
                     if (parameter.name === expression.member.value) {
                         parameter.outputs = parameter.outputs || [];
                         parameter.outputs.push(target.value);
                         return true;
                     }
                 }
-                module.unresolvedParameters = module.unresolvedParameters || [];
-                for (var unresolvedParameter of module.unresolvedParameters) {
+                targetModule.unresolvedParameters = targetModule.unresolvedParameters || [];
+                for (var unresolvedParameter of targetModule.unresolvedParameters) {
+                    unresolvedParameter.module = targetModule;
                     if (unresolvedParameter.name === expression.member.value) {
                         unresolvedParameter.outputs = unresolvedParameter.outputs || [];
                         unresolvedParameter.outputs.push(target.value);
                         return true;
                     }
                 }
-                module.unresolvedParameters.push({
+                targetModule.unresolvedParameters.push({
+                    module: targetModule,
                     name: expression.member.value,
                     outputs: [ target.value ]
                 });
@@ -1041,8 +1211,7 @@ torchscript.GraphContext = class {
     }
 
     _variable() {
-        var value = '_gen' + Math.random().toString(36).substring(7);
-        return { type: 'identifier', value: value };
+        return { type: 'identifier', value: '_gen' + Math.random().toString(36).substring(7) };
     }
 
     _name(expression) {
@@ -1052,15 +1221,14 @@ torchscript.GraphContext = class {
         if (expression.type == '.') {
             return [ this._name(expression.target), this._name(expression.member) ].join('.');
         }
-        throw new torchscript.Error('Failed to resolve name.');
+        throw new torchscript.Error("Failed to resolve name '" + JSON.stringify(expression) + "'.");
     }
 
     _moduleTensor(expression) {
-        if (expression.type == 'call' && expression.arguments.length == 1) {
-            var name = this._name(expression.target);
-            if (name == 'torch.t') {
-                return expression.arguments[0];
-            }
+        if (expression.type == 'call' && 
+            expression.arguments.length == 1 &&
+            this._name(expression.target) == 'torch.t') {
+            return expression.arguments[0];
         }
         return expression;
     }

+ 2 - 2
src/view.js

@@ -509,8 +509,8 @@ view.View = class {
                                     shape = '\u3008' + type.shape.dimensions.join('\u00D7') + '\u3009';
                                     if (type.shape.dimensions.length == 0 && connection.initializer && !connection.initializer.state) {
                                         shape = connection.initializer.toString();
-                                        if (shape && shape.length > 25) {
-                                            shape = shape.substring(0, 25) + '...';
+                                        if (shape && shape.length > 10) {
+                                            shape = shape.substring(0, 10) + '...';
                                         }
                                         separator = ' = ';
                                     }

+ 43 - 0
test/models.json

@@ -4528,6 +4528,14 @@
     "format": "TorchScript v1",
     "link":   "https://github.com/ApolloAuto/apollo"
   },
+  {
+    "type":   "torchscript",
+    "target": "densenet121.pt",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "render": "skip",
+    "script": [ "${root}/tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
   {
     "type":   "torchscript",
     "target": "mnist_linear_torchscript.pt",
@@ -4604,5 +4612,40 @@
     "script": [ "${root}/tools/pytorch", "sync install zoo" ],
     "format": "TorchScript v1",
     "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
+  },
+  {
+    "type":   "torchscript",
+    "target": "squeezenet1_1.pt",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "${root}/tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "type":   "torchscript",
+    "target": "traced_online_lane_enc.pt",
+    "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/traced_online_lane_enc.pt?raw=true",
+    "format": "TorchScript v1",
+    "link":   "https://github.com/ApolloAuto/apollo"
+  },
+  {
+    "type":   "torchscript",
+    "target": "traced_online_obs_enc.pt",
+    "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/traced_online_obs_enc.pt?raw=true",
+    "format": "TorchScript v1",
+    "link":   "https://github.com/ApolloAuto/apollo"
+  },
+  {
+    "type":   "torchscript",
+    "target": "traced_online_pred_layer.pt",
+    "source": "https://github.com/ApolloAuto/apollo/blob/master/modules/prediction/data/traced_online_pred_layer.pt?raw=true",
+    "format": "TorchScript v1",
+    "link":   "https://github.com/ApolloAuto/apollo"
+  },
+  {
+    "type":   "torchscript",
+    "target": "vgg16.pt",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "${root}/tools/pytorch", "sync install zoo" ],
+    "status": "script"
   }
 ]

+ 5 - 1
test/test.js

@@ -574,9 +574,13 @@ function next() {
         loadModel(folder + '/' + completed[0], item).then((model) => {
             if (item.render != 'skip') {
                 render(model).then(() => {
+                    if (item.error) {
+                       console.error('Expected error.');
+                       return;
+                    }
                     next();
                 }).catch((error) => {
-                    if (!item.error && item.error != error.message) {
+                    if (!item.error || item.error != error.message) {
                         console.error(err);
                     }
                     next();

+ 4 - 1
tools/pytorch-script.py

@@ -74,17 +74,20 @@ def zoo():
     download_pytorch_model('torchvision.models.densenet161', '${test}/data/pytorch/densenet161.pth')
     download_pytorch_model('torchvision.models.inception_v3', '${test}/data/pytorch/inception_v3.pth')
     download_pytorch_model('torchvision.models.mobilenet_v2', '${test}/data/pytorch/mobilenet_v2.pth')
-    download_pytorch_model('torchvision.models.resnet101', '${test}/data/pytorch/resnet101.pth')
     download_pytorch_model('torchvision.models.resnet18', '${test}/data/pytorch/resnet18.pth')
     download_pytorch_model('torchvision.models.resnet50', '${test}/data/pytorch/resnet50.pth')
+    download_pytorch_model('torchvision.models.resnet101', '${test}/data/pytorch/resnet101.pth')
     download_pytorch_model('torchvision.models.squeezenet1_0', '${test}/data/pytorch/squeezenet1_0.pth')
     download_pytorch_model('torchvision.models.vgg11_bn', '${test}/data/pytorch/vgg11_bn.pth')
     download_pytorch_model('torchvision.models.vgg16', '${test}/data/pytorch/vgg16.pth')
     download_torchscript_model('torchvision.models.alexnet', '${test}/data/torchscript/alexnet.pt', [ 1, 3, 299, 299 ])
+    download_torchscript_model('torchvision.models.densenet121', '${test}/data/torchscript/densenet121.pt', [ 1, 3, 224, 224 ])
     download_torchscript_model('torchvision.models.inception_v3', '${test}/data/torchscript/inception_v3.pt', [ 1, 3, 299, 299 ])
     download_torchscript_model('torchvision.models.mobilenet_v2', '${test}/data/torchscript/mobilenet_v2.pt', [ 1, 3, 224, 224 ])
     download_torchscript_model('torchvision.models.resnet18', '${test}/data/torchscript/resnet18.pt', [ 1, 3, 224, 224 ])
     download_torchscript_model('torchvision.models.resnet50', '${test}/data/torchscript/resnet50.pt', [ 1, 3, 224, 224 ])
+    download_torchscript_model('torchvision.models.squeezenet1_1', '${test}/data/torchscript/squeezenet1_1.pt', [ 1, 3, 224, 224 ])
+    download_torchscript_model('torchvision.models.vgg16', '${test}/data/torchscript/vgg16.pt', [ 1, 3, 224, 224 ])
 
 if __name__ == '__main__':
     command_table = { 'metadata': metadata, 'zoo': zoo }