Lutz Roeder 6 лет назад
Родитель
Сommit
39c127f141
3 измененных файлов с 55 добавлено и 9 удалено
  1. 31 5
      src/pytorch.js
  2. 14 0
      test/models.json
  3. 10 4
      test/test.js

+ 31 - 5
src/pytorch.js

@@ -56,6 +56,9 @@ pytorch.ModelFactory = class {
                 constructorTable['torch.autograd.variable.Variable'] = function() {};
                 constructorTable['torch.backends.cudnn.rnn.Unserializable'] = function() {};
                 constructorTable['torch.nn.backends.thnn._get_thnn_function_backend'] = function() {};
+                constructorTable['torch.nn.quantized.modules.functional_modules.FloatFunctional'] = function() {};
+                constructorTable['torch.quantization.stubs.DeQuantStub'] = function() {};
+                constructorTable['torch.quantization.stubs.QuantStub'] = function() {};
                 constructorTable['torch.nn.modules.activation.ELU'] = function() {};
                 constructorTable['torch.nn.modules.activation.GLU'] = function() {};
                 constructorTable['torch.nn.modules.activation.Hardtanh'] = function() {};
@@ -94,6 +97,7 @@ pytorch.ModelFactory = class {
                 constructorTable['torch.nn.modules.instancenorm.InstanceNorm2d'] = function() {};
                 constructorTable['torch.nn.modules.instancenorm.InstanceNorm3d'] = function() {};
                 constructorTable['torch.nn.modules.linear.Linear'] = function() {};
+                constructorTable['torch.nn.modules.linear.Identity'] = function() {};
                 constructorTable['torch.nn.modules.loss.BCELoss'] = function() {};
                 constructorTable['torch.nn.modules.loss.BCEWithLogitsLoss'] = function() {}; 
                 constructorTable['torch.nn.modules.loss.CrossEntropyLoss'] = function() {};
@@ -232,6 +236,22 @@ pytorch.ModelFactory = class {
                         this.stride = state[3];
                     };
                 };
+                constructorTable['torch.cuda.FloatTensor'] = function () {
+                    this.__setstate__ = function(state) {
+                        this.storage = state[0];
+                        this.storage_offset = state[1];
+                        this.size = state[2];
+                        this.stride = state[3];
+                    };
+                };
+                constructorTable['torch.cuda.DoubleTensor'] = function () {
+                    this.__setstate__ = function(state) {
+                        this.storage = state[0];
+                        this.storage_offset = state[1];
+                        this.size = state[2];
+                        this.stride = state[3];
+                    };
+                };
                 constructorTable['numpy.dtype'] = function(obj, align, copy) { 
                     switch (obj) {
                         case 'i1': this.name = 'int8'; this.itemsize = 1; break;
@@ -665,10 +685,10 @@ pytorch.ModelFactory = class {
             root.state_dict_stylepredictor, root.state_dict_ghiasi
         ];
         for (let dict of candidates) {
-            const state_dict =
-                pytorch.ModelFactory._convertStateDictList(dict) ||
-                pytorch.ModelFactory._convertStateDictMap(dict) || 
-                pytorch.ModelFactory._convertStateDictGroupMap(dict);
+            let state_dict = null;
+            state_dict = state_dict || pytorch.ModelFactory._convertStateDictList(dict);
+            state_dict = state_dict || pytorch.ModelFactory._convertStateDictMap(dict);
+            state_dict = state_dict || pytorch.ModelFactory._convertStateDictGroupMap(dict);
             if (state_dict) {
                 return state_dict;
             }
@@ -777,21 +797,27 @@ pytorch.ModelFactory = class {
                 }
             }
             else if (Object(item) === item) {
+                let hasTensors = false;
                 for (let key in item) {
                     const value = item[key];
                     if (pytorch.ModelFactory._isTensor(value)) {
                         state_group.states.push({ name: key, value: value, id: state_group_name + '.' + key });
+                        hasTensors = true;
                     }
                     else if (value !== Object(value)) {
                         state_group.attributes.push({ name: key, value: value });
                     }
                     else if (value && value.data && value.__module__ === 'torch.nn.parameter' && value.__name__ === 'Parameter') {
                         state_group.states.push({ name: key, value: value.data, id: state_group_name + '.' + key });
+                        hasTensors = true;
                     }
                     else {
                         return null;
                     }
                 }
+                if (!hasTensors) {
+                    return null;
+                }
             }
             else {
                 return null;
@@ -852,7 +878,7 @@ pytorch.ModelFactory = class {
     }
 
     static _isTensor(obj) {
-        return obj && obj.__module__ === 'torch' && obj.__name__ && obj.__name__.endsWith('Tensor');
+        return obj && (obj.__module__ === 'torch' || obj.__module__ === 'torch.cuda') && obj.__name__ && obj.__name__.endsWith('Tensor');
     }
 };
 

+ 14 - 0
test/models.json

@@ -3759,6 +3759,13 @@
     "script": [ "${root}/tools/pytorch", "sync install zoo" ],
     "status": "script"
   },
+  {
+    "type":   "pytorch",
+    "target": "mask_rcnn_r50_fpn_1x_20181010-069fa190.pth",
+    "source": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth",
+    "format": "PyTorch",
+    "link":   "https://github.com/open-mmlab/mmdetection/blob/master/docs/MODEL_ZOO.md"
+  },
   {
     "type":   "pytorch",
     "target": "mobilenet_v2.pth",
@@ -3968,6 +3975,13 @@
     "format": "PyTorch",
     "link":   "https://github.com/lutzroeder/netron/issues/133"
   },
+  {
+    "type":   "pytorch",
+    "target": "vgg-cifar10.pth.tar",
+    "source": "http://www.cs.unc.edu/~cyfu/cifar10/model_best.pth.tar",
+    "format": "PyTorch",
+    "link":   "https://github.com/chengyangfu/pytorch-vgg-cifar10"
+  },
   {
     "type":   "pytorch",
     "target": "vgg11-bbd30ac9.pth",

+ 10 - 4
test/test.js

@@ -52,7 +52,8 @@ global.TextDecoder = class {
     }
 };
 
-const type = process.argv.length > 2 ? process.argv[2] : null;
+let filter = process.argv.length > 2 ? process.argv[2] : null;
+const type = filter ? filter.split('/').shift() : '';
 const dataFolder = __dirname + '/data';
 let items = JSON.parse(fs.readFileSync(__dirname + '/models.json', 'utf-8'));
 
@@ -562,12 +563,17 @@ function next() {
         next();
         return;
     }
-    if (process.stdout.clearLine) {
-        process.stdout.clearLine();
-    }
     const targets = item.target.split(',');
     const target = targets[0];
     const folder = dataFolder + '/' + item.type;
+    const name = item.type + '/' + target;
+    if (filter && !name.startsWith(filter)) {
+        next();
+        return;
+    }
+    if (process.stdout.clearLine) {
+        process.stdout.clearLine();
+    }
     process.stdout.write(item.type + '/' + target + '\n');
 
     let promise = null;