Explorar el Código

Update bigdl.js

Lutz Roeder hace 4 años
padre
commit
dc4165ac43
Se han modificado 2 ficheros con 24 adiciones y 48 borrados
  1. 14 14
      source/bigdl-metadata.json
  2. 10 34
      source/bigdl.js

+ 14 - 14
source/bigdl-metadata.json

@@ -1,14 +1,14 @@
 [
   {
-    "name": "Dropout",
+    "name": "com.intel.analytics.bigdl.nn.Dropout",
     "category": "Dropout"
   },
   {
-    "name": "InferReshape",
+    "name": "com.intel.analytics.bigdl.nn.InferReshape",
     "category": "Shape"
   },
   {
-    "name": "JoinTable",
+    "name": "com.intel.analytics.bigdl.nn.JoinTable",
     "category": "Tensor",
     "inputs": [
       { "name": "inputs", "list": true }
@@ -18,7 +18,7 @@
     ]
   },
   {
-    "name": "Linear",
+    "name": "com.intel.analytics.bigdl.nn.Linear",
     "category": "Layer",
     "inputs": [
       { "name": "inputs" },
@@ -27,7 +27,7 @@
     ]
   },
   {
-    "name": "NormalizeScale",
+    "name": "com.intel.analytics.bigdl.nn.NormalizeScale",
     "category": "Normalization",
     "inputs": [
       { "name": "inputs" },
@@ -35,7 +35,7 @@
     ]
   },
   {
-    "name": "ReLU",
+    "name": "com.intel.analytics.bigdl.nn.ReLU",
     "category": "Activation"
   },
   {
@@ -52,15 +52,15 @@
     "category": "Activation"
   },
   {
-    "name": "SpatialAveragePooling",
+    "name": "com.intel.analytics.bigdl.nn.SpatialAveragePooling",
     "category": "Pool"
   },
   {
-    "name": "SpatialBatchNormalization",
+    "name": "com.intel.analytics.bigdl.nn.SpatialBatchNormalization",
     "category": "Normalization"
   },
   {
-    "name": "SpatialConvolution",
+    "name": "com.intel.analytics.bigdl.nn.quantized.SpatialConvolution",
     "category": "Layer",
     "inputs": [
       { "name": "inputs" },
@@ -69,11 +69,11 @@
     ]
   },
   {
-    "name": "SpatialCrossMapLRN",
+    "name": "com.intel.analytics.bigdl.nn.SpatialCrossMapLRN",
     "category": "Normalization"
   },
   {
-    "name": "SpatialDilatedConvolution",
+    "name": "com.intel.analytics.bigdl.nn.SpatialDilatedConvolution",
     "category": "Layer",
     "inputs": [
       { "name": "inputs" },
@@ -82,14 +82,14 @@
     ]
   },
   {
-    "name": "SpatialMaxPooling",
+    "name": "com.intel.analytics.bigdl.nn.SpatialMaxPooling",
     "category": "Pool"
   },
   {
-    "name": "Transpose",
+    "name": "com.intel.analytics.bigdl.nn.Transpose",
     "category": "Shape"
   },
   {
-    "name": "View"
+    "name": "com.intel.analytics.bigdl.nn.View"
   }
 ]

+ 10 - 34
source/bigdl.js

@@ -58,17 +58,16 @@ bigdl.Graph = class {
         this._inputs = [];
         this._outputs = [];
         this._nodes = [];
-        this._loadModule(metadata, '', module);
+        this._loadModule(metadata, module);
     }
 
-    _loadModule(metadata, group, module) {
+    _loadModule(metadata, module) {
         switch (module.moduleType) {
-            case 'com.intel.analytics.bigdl.nn.StaticGraph': {
-                this._loadStaticGraph(metadata, group, module);
-                break;
-            }
+            case 'com.intel.analytics.bigdl.nn.StaticGraph':
             case 'com.intel.analytics.bigdl.nn.Sequential': {
-                this._loadSequential(metadata, group, module);
+                for (const submodule of module.subModules) {
+                    this._loadModule(metadata, submodule);
+                }
                 break;
             }
             case 'com.intel.analytics.bigdl.nn.Input': {
@@ -78,30 +77,12 @@ bigdl.Graph = class {
                 break;
             }
             default: {
-                this._nodes.push(new bigdl.Node(metadata, group, module));
+                this._nodes.push(new bigdl.Node(metadata, module));
                 break;
             }
         }
     }
 
-    _loadSequential(metadata, group, module) {
-        group = group.length > 0 ?  group + '.' + module.namePostfix : module.namePostfix;
-        for (const submodule of module.subModules) {
-            this._loadModule(metadata, group, submodule);
-        }
-    }
-
-    _loadStaticGraph(metadata, group, module) {
-        group = group.length > 0 ?  group + '.' + module.namePostfix : module.namePostfix;
-        for (const submodule of module.subModules) {
-            this._loadModule(metadata, group, submodule);
-        }
-    }
-
-    get groups() {
-        return this._groups || false;
-    }
-
     get type() {
         return this._type;
     }
@@ -168,15 +149,14 @@ bigdl.Argument = class {
 
 bigdl.Node = class {
 
-    constructor(metadata, group, module) {
-        this._group = group;
-        const type = module.moduleType.split('.').pop();
+    constructor(metadata, module) {
+        const type = module.moduleType;
         this._name = module.name;
         this._attributes = [];
         this._inputs = [];
         this._outputs = [];
         this._inputs.push(new bigdl.Parameter('input', module.preModules.map((id) => new bigdl.Argument(id, null, null))));
-        this._type =  metadata.type(type);
+        this._type =  metadata.type(type) || { name: type };
         const inputs = (this._type && this._type.inputs) ? this._type.inputs.slice() : [];
         inputs.shift();
         if (module.weight) {
@@ -226,10 +206,6 @@ bigdl.Node = class {
         ]));
     }
 
-    get group() {
-        return this._group;
-    }
-
     get type() {
         return this._type;
     }