Selaa lähdekoodia

Add Torch prototype (#200)

Lutz Roeder 7 vuotta sitten
vanhempi
sitoutus
b4f44f4329
9 muutettua tiedostoa jossa 1039 lisäystä ja 6 poistoa
  1. 2 0
      electron-builder.yml
  2. 1 0
      setup.py
  3. 1 1
      src/app.js
  4. 4 2
      src/pytorch-metadata.json
  5. 150 0
      src/torch-metadata.json
  6. 793 0
      src/torch.js
  7. 1 1
      src/view-browser.html
  8. 3 2
      src/view.js
  9. 84 0
      test/models.json

+ 2 - 0
electron-builder.yml

@@ -52,6 +52,8 @@ fileAssociations:
     ext: prototxt
   - name: "PyTorch Model"
     ext: pth
+  - name: "Torch Model"
+    ext: t7
 publish:
   - provider: github
     releaseType: release

+ 1 - 0
setup.py

@@ -121,6 +121,7 @@ setuptools.setup(
             'sklearn.js', 'sklearn-metadata.json',
             'tf.js', 'tf-metadata.json', 'tf-proto.js', 
             'tflite.js', 'tflite-metadata.json', 'tflite-schema.js', 
+            'torch.js', 'torch-metadata.json',
             'view-browser.html', 'view-browser.js',
             'view-grapher.css', 'view-grapher.js',
             'view-sidebar.css', 'view-sidebar.js',

+ 1 - 1
src/app.js

@@ -112,7 +112,7 @@ class Application {
         var showOpenDialogOptions = { 
             properties: [ 'openFile' ], 
             filters: [
-                { name: 'All Model Files',  extensions: [ 'onnx', 'pb', 'h5', 'hdf5', 'json', 'keras', 'mlmodel', 'caffemodel', 'model', 'meta', 'tflite', 'lite', 'pt', 'pth', 'pkl', 'joblib', 'pbtxt', 'prototxt', 'xml', 'dot' ] }
+                { name: 'All Model Files',  extensions: [ 'onnx', 'pb', 'h5', 'hdf5', 'json', 'keras', 'mlmodel', 'caffemodel', 'model', 'meta', 'tflite', 'lite', 'pt', 'pth', 't7', 'pkl', 'joblib', 'pbtxt', 'prototxt', 'xml', 'dot' ] }
                 /* 
                 { name: 'ONNX Model', extensions: [ 'onnx', 'pb', 'pbtxt' ] },
                 { name: 'Keras Model', extensions: [ 'h5', 'hdf5', 'json', 'keras' ] },

+ 4 - 2
src/pytorch-metadata.json

@@ -408,7 +408,8 @@
       "attributes": [
         {
           "default": false,
-          "name": "inplace"
+          "name": "inplace",
+          "visible": false
         },
         {
           "default": 0.5,
@@ -425,7 +426,8 @@
       "attributes": [
         {
           "default": false,
-          "name": "inplace"
+          "name": "inplace",
+          "visible": false
         },
         {
           "default": 0.5,

+ 150 - 0
src/torch-metadata.json

@@ -0,0 +1,150 @@
+[
+  {
+    "name": "Linear",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "SpatialConvolution",
+    "schema": {
+      "category": "Layer",
+      "attributes": [
+        { "name": "benchmarked", "visible": false },
+        { "name": "input_offset", "visible": false },
+        { "name": "output_offset", "visible": false },
+        { "name": "weight_offset", "visible": false },
+        { "name": "groups", "default": 1 },
+        { "name": "d", "default": [ 1, 1 ] },
+        { "name": "pad", "default": [ 0, 0 ] },
+        { "name": "padding", "default": 0 },
+        { "name": "nInputPlane", "visible": false },
+        { "name": "nOutputPlane", "visible": false },
+        { "name": "fmode", "visible": false },
+        { "name": "bwmode", "visible": false },
+        { "name": "bdmode", "visible": false }
+      ]
+    }
+  },
+  {
+    "name": "SpatialFullConvolution",
+    "schema": {
+      "category": "Layer",
+      "attributes": [
+        { "name": "d", "default": [ 1, 1 ] },
+        { "name": "dilation", "default": [ 1, 1 ] },
+        { "name": "pad", "default": [ 0, 0 ] },
+        { "name": "nInputPlane", "visible": false },
+        { "name": "nOutputPlane", "visible": false }
+      ]
+    }
+  },
+  {
+    "name": "SpatialDilatedConvolution",
+    "schema": {
+      "category": "Layer",
+      "attributes": [
+        { "name": "d", "default": [ 1, 1 ] },
+        { "name": "dilation", "default": [ 1, 1 ] },
+        { "name": "pad", "default": [ 0, 0 ] },
+        { "name": "nInputPlane", "visible": false },
+        { "name": "nOutputPlane", "visible": false }
+      ]
+    }
+  },
+  {
+    "name": "BatchNormalization",
+    "schema": {
+      "category": "Normalization",
+      "attributes": [
+        { "name": "affine", "default": true },
+        { "name": "momentum", "default": 0.1 },
+        { "name": "eps", "default": 0.00001 }
+      ]
+    }
+  },
+  {
+    "name": "SpatialBatchNormalization",
+    "schema": {
+      "category": "Normalization",
+      "attributes": [
+        { "name": "affine", "default": true },
+        { "name": "momentum", "default": 0.1 },
+        { "name": "eps", "default": 0.00001 },
+        { "name": "mode", "default": "CUDNN_BATCHNORM_SPATIAL" },
+        { "name": "nDim", "default": 4 }
+      ]
+    }
+  },
+  {
+    "name": "SpatialAveragePooling",
+    "schema": {
+      "category": "Pool",
+      "attributes": [ 
+        { "name": "ceil_mode", "default": false },
+        { "name": "mode", "default": "CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING" },
+        { "name": "d", "default": [ 1, 1 ] },
+        { "name": "pad", "default": [ 0, 0 ] }
+      ]
+    }
+  },
+  {
+    "name": "SpatialMaxPooling",
+    "schema": {
+      "category": "Pool",
+      "attributes": [ 
+        { "name": "ceil_mode", "default": false },
+        { "name": "mode", "default": "CUDNN_POOLING_MAX" },
+        { "name": "pad", "default": [ 0, 0 ] }
+      ]
+    }
+  },
+  {
+    "name": "SpatialZeroPadding",
+    "schema": {
+      "category": "Tensor",
+      "attributes": [ 
+      ]
+    }
+  },
+  {
+    "name": "Concat",
+    "schema": {
+      "category": "Tensor"
+    }
+  },
+  {
+    "name": "ReLU",
+    "schema": {
+      "category": "Activation",
+      "attributes": [
+        { "name": "threshold", "default": 0 },
+        { "name": "val", "default": 0 },
+        { "name": "inplace", "default": false, "visible": false },
+        { "name": "mode", "default": "CUDNN_ACTIVATION_RELU" },
+        { "name": "nElem", "visible": false }
+      ]
+    }
+  },
+  {
+    "name": "Sigmoid",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "Reshape",
+    "schema": {
+      "category": "Shape"
+    }
+  },
+  {
+    "name": "Dropout",
+    "schema": {
+      "category": "Dropout",
+      "attributes": [
+        { "name": "v2", "visible": false }
+      ]
+    }
+  }
+]

+ 793 - 0
src/torch.js

@@ -0,0 +1,793 @@
+
+/*jshint esversion: 6 */
+
+var torch = torch || {};
+var base = base || require('./base');
+
+torch.ModelFactory = class {
+
+    match(context, host) {
+        var extension = context.identifier.split('.').pop().toLowerCase();
+        if (extension == 't7') {
+            return true;
+        }
+        return false;
+    }
+
+    open(context, host, callback) {
+        torch.OperatorMetadata.open(host, (err, metadata) => {
+            var identifier = context.identifier;
+            try {
+                var buffer = context.buffer;
+                var reader = new torch.T7Reader(buffer, (name) => {
+                    host.exception(new torch.Error("Unknown type '" + name + "' in '" + identifier + "'."), false);
+                    return null;
+                });
+                var root = reader.read();
+                var model = new torch.Model(metadata, root);
+                callback(null, model);
+                return;
+            }
+            catch (error) {
+                var message = error && error.message ? error.message : error.toString();
+                message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
+                callback(new torch.Error(message + " in '" + identifier + "'."), null);
+                return;
+            }
+        });
+    }
+};
+
+torch.Model = class {
+    
+    constructor(metadata, root) {
+        this._graphs = [];
+        this._graphs.push(new torch.Graph(metadata, root));
+    }
+
+    get graphs() {
+        return this._graphs;
+    }
+
+    get format() {
+        return 'Torch v7';
+    }
+};
+
+torch.Graph = class {
+
+    constructor(metadata, root) {
+        this._inputs = [];
+        this._outputs = [];
+        this._nodes = [];
+        this._groups = 'false';
+        
+        if (root.hasOwnProperty('model')) {
+            root = root.model;
+        }
+
+        var inputs = [];
+        var outputs = [];
+
+        this._loadModule(metadata, root, [], '', inputs, outputs);
+
+        inputs.forEach((input, index) => {
+            this._inputs.push(new torch.Argument('input' + (index != 0 ? (index + 1).toString() : ''), true, [ input ]));
+        });
+        outputs.forEach((output, index) => {
+            this._outputs.push(new torch.Argument('output' + (index != 0 ? (index + 1).toString() : ''), true, [ output ]));
+        });
+    }
+
+    get inputs() {
+        return this._inputs;
+    }
+
+    get outputs() {
+        return this._outputs;
+    }
+
+    get nodes() {
+        return this._nodes;
+    }
+
+    get groups() {
+        return this._groups;
+    }
+
+    _loadModule(metadata, module, groups, key, inputs, outputs) {
+        if (groups.length > 0) {
+            this._groups = true;
+        }
+        switch (module.__type__) {
+            case 'nn.Sequential':
+                groups.push(key);
+                var subInputs = inputs;
+                var subOutputs = [];
+                var keys = Object.keys(module.modules);
+                keys.sort(); 
+                var last = keys[keys.length - 1];
+                keys.forEach((key, index) => {
+                    if (key == last.toString()) {
+                        subOutputs = outputs;
+                    }
+                    this._loadModule(metadata, module.modules[key], groups, key, subInputs, subOutputs);
+                    subInputs = subOutputs;
+                    subOutputs = [];
+                });
+                groups.pop();
+                break;
+            case 'nn.Parallel':
+                groups.push(key);
+                var keys = Object.keys(module.modules);
+                keys.sort();
+                var newInputs = [];
+                var newOutputs = [];
+                keys.forEach((key, index) => {
+                    var subInputs = inputs.map((input) => input);
+                    var subOutputs = outputs.map((output) => output);
+                    this._loadModule(metadata, module.modules[key], groups, key, subInputs, subOutputs);
+                    if (inputs.length == 0) {
+                        subInputs.forEach((input) => {
+                            newInputs.push(input);
+                        });
+                    }
+                    if (outputs.length == 0) {
+                        subOutputs.forEach((output) => {
+                            newOutputs.push(output);
+                        });
+                    }
+                });
+                newInputs.forEach((input) => {
+                    inputs.push(input);
+                });
+                newOutputs.forEach((output) => {
+                    outputs.push(output);
+                });
+                groups.pop();
+                break;
+            case 'nn.Concat':
+            case 'nn.ConcatTable':
+                groups.push(key);
+                var keys = Object.keys(module.modules);
+                keys.sort();
+                if (inputs.length == 0) {
+                    inputs.push(new torch.Connection(groups.join('/') + '/' + key, null, null));
+                }
+                var concatInputs = [];
+                keys.forEach((key, index) => {
+                    var streamInputs = inputs.map((input) => input);
+                    var streamOutputs = [];
+                    this._loadModule(metadata, module.modules[key], groups, key, streamInputs, streamOutputs);
+                    streamOutputs.forEach((output) => {
+                        concatInputs.push(output);
+                    });
+                });
+                groups.pop();
+                delete module.modules;
+                delete module.dimension;
+                this._createNode(metadata, module, groups, key, concatInputs, outputs);
+                break;
+            case 'nn.Inception':
+                delete module.modules; // TODO
+                delete module.module; // TODO
+                delete module.transfer; // TODO
+                delete module.pool; // TODO
+                this._createNode(metadata, module, groups, key, inputs, outputs);
+                break;
+            default:
+                this._createNode(metadata, module, groups, key, inputs, outputs);
+                break;
+        }
+    }
+
+    _createNode(metadata, module, group, subIndex, inputs, outputs) {
+        this._nodes.push(new torch.Node(metadata, module, group, subIndex, inputs, outputs));
+    }
+};
+
+torch.Argument = class {
+
+    constructor(name, visible, connections) {
+        this._name = name;
+        this._visible = visible;
+        this._connections = connections;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get visible() {
+        return this._visible;
+    }
+
+    get connections() {
+        return this._connections;
+    }
+};
+
+torch.Connection = class {
+
+    constructor(id, type, initializer) {
+        this._id = id;
+        this._type = type;
+        this._initializer = initializer;
+    }
+
+    get id() {
+        return this._id;
+    }
+
+    get type() {
+        if (this._initializer) {
+            return this._initializer.type;
+        }
+        return this._type;
+    }
+
+    get initializer() {
+        return this._initializer;
+    }
+};
+
+torch.Node = class {
+
+    constructor(metadata, module, groups, key, inputs, outputs) {
+        this._metadata = metadata;
+        this._group = groups.join('/');
+        this._name = this._group + '/' + key;
+        var type = module.__type__;
+        this._operator = type ? type.split('.').pop() : 'Object';
+        var initializers = [];
+        Object.keys(module).forEach((key) => {
+            var obj = module[key];
+            if (obj.__type__ && obj.__type__ == 'torch.LongStorage') {
+                var array = [];
+                var reader = new torch.T7Reader(obj.data);
+                for (var i = 0; i < obj.size; i++) {
+                    array.push(reader.int64());
+                }
+                module[key] = array;
+            }
+        });
+        delete module.iSize;
+        delete module.gradInput;
+        delete module.finput;
+        delete module.fgradInput;
+        delete module.output;
+        delete module.gradWeight;
+        delete module.gradBias;
+        delete module.scaleT;
+        switch (type) {
+            case 'nn.Linear':
+                delete module.addBuffer;
+                break;
+            case 'nn.Reshape':
+                delete module._input;
+                delete module._gradOutput;
+                break;
+            case 'cudnn.SpatialConvolution':
+            case 'nn.SpatialConvolution':
+            case 'nn.SpatialDilatedConvolution':
+            case 'nn.SpatialFullConvolution':
+                delete module.ones;
+                this._updateWidthHeight(module, 'adj');
+                this._updateWidthHeight(module, 'd');
+                this._updateWidthHeight(module, 'dilation');
+                this._updateWidthHeight(module, 'k');
+                this._updateWidthHeight(module, 'pad');
+                break;
+            case 'cudnn.BatchNormalization':
+            case 'cudnn.SpatialBatchNormalization':
+            case 'nn.BatchNormalization':
+            case 'nn.SpatialBatchNormalization':
+                delete module.save_mean;
+                delete module.save_std;
+                delete module.gradWeight;
+                module.mean = module.running_mean;
+                module.var = module.running_var;
+                delete module.running_mean;
+                delete module.running_var;
+                break;
+            case 'cudnn.SpatialMaxPooling':
+            case 'inn.SpatialMaxPooling':
+            case 'nn.SpatialMaxPooling':
+                delete module.indices;
+                this._updateWidthHeight(module, 'pad');
+                this._updateWidthHeight(module, 'd');
+                this._updateWidthHeight(module, 'k');
+                break;
+            case 'cudnn.SpatialAveragePooling':
+            case 'nn.SpatialAveragePooling':
+                this._updateWidthHeight(module, 'd');
+                this._updateWidthHeight(module, 'k');
+                break;    
+            case 'nn.SpatialFullConvolution':
+                delete module.ones;
+                break;
+            case 'nn.Dropout':
+                delete module.noise;
+                break;
+        }
+        this._attributes = [];
+        Object.keys(module).forEach((key) => {
+            if (key == '__type__' || key == '_type') {
+                return;
+            }
+            var obj = module[key];
+            if (obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Tensor')) {
+                if (obj.size.length == 0) {
+                    debugger;
+                    // console.log("  " + type + "::" + key);
+                }
+                initializers.push(new torch.Argument(key, true, [ 
+                    new torch.Connection(key, null, new torch.Tensor(obj))
+                ]));
+                return;
+            }
+            if (key == 'modules' || obj.__type__) {
+                debugger;                
+                // console.log("  " + type + "::" + key);
+                return;
+            }
+            this._attributes.push(new torch.Attribute(this._metadata, this._operator, key, obj));
+        });
+        this._inputs = [];
+        if (inputs.length == 0) {
+            inputs.push(new torch.Connection(this._name + '/in', null, null));
+        }
+        this._inputs.push(new torch.Argument('input', true, inputs));
+        this._outputs = [];
+        if (outputs.length == 0) {
+            outputs.push(new torch.Connection(this._name, null, null));
+        }
+        this._outputs.push(new torch.Argument('output', true, outputs));
+        initializers = initializers.filter((argument) => {
+            if (argument.name == 'weight') {
+                this._inputs.push(argument);
+                return false;
+            }
+            return true;
+        });
+        initializers = initializers.filter((argument) => {
+            if (argument.name == 'bias') {
+                this._inputs.push(argument);
+                return false;
+            }
+            return true;
+        });
+        initializers.forEach((initialier) => {
+            this._inputs.push(initialier);
+        });
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get operator() {
+        return this._operator;
+    }
+
+    get group() {
+        return this._group;
+    }
+
+    get category() {
+        var schema = this._metadata.getSchema(this._operator);
+        return (schema && schema.category) ? schema.category : null;
+    }
+
+    get attributes() {
+        return this._attributes;
+    }
+
+    get inputs() {
+        return this._inputs;
+    }
+
+    get outputs() {
+        return this._outputs;
+    }
+
+    _updateWidthHeight(module, name) {
+        if (module.hasOwnProperty(name + 'W') && module.hasOwnProperty(name + 'H')) {
+            module[name] = [ module[name + 'W'], module[name + 'H'] ];
+            delete module[name + 'W'];
+            delete module[name + 'H'];
+        }
+    }
+};
+
+torch.Attribute = class {
+
+    constructor(metadata, operator, name, value) {
+        this._name = name;
+        this._value = value;
+        if (name == 'train') {
+            this._visible = false;
+        }
+        var schema = metadata.getAttributeSchema(operator, name);
+        if (schema) {
+            if (schema.hasOwnProperty('visible')) {
+                this._visible = schema.visible;
+            }
+            else if (schema.hasOwnProperty('default')) {
+                if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
+                    this._visible = false;
+                }
+            }
+        }
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get value() {
+        return this._value;
+    }
+
+    get visible() {
+        return this._visible == false ? false : true;
+    }
+};
+
+torch.Tensor = class {
+
+    constructor(tensor) {
+        this._type = new torch.TensorType(tensor);
+        this._storage = tensor.storage;
+    }
+
+    get type() {
+        return this._type;
+    }
+
+    get state() {
+        return 'Not implemented.';
+    }
+};
+
+torch.TensorType = class {
+
+    constructor(tensor) {
+        this._dataType = tensor.storage ? tensor.storage.dataType : '?';
+        this._shape = new torch.TensorShape(tensor.size);
+    }
+
+    get dataType() {
+        return this._dataType;
+    }
+
+    get shape() {
+        return this._shape;
+    }
+
+    toString() {
+        return (this.dataType || '?') + this._shape.toString();
+    }
+};
+
+torch.TensorShape = class {
+
+    constructor(dimensions) {
+        this._dimensions = dimensions;
+    }
+
+    get dimensions() {
+        return this._dimensions;
+    }
+
+    toString() {
+        if (this._dimensions) {
+            if (this._dimensions.length == 0) {
+                return '';
+            }
+            return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
+        }
+        return '';
+    }
+};
+
+torch.Error = class extends Error {
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading Torch model.';
+    }
+};
+
+torch.T7Reader = class {
+
+    constructor(buffer, callback) {
+        this._buffer = buffer;
+        this._position = 0;
+        this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
+        this._textDecoder = new TextDecoder('ascii');
+        this._callback = callback; 
+        this._memo = {};
+        this._registry = {};
+        this._registry['cudnn.BatchNormalization'] = function(reader, version) { reader.nn(this); };
+        this._registry['cudnn.SpatialConvolution'] = function(reader, version) { reader.nn(this); };
+        this._registry['cudnn.ReLU'] = function(reader, version) { reader.nn(this); };
+        this._registry['cudnn.SpatialAveragePooling'] = function(reader, version) { reader.nn(this); };
+        this._registry['cudnn.SpatialBatchNormalization'] = function(reader, version) { reader.nn(this); };
+        this._registry['cudnn.SpatialMaxPooling'] = function(reader, version) { reader.nn(this); };
+        this._registry['inn.SpatialMaxPooling'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.CAddTable'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Concat'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.ConcatTable'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.DepthConcat'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Dropout'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Identity'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Inception'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Linear'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Parallel'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.ReLU'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Reshape'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Sequential'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Sigmoid'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialAveragePooling'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialBatchNormalization'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialConvolution'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialDilatedConvolution'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialFullConvolution'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialMaxPooling'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialZeroPadding'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.View'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.gModule'] = function(reader, version) { reader.nn(this); };
+        this._registry['nngraph.Node'] = function(reader, version) { reader.nn(this); };
+        this._registry['torch.ByteTensor'] = function(reader, version) { reader.tensor(this); };
+        this._registry['torch.CharTensor'] = function(reader, version) { reader.tensor(this); };
+        this._registry['torch.ShortTensor'] = function(reader, version) { reader.tensor(this); };
+        this._registry['torch.IntTensor'] = function(reader, version) { reader.tensor(this); };
+        this._registry['torch.LongTensor'] = function(reader, version) { reader.tensor(this); };
+        this._registry['torch.FloatTensor'] = function(reader, version) { reader.tensor(this); };
+        this._registry['torch.DoubleTensor'] = function(reader, version) { reader.tensor(this); };
+        this._registry['torch.CudaByteTensor'] = function(reader, version) {reader.tensor(this); };
+        this._registry['torch.CudaCharTensor'] = function(reader, version) {reader.tensor(this); };
+        this._registry['torch.CudaShortTensor'] = function(reader, version) {reader.tensor(this); };
+        this._registry['torch.CudaIntTensor'] = function(reader, version) {reader.tensor(this); };
+        this._registry['torch.CudaLongTensor'] = function(reader, version) {reader.tensor(this); };
+        this._registry['torch.CudaTensor'] = function(reader, version) {reader.tensor(this); };
+        this._registry['torch.CudaDoubleTensor'] = function(reader, version) {reader.tensor(this); };
+        this._registry['torch.ByteStorage'] = function(reader, version) { reader.storage(this, 'uint8', 1); };
+        this._registry['torch.CharStorage'] = function(reader, version) { reader.storage(this, 'int8', 1); };
+        this._registry['torch.ShortStorage'] = function(reader, version) { reader.storage(this, 'int16', 2); };
+        this._registry['torch.IntStorage'] = function(reader, version) { reader.storage(this, 'int32', 4); };
+        this._registry['torch.LongStorage'] = function(reader, version) { reader.storage(this, 'int64', 8); };
+        this._registry['torch.FloatStorage'] = function(reader, version) { reader.storage(this, 'float32', 4); };
+        this._registry['torch.DoubleStorage'] = function(reader, version) { reader.storage(this, 'float64', 8); };
+        this._registry['torch.CudaByteStorage'] = function(reader, version) { reader.storage(this, 'uint8', 1); };
+        this._registry['torch.CudaCharStorage'] = function(reader, version) { reader.storage(this, 'int8', 1); };
+        this._registry['torch.CudaShortStorage'] = function(reader, version) { reader.storage(this, 'int16', 2); };
+        this._registry['torch.CudaIntStorage'] = function(reader, version) { reader.storage(this, 'int32', 4); };
+        this._registry['torch.CudaLongStorage'] = function(reader, version) { reader.storage(this, 'int64', 8); };
+        this._registry['torch.CudaIntStorage'] = function(reader, version) { reader.storage(this, 'int32', 4); };
+        this._registry['torch.CudaStorage'] = function(reader, version) { reader.storage(this, 'float32', 4); };
+        this._registry['torch.CudaFloatStorage'] = function(reader, version) { reader.storage(this, 'float64', 8); };
+    }
+
+    read() {
+        var type = this.int32();
+        switch (type) {
+            case 0:
+                return null;
+            case 1:
+                return  this.float64();
+            case 2:
+                return this.string();
+            case 3:
+                return this.table();
+            case 4:
+                return this.object();
+            case 5:
+                return this.boolean();
+            case 6:
+            case 7:
+            case 8:
+                return this.function();
+            default:
+                throw new torch.Error("File format has invalid type '" + type + "'.");
+        }
+    }
+
+    boolean() {
+        return this.int32() == 1;
+    }
+
+    bytes(size) {
+        var data = this._buffer.subarray(this._position, this._position + size);
+        this._position += size;
+        return data;
+    }
+
+    int32() {
+        var value = this._dataView.getInt32(this._position, true);
+        this._position += 4;
+        return value;
+    }
+
+    int64() {
+        var lo = this.int32();
+        var hi = this.int32();
+        if (lo == -1 && hi == -1) {
+            return -1;
+        }
+        if (hi != 0) {
+            throw new torch.Error('Invalid int64 value.');
+        }
+        return lo;
+    }
+
+    int64s(size) {
+        var array = [];
+        for (var i = 0; i < size; i++) {
+            array.push(this.int64());
+        }
+        return array;
+    }
+
+    float64() {
+        var value = this._dataView.getFloat64(this._position, true);
+        this._position += 8;
+        return value;
+    }
+
+    string() {
+        var size = this.int32();
+        var buffer = this.bytes(size);
+        return this._textDecoder.decode(buffer);
+    }
+
+    object() {
+        var index = this.int32();
+        if (this._memo[index]) {
+            return this._memo[index];
+        }
+
+        var version = this.string();
+        var name = null;
+        if (version.startsWith('V ')) {
+            name = this.string();
+            version = Number(version.split(' ')[1]);
+        }
+        else {
+            name = version;
+            version = 0;
+        }
+
+        var obj = { __type__: name };
+        var constructor = this._registry[name];
+        if (constructor) {
+            constructor.apply(obj, [ this, version ]);
+        }
+        else {
+            constructor = this._callback(name);
+            if (constructor) {
+                constructor.apply(obj, [ this, version ]);
+            }
+            this.nn(obj);
+        }
+        this._memo[index] = obj;
+        return obj;
+    }
+
+    table() {
+        var index = this.int32();
+        if (this._memo[index]) {
+            return this._memo[index];
+        }
+        var size = this.int32();
+        var table = {};
+        for (var i = 0; i < size; i++) {
+            var key = this.read();
+            var value = this.read();
+            table[key] = value;
+        }
+        var keys = Object.keys(table);
+        keys.sort();
+        var list = true;
+        for (var j = 0; j < keys.length; j++) {
+            if (keys[j] != j.toString()) {
+                list = false;
+            }
+        }
+        if (list && keys.length > 0) {
+            debugger;
+        }
+
+        this._memo[index] = table;
+        return table;
+    }
+
+    function() {
+        var size = this.int32();
+        var dumped = this.bytes(size);
+        var upvalues = this.read();
+        return { size: size, dumped: dumped, upvalues: upvalues };
+    }
+
+    nn(obj) {
+        var attributes = this.read();
+        if (attributes != null) {
+            Object.keys(attributes).forEach((key) => {
+                obj[key] = attributes[key];
+            });
+        }
+    }
+
+    tensor(obj) {
+        var dim = this.int32();
+        obj.size = [];
+        for (var i = 0; i < dim; i++) {
+            obj.size.push(this.int64());
+        }
+        obj.stride = [];
+        for (var j = 0; j < dim; j++) {
+            obj.stride.push(this.int64());
+        }
+        obj.storage_offset = this.int64() - 1;
+        obj.storage = this.read();
+    }
+
+    storage(obj, dataType, itemSize) {
+        obj.dataType = dataType;
+        obj.itemSize = itemSize;
+        obj.size = this.int64();
+        obj.data = this.bytes(obj.size * obj.itemSize);
+    }
+};
+
+torch.OperatorMetadata = class {
+
+    static open(host, callback) {
+        if (torch.OperatorMetadata._metadata) {
+            callback(null, torch.OperatorMetadata._metadata);
+            return;
+        }
+        host.request(null, 'torch-metadata.json', 'utf-8', (err, data) => {
+            torch.OperatorMetadata._metadata = new torch.OperatorMetadata(data);
+            callback(null, torch.OperatorMetadata._metadata);
+            return;
+        });
+    }
+
+    constructor(data) {
+        this._map = {};
+        if (data) {
+            var items = JSON.parse(data);
+            if (items) {
+                items.forEach((item) => {
+                    if (item.name && item.schema) {
+                        this._map[item.name] = item.schema;
+                    }
+                });
+            }
+        }
+    }
+
+    getSchema(operator) {
+        return this._map[operator] || null;
+    }
+
+    getAttributeSchema(operator, name) {
+        var schema = this._map[operator];
+        if (schema && schema.attributes && schema.attributes.length > 0) {
+            if (!schema.__attributesMap) {
+                schema.__attributesMap = {};
+                schema.attributes.forEach((attribute) => {
+                    schema.__attributesMap[attribute.name] = attribute;
+                });
+            }
+            return schema.__attributesMap[name];
+        }
+        return null;
+    }
+};
+
+if (typeof module !== 'undefined' && typeof module.exports === 'object') {
+    module.exports.ModelFactory = torch.ModelFactory;
+}
+

+ 1 - 1
src/view-browser.html

@@ -69,7 +69,7 @@
         </svg>
     </a>
     <button id='open-file-button' class='center' style='top: 200px; width: 125px; opacity: 0;'>Open Model...</button>
-    <input type='file' id='open-file-dialog' style='display:none' multiple='false' accept='.onnx, .pb, .meta, .tflite, .keras, .h5, .hdf5, .json, .model, .mlmodel, .caffemodel, .pt, .pth, .pkl, .joblib, .pbtxt, .prototxt, .xml, .dot'>
+    <input type='file' id='open-file-dialog' style='display:none' multiple='false' accept='.onnx, .pb, .meta, .tflite, .keras, .h5, .hdf5, .json, .model, .mlmodel, .caffemodel, .pt, .pth, .t7, .pkl, .joblib, .pbtxt, .prototxt, .xml, .dot'>
     <!-- Preload fonts to workaround Chrome SVG layout issue -->
     <div style='font-weight: normal; color: rgba(0, 0, 0, 0.01); user-select: none;'>.</div>
     <div style='font-weight: 600; color: rgba(0, 0, 0, 0.01); user-select: none;'>.</div>

+ 3 - 2
src/view.js

@@ -774,8 +774,8 @@ view.View = class {
                                 graphElement.setAttribute('height', height / this._zoom);        
                                 if (inputElements && inputElements.length > 0) {
                                     // Center view based on input elements
-                                    for (var i = 0; i < inputElements.length; i++) {
-                                        inputElements[i].scrollIntoView({ behavior: 'instant' });
+                                    for (var j = 0; j < inputElements.length; j++) {
+                                        inputElements[j].scrollIntoView({ behavior: 'instant' });
                                         break;
                                     }
                                 }
@@ -1075,6 +1075,7 @@ view.ModelFactoryService = class {
         this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt' ]);
         this.register('./caffe2', [ 'predict_net.pb', 'predict_net.pbtxt', 'predict_net.prototxt' ]);
         this.register('./pytorch', [ '.pt', '.pth', '.pkl', '.h5', '.model' ]);
+        this.register('./torch', [ '.t7' ]);
         this.register('./tflite', [ '.tflite', '.lite' ]);
         this.register('./tf', [ '.pb', '.meta', '.pbtxt', '.prototxt' ]);
         this.register('./sklearn', [ '.pkl', '.joblib' ]);

+ 84 - 0
test/models.json

@@ -3151,5 +3151,89 @@
     "source": "https://raw.githubusercontent.com/kosslab-kr/Tizen-NN-Runtime/master/Xor/xorGate.lite",
     "format": "TensorFlow Lite v3",
     "link":   "https://github.com/kosslab-kr/Tizen-NN-Runtime"
+  },
+  {
+    "type":   "torch",
+    "target": "2ch_notredame.t7",
+    "source": "https://s3.amazonaws.com/modelzoo-networks/cvpr2015matching_networks.tar.gz[2ch/2ch_notredame.t7]",
+    "format": "Torch v7",
+    "link":   "https://github.com/szagoruyko/cvpr15deepcompare"
+  },
+  {
+    "type":   "torch",
+    "target": "2ch2stream_liberty.t7",
+    "source": "https://s3.amazonaws.com/modelzoo-networks/cvpr2015matching_networks.tar.gz[2ch2stream/2ch2stream_liberty.t7]",
+    "format": "Torch v7",
+    "link":   "https://github.com/szagoruyko/cvpr15deepcompare"
+  },
+  {
+    "type":   "torch",
+    "target": "2chdeep_yosemite.t7",
+    "source": "https://s3.amazonaws.com/modelzoo-networks/cvpr2015matching_networks.tar.gz[2chdeep/2chdeep_yosemite.t7]",
+    "format": "Torch v7",
+    "link":   "https://github.com/szagoruyko/cvpr15deepcompare"
+  },
+  {
+    "type":   "torch",
+    "target": "completionnet_places2.t7",
+    "source": "http://hi.cs.waseda.ac.jp/~iizuka/data/completionnet_places2.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/yangwangx/inpainting_glcic_pytorch"
+  },
+  {
+    "type":   "torch",
+    "target": "inception.t7",
+    "source": "https://raw.githubusercontent.com/cpra/fer-cnn-sota/master/models/inception.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/cpra/fer-cnn-sota"
+  },
+  {
+    "type":   "torch",
+    "target": "resnet.t7",
+    "source": "https://raw.githubusercontent.com/cpra/fer-cnn-sota/master/models/resnet.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/cpra/fer-cnn-sota"
+  },
+  {
+    "type":   "torch",
+    "target": "resnet-18.t7",
+    "source": "https://d2j0dndfm35trm.cloudfront.net/resnet-18.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/facebook/fb.resnet.torch/tree/master/pretrained"
+  },
+  {
+    "type":   "torch",
+    "target": "resnet-34.t7",
+    "source": "https://d2j0dndfm35trm.cloudfront.net/resnet-34.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/facebook/fb.resnet.torch/tree/master/pretrained"
+  },
+  {
+    "type":   "torch",
+    "target": "resnet-50.t7",
+    "source": "https://d2j0dndfm35trm.cloudfront.net/resnet-50.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/facebook/fb.resnet.torch/tree/master/pretrained"
+  },
+  {
+    "type":   "torch",
+    "target": "siam_liberty.t7",
+    "source": "https://s3.amazonaws.com/modelzoo-networks/cvpr2015matching_networks.tar.gz[siam/siam_liberty.t7]",
+    "format": "Torch v7",
+    "link":   "https://github.com/szagoruyko/cvpr15deepcompare"
+  },
+  {
+    "type":   "torch",
+    "target": "siam2stream_notredame.t7",
+    "source": "https://s3.amazonaws.com/modelzoo-networks/cvpr2015matching_networks.tar.gz[siam2stream/siam2stream_notredame.t7]",
+    "format": "Torch v7",
+    "link":   "https://github.com/szagoruyko/cvpr15deepcompare"
+  },
+  {
+    "type":   "torch",
+    "target": "vgg.t7",
+    "source": "https://raw.githubusercontent.com/cpra/fer-cnn-sota/master/models/vgg.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/cpra/fer-cnn-sota"
   }
 ]