Lutz Roeder 4 лет назад
Родитель
Сommit
d15a024c9f
1 измененных файлов с 78 добавлено и 52 удалено
  1. 78 52
      source/om.js

+ 78 - 52
source/om.js

@@ -67,7 +67,8 @@ om.Graph = class {
             if (op.type === 'Const') {
                 continue;
             }
-            this._nodes.push(new om.Node(metadata, op, graph, weights, model));
+            const node = new om.Node(metadata, op, graph, weights, model);
+            this._nodes.push(node);
         }
     }
 
@@ -193,7 +194,8 @@ om.Node = class {
                     this._chain.push(new om.Node(metadata, { type: 'ReLU' }, graph, weights));
                     continue;
                 }
-                this._attributes.push(new om.Attribute(metadata.attribute(this._type, name), name, value, true));
+                const attribute = new om.Attribute(metadata.attribute(this._type, name), name, value, true);
+                this._attributes.push(attribute);
             }
         }
     }
@@ -237,63 +239,83 @@ om.Attribute = class {
         this._name = name;
         this._value = value;
         this._visible = visible;
-        if (Object.prototype.hasOwnProperty.call(value, 'i')) {
-            this._value = value.i;
-            this._type = 'int64';
-        }
-        else if (Object.prototype.hasOwnProperty.call(value, 'f')) {
-            this._value = value.f;
-            this._type = 'float32';
-        }
-        else if (Object.prototype.hasOwnProperty.call(value, 'b')) {
-            this._value = value.b;
-            this._type = 'boolean';
-        }
-        else if (Object.prototype.hasOwnProperty.call(value, 'bt')) {
-            this._value = null;
-            if (value.bt.length !== 0) {
-                this._type = 'tensor';
-                this._value = new om.Tensor('Constant', new om.TensorType('float32', [ value.bt.length / 4 ], null), value.bt);
-            }
-        }
-        else if (Object.prototype.hasOwnProperty.call(value, 's')) {
-            if (typeof value.s === 'string') {
-                this._value = value.s;
+        switch (value.value) {
+            case 'i': {
+                this._value = value.i;
+                this._type = 'int64';
+                break;
             }
-            else if (value.s.filter(c => c <= 32 && c >= 128).length === 0) {
-                this._value = om.Metadata.textDecoder.decode(value.s);
+            case 'f': {
+                this._value = value.f;
+                this._type = 'float32';
+                break;
             }
-            else {
-                this._value = value.s;
+            case 'b': {
+                this._value = value.b;
+                this._type = 'boolean';
+                break;
             }
-            this._type = 'string';
-        }
-        else if (Object.prototype.hasOwnProperty.call(value, 'list')) {
-            const list = value.list;
-            this._value = [];
-            if (list.s && list.s.length > 0) {
-                this._value = list.s.map(v => String.fromCharCode.apply(null, new Uint16Array(v))).join(', ');
-                this._type = 'string[]';
+            case 'bt': {
+                this._value = null;
+                if (value.bt.length !== 0) {
+                    this._type = 'tensor';
+                    this._value = new om.Tensor('Constant', new om.TensorType('float32', [ value.bt.length / 4 ], null), value.bt);
+                }
+                break;
             }
-            else if (list.b && list.b.length > 0) {
-                this._value = list.b;
-                this._type = 'boolean[]';
+            case 'dt': {
+                this._type = 'DataType';
+                this._value = om.Utility.dtype(value.dt.toNumber());
+                break;
             }
-            else if (list.i && list.i.length > 0) {
-                this._value = list.i;
-                this._type = 'int64[]';
+            case 's': {
+                if (typeof value.s === 'string') {
+                    this._value = value.s;
+                }
+                else if (value.s.filter(c => c <= 32 && c >= 128).length === 0) {
+                    this._value = om.Utility.decodeText(value.s);
+                }
+                else {
+                    this._value = value.s;
+                }
+                this._type = 'string';
+                break;
             }
-            else if (list.f && list.f.length > 0) {
-                this._value = list.f;
-                this._type = 'float32[]';
+            case 'func': {
+                break;
             }
-            else if (list.type && list.type.length > 0) {
-                this._type = 'type[]';
-                this._value = list.type.map((type) => om.Node.enum2Dtype(type) || '?');
+            case 'list': {
+                const list = value.list;
+                this._value = [];
+                if (list.s && list.s.length > 0) {
+                    this._value = list.s.map(v => String.fromCharCode.apply(null, new Uint16Array(v))).join(', ');
+                    this._type = 'string[]';
+                }
+                else if (list.b && list.b.length > 0) {
+                    this._value = list.b;
+                    this._type = 'boolean[]';
+                }
+                else if (list.i && list.i.length > 0) {
+                    this._value = list.i;
+                    this._type = 'int64[]';
+                }
+                else if (list.f && list.f.length > 0) {
+                    this._value = list.f;
+                    this._type = 'float32[]';
+                }
+                else if (list.type && list.type.length > 0) {
+                    this._type = 'type[]';
+                    this._value = list.type.map((type) => om.Node.enum2Dtype(type) || '?');
+                }
+                else if (list.shape && list.shape.length > 0) {
+                    this._type = 'shape[]';
+                    this._value = list.shape.map((shape) => new om.TensorShape(shape));
+                }
+                break;
             }
-            else if (list.shape && list.shape.length > 0) {
-                this._type = 'shape[]';
-                this._value = list.shape.map((shape) => new om.TensorShape(shape));
+            case undefined: {
+                this._value = null;
+                break;
             }
         }
     }
@@ -580,12 +602,16 @@ om.Utility = class {
         }
         throw new om.Error("Unknown dtype '" + value + "'.");
     }
+
+    static decodeText(value) {
+        om.Utility._textDecoder = om.Utility._textDecoder || new TextDecoder('utf-8');
+        return om.Utility._textDecoder.decode(value);
+    }
 };
 
 om.Metadata = class {
 
     static open(context) {
-        om.Metadata.textDecoder = om.Metadata.textDecoder || new TextDecoder('utf-8');
         if (om.Metadata._metadata) {
             return Promise.resolve(om.Metadata._metadata);
         }