|
|
@@ -125,6 +125,12 @@ onnx.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+ if (onnx.Text.Reader.open(stream)) {
|
|
|
+ return 'onnx.text';
|
|
|
+ }
|
|
|
+ if (onnx.Runtime.Reader.open(stream, extension)) {
|
|
|
+ return 'onnx.flatbuffers';
|
|
|
+ }
|
|
|
tags = context.tags('pbtxt');
|
|
|
if (tags.has('ir_version')) {
|
|
|
return 'onnx.pbtxt.ModelProto';
|
|
|
@@ -132,12 +138,6 @@ onnx.ModelFactory = class {
|
|
|
if (tags.has('graph') && extension !== 'model') {
|
|
|
return 'onnx.pbtxt.ModelProto';
|
|
|
}
|
|
|
- if (onnx.Runtime.Reader.open(stream, extension)) {
|
|
|
- return 'onnx.flatbuffers';
|
|
|
- }
|
|
|
- if (onnx.Text.Reader.open(stream)) {
|
|
|
- return 'onnx.text';
|
|
|
- }
|
|
|
return undefined;
|
|
|
}
|
|
|
|
|
|
@@ -1898,7 +1898,7 @@ onnx.Text.Reader = class {
|
|
|
switch (keyword) {
|
|
|
case 'ir_version':
|
|
|
case 'model_version':
|
|
|
- model[keyword] = this._int64();
|
|
|
+ model[keyword] = this._integer();
|
|
|
break;
|
|
|
case 'opset_import':
|
|
|
model[keyword] = this._operatorSetId();
|
|
|
@@ -1949,7 +1949,9 @@ onnx.Text.Reader = class {
|
|
|
do {
|
|
|
const valueInfo = this._valueInfo();
|
|
|
if (this._match('=')) {
|
|
|
- this._throw('Initializer not implemented.');
|
|
|
+ const tensor = this._tensor(valueInfo.type);
|
|
|
+ tensor.name = valueInfo.name;
|
|
|
+ graph.initializer.push(tensor);
|
|
|
}
|
|
|
graph.input.push(valueInfo);
|
|
|
}
|
|
|
@@ -1964,7 +1966,9 @@ onnx.Text.Reader = class {
|
|
|
do {
|
|
|
const valueInfo = this._valueInfo();
|
|
|
if (this._match('=')) {
|
|
|
- this._throw('Initializer not implemented.');
|
|
|
+ const tensor = this._tensor(valueInfo.type);
|
|
|
+ tensor.name = valueInfo.name;
|
|
|
+ graph.initializer.push(tensor);
|
|
|
}
|
|
|
else {
|
|
|
graph.value_info.push(valueInfo);
|
|
|
@@ -2063,7 +2067,17 @@ onnx.Text.Reader = class {
|
|
|
const identifier = this._identifier();
|
|
|
if (this._dataTypes.has(identifier)) {
|
|
|
attribute.type = onnx.AttributeType.TENSOR;
|
|
|
- attribute.t = this._tensor(identifier);
|
|
|
+ if (!this._dataTypes.has(identifier)) {
|
|
|
+ this._throw("Unexpected type '" + identifier + "'.");
|
|
|
+ }
|
|
|
+ const type = this._type(this._dataTypes.get(identifier));
|
|
|
+ if (!type.tensor_type.elem_type) {
|
|
|
+ this._throw('Expected tensor data type.');
|
|
|
+ }
|
|
|
+ if (!type.tensor_type.shape || !type.tensor_type.shape.dim) {
|
|
|
+ this._throw('Expected tensor shape.');
|
|
|
+ }
|
|
|
+ attribute.t = this._tensor(type);
|
|
|
}
|
|
|
else {
|
|
|
attribute.type = onnx.AttributeType.GRAPH;
|
|
|
@@ -2148,7 +2162,7 @@ onnx.Text.Reader = class {
|
|
|
dimension.dim_param = identifier;
|
|
|
}
|
|
|
else {
|
|
|
- dimension.dim_value = this._int64();
|
|
|
+ dimension.dim_value = this._integer();
|
|
|
}
|
|
|
}
|
|
|
shape.dim.push(dimension);
|
|
|
@@ -2157,26 +2171,51 @@ onnx.Text.Reader = class {
|
|
|
return shape;
|
|
|
}
|
|
|
|
|
|
- _tensor(elem_type) {
|
|
|
+ _tensor(type) {
|
|
|
const tensor = new onnx.proto.TensorProto();
|
|
|
- if (!this._dataTypes.has(elem_type)) {
|
|
|
- this._throw("Unexpected type '" + elem_type + "'.");
|
|
|
- }
|
|
|
- const type = this._type(this._dataTypes.get(elem_type));
|
|
|
- if (!type.tensor_type.elem_type) {
|
|
|
- this._throw('Expected tensor data type.');
|
|
|
+ if (!type.tensor_type || !type.tensor_type.elem_type) {
|
|
|
+ this._throw('Expected tensor type.');
|
|
|
}
|
|
|
- if (!type.tensor_type.shape || !type.tensor_type.shape.dim) {
|
|
|
- this._throw('Expected tensor shape.');
|
|
|
+ if (!type.tensor_type.shape || !type.tensor_type.shape.dim || !type.tensor_type.shape.dim.every((dim) => dim.dim_value)) {
|
|
|
+ this._throw('Expected numeric tensor shape.');
|
|
|
}
|
|
|
- tensor.data_type = type.tensor_type.elem_type;
|
|
|
+ const elem_type = type.tensor_type.elem_type;
|
|
|
+ tensor.data_type = elem_type;
|
|
|
tensor.dims = type.tensor_type.shape.dim.map((dim) => dim.dim_value);
|
|
|
this._match('=');
|
|
|
this._expect('{');
|
|
|
if (!this._match('}')) {
|
|
|
do {
|
|
|
- this._next();
|
|
|
- } while (!this._match('}'));
|
|
|
+ switch (elem_type) {
|
|
|
+ case onnx.DataType.INT8:
|
|
|
+ case onnx.DataType.INT16:
|
|
|
+ case onnx.DataType.INT32:
|
|
|
+ case onnx.DataType.UINT8:
|
|
|
+ case onnx.DataType.UINT16:
|
|
|
+ case onnx.DataType.BOOL:
|
|
|
+ tensor.int32_data.push(this._integer());
|
|
|
+ break;
|
|
|
+ case onnx.DataType.INT64:
|
|
|
+ tensor.int64_data.push(this._integer());
|
|
|
+ break;
|
|
|
+ case onnx.DataType.UINT32:
|
|
|
+ case onnx.DataType.UINT64:
|
|
|
+ tensor.uint64_data.push(this._integer());
|
|
|
+ break;
|
|
|
+ case onnx.DataType.FLOAT:
|
|
|
+ tensor.float_data.push(this._float());
|
|
|
+ break;
|
|
|
+ case onnx.DataType.DOUBLE:
|
|
|
+ tensor.double_data.push(this._float());
|
|
|
+ break;
|
|
|
+ case onnx.DataType.STRING:
|
|
|
+ tensor.string_data.push(this.string());
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ return this._throw("Unsupported tensor element type '" + elem_type.toString() + "'.");
|
|
|
+ }
|
|
|
+ } while (this._match(','));
|
|
|
+ this._expect('}');
|
|
|
}
|
|
|
return tensor;
|
|
|
}
|
|
|
@@ -2299,7 +2338,7 @@ onnx.Text.Reader = class {
|
|
|
return undefined;
|
|
|
}
|
|
|
|
|
|
- _int64() {
|
|
|
+ _integer() {
|
|
|
const value = this._literal();
|
|
|
if (!Number.isInteger(value)) {
|
|
|
this._throw('Integer value expected.');
|
|
|
@@ -2307,6 +2346,14 @@ onnx.Text.Reader = class {
|
|
|
return value;
|
|
|
}
|
|
|
|
|
|
+ _float() {
|
|
|
+ const value = this._literal();
|
|
|
+ if (typeof value !== 'number') {
|
|
|
+ this._throw('Float value expected.');
|
|
|
+ }
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+
|
|
|
_string() {
|
|
|
const value = this._literal();
|
|
|
if (typeof value !== 'string') {
|
|
|
@@ -2323,7 +2370,7 @@ onnx.Text.Reader = class {
|
|
|
const value = new onnx.proto.OperatorSetIdProto();
|
|
|
value.domain = this._string();
|
|
|
this._expect(':');
|
|
|
- value.version = this._int64();
|
|
|
+ value.version = this._integer();
|
|
|
list.push(value);
|
|
|
}
|
|
|
while (this._match(','));
|