|
|
@@ -111,12 +111,32 @@ rknn.Graph = class {
|
|
|
this._nodes = [];
|
|
|
switch (type) {
|
|
|
case 'json': {
|
|
|
+ const dataType = (value) => {
|
|
|
+ const type = value.vx_type.startsWith('VSI_NN_TYPE_') ? value.vx_type.split('_').pop().toLowerCase() : value.vx_type;
|
|
|
+ switch (type) {
|
|
|
+ case 'uint8':
|
|
|
+ case 'int8':
|
|
|
+ case 'int16':
|
|
|
+ case 'int32':
|
|
|
+ case 'int64':
|
|
|
+ case 'float16':
|
|
|
+ case 'float32':
|
|
|
+ case 'float64':
|
|
|
+ case 'vdata':
|
|
|
+ return type;
|
|
|
+ default:
|
|
|
+ if (value.vx_type !== '') {
|
|
|
+ throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
|
|
|
+ }
|
|
|
+ return '?';
|
|
|
+ }
|
|
|
+ };
|
|
|
const model = obj;
|
|
|
const args = new Map();
|
|
|
for (const const_tensor of model.const_tensor) {
|
|
|
const name = 'const_tensor:' + const_tensor.tensor_id.toString();
|
|
|
const shape = new rknn.TensorShape(const_tensor.size);
|
|
|
- const type = new rknn.TensorType(const_tensor.dtype, shape);
|
|
|
+ const type = new rknn.TensorType(dataType(const_tensor.dtype), shape);
|
|
|
const tensor = new rknn.Tensor(type, const_tensor.offset, next.value);
|
|
|
const argument = new rknn.Argument(name, type, tensor);
|
|
|
args.set(name, argument);
|
|
|
@@ -129,7 +149,7 @@ rknn.Graph = class {
|
|
|
for (const norm_tensor of model.norm_tensor) {
|
|
|
const name = 'norm_tensor:' + norm_tensor.tensor_id.toString();
|
|
|
const shape = new rknn.TensorShape(norm_tensor.size);
|
|
|
- const type = new rknn.TensorType(norm_tensor.dtype, shape);
|
|
|
+ const type = new rknn.TensorType(dataType(norm_tensor.dtype), shape);
|
|
|
const argument = new rknn.Argument(name, type, null);
|
|
|
args.set(name, argument);
|
|
|
}
|
|
|
@@ -180,7 +200,17 @@ rknn.Graph = class {
|
|
|
}
|
|
|
case 'flatbuffers': {
|
|
|
const graph = obj;
|
|
|
- const args = graph.tensors.map((tensor) => new rknn.Argument(tensor.name));
|
|
|
+ const dataTypes = [ 'unk0', '?', '?', 'int8', '?', 'int16', 'float32', 'int64', '?', '?', 'float16', '?', '?', 'unk13' ];
|
|
|
+ const args = graph.tensors.map((tensor) => {
|
|
|
+ const shape = new rknn.TensorShape(Array.from(tensor.shape));
|
|
|
+ const dataType = tensor.data_type < dataTypes.length ? dataTypes[tensor.data_type] : '?';
|
|
|
+ if (dataType === '?') {
|
|
|
+ throw new rknn.Error("Unsupported tensor data type '" + tensor.data_type + "'.");
|
|
|
+ }
|
|
|
+ const type = new rknn.TensorType(dataType, shape);
|
|
|
+ const initializer = tensor.kind !== 4 && tensor.kind !== 5 ? null : new rknn.Tensor(type, 0, null);
|
|
|
+ return new rknn.Argument(tensor.name, type, initializer);
|
|
|
+ });
|
|
|
const arg = (index) => {
|
|
|
if (index >= args.length) {
|
|
|
throw new rknn.Error("Invalid tensor index '" + index.toString() + "'.");
|
|
|
@@ -265,6 +295,9 @@ rknn.Argument = class {
|
|
|
rknn.Node = class {
|
|
|
|
|
|
constructor(metadata, type, node, arg, next) {
|
|
|
+ this._inputs = [];
|
|
|
+ this._outputs = [];
|
|
|
+ this._attributes = [];
|
|
|
switch (type) {
|
|
|
case 'json': {
|
|
|
this._name = node.name || '';
|
|
|
@@ -285,9 +318,6 @@ rknn.Node = class {
|
|
|
this._type.name = this._type.name.startsWith(prefix) ? this._type.name.substring(prefix.length) : this._type.name;
|
|
|
}
|
|
|
}
|
|
|
- this._inputs = [];
|
|
|
- this._outputs = [];
|
|
|
- this._attributes = [];
|
|
|
node.input = node.input || [];
|
|
|
for (let i = 0; i < node.input.length; ) {
|
|
|
const input = this._type && this._type.inputs && i < this._type.inputs.length ? this._type.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
|
|
|
@@ -335,23 +365,35 @@ rknn.Node = class {
|
|
|
case 'flatbuffers': {
|
|
|
this._name = node.name;
|
|
|
this._type = metadata.type(node.type);
|
|
|
- this._inputs = Array.from(node.inputs).map((input, index) => {
|
|
|
- const argument = arg(input);
|
|
|
- return new rknn.Parameter(index.toString(), [ argument ]);
|
|
|
- });
|
|
|
- this._outputs = Array.from(node.outputs).map((output, index) => {
|
|
|
- const argument = arg(output);
|
|
|
- return new rknn.Parameter(index.toString(), [ argument ]);
|
|
|
- });
|
|
|
- this._attributes = [];
|
|
|
+ if (node.inputs.length > 0) {
|
|
|
+ const inputs = this._type.inputs || (node.inputs.length === 1 ? [ { name: "input" } ] : [ { name: "inputs", list: true } ]);
|
|
|
+ if (Array.isArray(inputs) && inputs.length > 0 && inputs[0].list === true) {
|
|
|
+ this._inputs = [new rknn.Parameter(inputs[0].name, Array.from(node.inputs).map((input) => arg(input))) ];
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ this._inputs = Array.from(node.inputs).map((input, index) => {
|
|
|
+ const argument = arg(input);
|
|
|
+ return new rknn.Parameter(index < inputs.length ? inputs[index].name : index.toString(), [ argument ]);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (node.outputs.length > 0) {
|
|
|
+ const outputs = this._type.outputs || (node.outputs.length === 1 ? [ { name: "output" } ] : [ { name: "outputs", list: true } ]);
|
|
|
+ if (Array.isArray(outputs) && outputs.length > 0 && outputs[0].list === true) {
|
|
|
+ this._outputs = [ new rknn.Parameter(outputs[0].name, Array.from(node.outputs).map((output) => arg(output))) ];
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ this._outputs = Array.from(node.outputs).map((output, index) => {
|
|
|
+ const argument = arg(output);
|
|
|
+ return new rknn.Parameter(index < outputs.length ? outputs[index].name : index.toString(), [ argument ]);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
break;
|
|
|
}
|
|
|
case 'openvx': {
|
|
|
this._name = '';
|
|
|
this._type = metadata.type(node.type);
|
|
|
- this._inputs = [];
|
|
|
- this._outputs = [];
|
|
|
- this._attributes = [];
|
|
|
break;
|
|
|
}
|
|
|
default: {
|
|
|
@@ -401,6 +443,7 @@ rknn.Tensor = class {
|
|
|
|
|
|
constructor(type, offset, weights) {
|
|
|
this._type = type;
|
|
|
+ this._data = null;
|
|
|
let size = 0;
|
|
|
switch (this._type.dataType) {
|
|
|
case 'uint8': size = 1; break;
|
|
|
@@ -414,10 +457,12 @@ rknn.Tensor = class {
|
|
|
case 'vdata': size = 1; break;
|
|
|
default: throw new rknn.Error("Unsupported tensor data type '" + this._type.dataType + "'.");
|
|
|
}
|
|
|
- const shape = type.shape.dimensions;
|
|
|
- size = size * shape.reduce((a, b) => a * b, 1);
|
|
|
- if (size > 0) {
|
|
|
- this._data = weights.slice(offset, offset + size);
|
|
|
+ if (weights) {
|
|
|
+ const shape = type.shape.dimensions;
|
|
|
+ size = size * shape.reduce((a, b) => a * b, 1);
|
|
|
+ if (size > 0) {
|
|
|
+ this._data = weights.slice(offset, offset + size);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -541,26 +586,7 @@ rknn.Tensor = class {
|
|
|
rknn.TensorType = class {
|
|
|
|
|
|
constructor(dataType, shape) {
|
|
|
- const type = dataType.vx_type.startsWith('VSI_NN_TYPE_') ? dataType.vx_type.split('_').pop().toLowerCase() : dataType.vx_type;
|
|
|
- switch (type) {
|
|
|
- case 'uint8':
|
|
|
- case 'int8':
|
|
|
- case 'int16':
|
|
|
- case 'int32':
|
|
|
- case 'int64':
|
|
|
- case 'float16':
|
|
|
- case 'float32':
|
|
|
- case 'float64':
|
|
|
- case 'vdata':
|
|
|
- this._dataType = type;
|
|
|
- break;
|
|
|
- default:
|
|
|
- if (dataType.vx_type !== '') {
|
|
|
- throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
|
|
|
- }
|
|
|
- this._dataType = '?';
|
|
|
- break;
|
|
|
- }
|
|
|
+ this._dataType = dataType;
|
|
|
this._shape = shape;
|
|
|
}
|
|
|
|