Lutz Roeder 4 лет назад
Родитель
Сommit
c088c0cb28
1 измененных файлов с 72 добавлено и 25 удалено
  1. 72 25
      source/onnx.js

+ 72 - 25
source/onnx.js

@@ -125,6 +125,12 @@ onnx.ModelFactory = class {
                 }
             }
         }
+        if (onnx.Text.Reader.open(stream)) {
+            return 'onnx.text';
+        }
+        if (onnx.Runtime.Reader.open(stream, extension)) {
+            return 'onnx.flatbuffers';
+        }
         tags = context.tags('pbtxt');
         if (tags.has('ir_version')) {
             return 'onnx.pbtxt.ModelProto';
@@ -132,12 +138,6 @@ onnx.ModelFactory = class {
         if (tags.has('graph') && extension !== 'model') {
             return 'onnx.pbtxt.ModelProto';
         }
-        if (onnx.Runtime.Reader.open(stream, extension)) {
-            return 'onnx.flatbuffers';
-        }
-        if (onnx.Text.Reader.open(stream)) {
-            return 'onnx.text';
-        }
         return undefined;
     }
 
@@ -1898,7 +1898,7 @@ onnx.Text.Reader = class {
                 switch (keyword) {
                     case 'ir_version':
                     case 'model_version':
-                        model[keyword] = this._int64();
+                        model[keyword] = this._integer();
                         break;
                     case 'opset_import':
                         model[keyword] = this._operatorSetId();
@@ -1949,7 +1949,9 @@ onnx.Text.Reader = class {
                 do {
                     const valueInfo = this._valueInfo();
                     if (this._match('=')) {
-                        this._throw('Initializer not implemented.');
+                        const tensor = this._tensor(valueInfo.type);
+                        tensor.name = valueInfo.name;
+                        graph.initializer.push(tensor);
                     }
                     graph.input.push(valueInfo);
                 }
@@ -1964,7 +1966,9 @@ onnx.Text.Reader = class {
                 do {
                     const valueInfo = this._valueInfo();
                     if (this._match('=')) {
-                        this._throw('Initializer not implemented.');
+                        const tensor = this._tensor(valueInfo.type);
+                        tensor.name = valueInfo.name;
+                        graph.initializer.push(tensor);
                     }
                     else {
                         graph.value_info.push(valueInfo);
@@ -2063,7 +2067,17 @@ onnx.Text.Reader = class {
                 const identifier = this._identifier();
                 if (this._dataTypes.has(identifier)) {
                     attribute.type = onnx.AttributeType.TENSOR;
-                    attribute.t = this._tensor(identifier);
+                    if (!this._dataTypes.has(identifier)) {
+                        this._throw("Unexpected type '" + identifier + "'.");
+                    }
+                    const type = this._type(this._dataTypes.get(identifier));
+                    if (!type.tensor_type.elem_type) {
+                        this._throw('Expected tensor data type.');
+                    }
+                    if (!type.tensor_type.shape || !type.tensor_type.shape.dim) {
+                        this._throw('Expected tensor shape.');
+                    }
+                    attribute.t = this._tensor(type);
                 }
                 else {
                     attribute.type = onnx.AttributeType.GRAPH;
@@ -2148,7 +2162,7 @@ onnx.Text.Reader = class {
                     dimension.dim_param = identifier;
                 }
                 else {
-                    dimension.dim_value = this._int64();
+                    dimension.dim_value = this._integer();
                 }
             }
             shape.dim.push(dimension);
@@ -2157,26 +2171,51 @@ onnx.Text.Reader = class {
         return shape;
     }
 
-    _tensor(elem_type) {
+    _tensor(type) {
         const tensor = new onnx.proto.TensorProto();
-        if (!this._dataTypes.has(elem_type)) {
-            this._throw("Unexpected type '" + elem_type + "'.");
-        }
-        const type = this._type(this._dataTypes.get(elem_type));
-        if (!type.tensor_type.elem_type) {
-            this._throw('Expected tensor data type.');
+        if (!type.tensor_type || !type.tensor_type.elem_type) {
+            this._throw('Expected tensor type.');
         }
-        if (!type.tensor_type.shape || !type.tensor_type.shape.dim) {
-            this._throw('Expected tensor shape.');
+        if (!type.tensor_type.shape || !type.tensor_type.shape.dim || !type.tensor_type.shape.dim.every((dim) => dim.dim_value)) {
+            this._throw('Expected numeric tensor shape.');
         }
-        tensor.data_type = type.tensor_type.elem_type;
+        const elem_type = type.tensor_type.elem_type;
+        tensor.data_type = elem_type;
         tensor.dims = type.tensor_type.shape.dim.map((dim) => dim.dim_value);
         this._match('=');
         this._expect('{');
         if (!this._match('}')) {
             do {
-                this._next();
-            } while (!this._match('}'));
+                switch (elem_type) {
+                    case onnx.DataType.INT8:
+                    case onnx.DataType.INT16:
+                    case onnx.DataType.INT32:
+                    case onnx.DataType.UINT8:
+                    case onnx.DataType.UINT16:
+                    case onnx.DataType.BOOL:
+                        tensor.int32_data.push(this._integer());
+                        break;
+                    case onnx.DataType.INT64:
+                        tensor.int64_data.push(this._integer());
+                        break;
+                    case onnx.DataType.UINT32:
+                    case onnx.DataType.UINT64:
+                        tensor.uint64_data.push(this._integer());
+                        break;
+                    case onnx.DataType.FLOAT:
+                        tensor.float_data.push(this._float());
+                        break;
+                    case onnx.DataType.DOUBLE:
+                        tensor.double_data.push(this._float());
+                        break;
+                    case onnx.DataType.STRING:
+                        tensor.string_data.push(this.string());
+                        break;
+                    default:
+                        return this._throw("Unsupported tensor element type '" + elem_type.toString() + "'.");
+                }
+            } while (this._match(','));
+            this._expect('}');
         }
         return tensor;
     }
@@ -2299,7 +2338,7 @@ onnx.Text.Reader = class {
         return undefined;
     }
 
-    _int64() {
+    _integer() {
         const value = this._literal();
         if (!Number.isInteger(value)) {
             this._throw('Integer value expected.');
@@ -2307,6 +2346,14 @@ onnx.Text.Reader = class {
         return value;
     }
 
+    _float() {
+        const value = this._literal();
+        if (typeof value !== 'number') {
+            this._throw('Float value expected.');
+        }
+        return value;
+    }
+
     _string() {
         const value = this._literal();
         if (typeof value !== 'string') {
@@ -2323,7 +2370,7 @@ onnx.Text.Reader = class {
                 const value = new onnx.proto.OperatorSetIdProto();
                 value.domain = this._string();
                 this._expect(':');
-                value.version = this._int64();
+                value.version = this._integer();
                 list.push(value);
             }
             while (this._match(','));