Lutz Roeder 7 лет назад
Родитель
Сommit
b0435f9eac
4 измененных файлов с 50 добавлено и 0 удалено
  1. 7 0
      src/sklearn.js
  2. 16 0
      src/torch-metadata.json
  3. 20 0
      src/torch.js
  4. 7 0
      test/models.json

+ 7 - 0
src/sklearn.js

@@ -59,6 +59,10 @@ sklearn.ModelFactory = class {
                                 this.itemsize = Number(obj.substring(1));
                                 this.name = 'object';
                             }
+                            else if (obj.startsWith('S')) {
+                                this.itemsize = Number(obj.substring(1));
+                                this.name = 'string';
+                            }
                             else {
                                 throw new sklearn.Error("Unknown dtype '" + obj.toString() + "'.");
                             }
@@ -191,6 +195,9 @@ sklearn.ModelFactory = class {
                         obj.__type__ = cls;
                         return obj;
                     }
+                    if (base == '__builtin__.tuple') {
+                        return state;
+                    }
                     throw new sklearn.Error("Unknown base type '" + base + "'.");
                 };
                 functionTable['numpy.core.multiarray.scalar'] = function(dtype, rawData) {

+ 16 - 0
src/torch-metadata.json

@@ -224,6 +224,22 @@
       ]
     }
   },
+  {
+    "name": "Normalize",
+    "schema": {
+      "category": "Normalization",
+      "attributes": [
+      ]
+    }
+  },
+  {
+    "name": "SpatialCrossMapLRN",
+    "schema": {
+      "category": "Normalization",
+      "attributes": [
+      ]
+    }
+  },
   {
     "name": "Mean",
     "schema": {

+ 20 - 0
src/torch.js

@@ -248,11 +248,20 @@ torch.Node = class {
         delete module.gradBias;
         delete module.scaleT;
         delete module._input;
+        delete module._output;
+        delete module._gradInput;
         delete module._gradOutput;
+        delete module.buffer;
+        delete module.buffer2;
         switch (type) {
             case 'nn.Linear':
                 delete module.addBuffer;
                 break;
+            case 'nn.Normalize':
+                delete module.addBuffer;
+                delete module.normp;
+                delete module.norm;
+                break;
             case 'cudnn.SpatialConvolution':
             case 'cudnn.SpatialFullConvolution':
             case 'nn.SpatialConvolution':
@@ -276,6 +285,8 @@ torch.Node = class {
                 delete module.save_mean;
                 delete module.save_std;
                 delete module.gradWeight;
+                delete module.normalized;
+                delete module.centered;
                 if (module.running_mean) {
                     module.mean = module.running_mean;
                     delete module.running_mean;
@@ -290,6 +301,9 @@ torch.Node = class {
                 }
                 delete module.bn; // TODO InstanceNormalization
                 break;
+            case 'nn.SpatialCrossMapLRN':
+                delete module.scale;
+                break;
             case 'cudnn.SpatialMaxPooling':
             case 'cudnn.SpatialAveragePooling':
             case 'inn.SpatialMaxPooling':
@@ -538,6 +552,8 @@ torch.T7Reader = class {
         this._registry['nn.LeakyReLU'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.Linear'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.Mean'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.MulConstant'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Normalize'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.Parallel'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.ReLU'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.Reshape'] = function(reader, version) { reader.nn(this); };
@@ -548,11 +564,15 @@ torch.T7Reader = class {
         this._registry['nn.SpatialBatchNormalization'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.SpatialConvolution'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.SpatialConvolutionMM'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialCrossMapLRN'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.SpatialDilatedConvolution'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.SpatialFullConvolution'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.SpatialLPPooling'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.SpatialMaxPooling'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.SpatialReflectionPadding'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.SpatialZeroPadding'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Square'] = function(reader, version) { reader.nn(this); };
+        this._registry['nn.Sqrt'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.Tanh'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.View'] = function(reader, version) { reader.nn(this); };
         this._registry['nn.gModule'] = function(reader, version) { reader.nn(this); };

+ 7 - 0
test/models.json

@@ -3236,6 +3236,13 @@
     "format": "Torch v7",
     "link":   "https://github.com/cpra/fer-cnn-sota"
   },
+  {
+    "type":   "torch",
+    "target": "openface.nn4.small2.v1.t7",
+    "source": "https://raw.githubusercontent.com/pyannote/pyannote-data/master/openface.nn4.small2.v1.t7",
+    "format": "Torch v7",
+    "link":   "https://github.com/pyannote/pyannote-data"
+  },
   {
     "type":   "torch",
     "target": "portrait_584_net_D_cpu.t7",