|
|
@@ -185,16 +185,18 @@ rknn.Argument = class {
|
|
|
rknn.Node = class {
|
|
|
|
|
|
constructor(metadata, node, args) {
|
|
|
- this._metadata = metadata;
|
|
|
this._name = node.name || '';
|
|
|
+ this._metadata = metadata.type(node.op);
|
|
|
this._type = node.op;
|
|
|
+ for (const prefix of [ 'VSI_NN_OP_', 'RKNN_OP_' ]) {
|
|
|
+ this._type = this._type.startsWith(prefix) ? this._type.substring(prefix.length) : this._type;
|
|
|
+ }
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
this._attributes = [];
|
|
|
- const schema = this._metadata.type(this._type);
|
|
|
node.input = node.input || [];
|
|
|
for (let i = 0; i < node.input.length; ) {
|
|
|
- const input = schema && schema.inputs && i < schema.inputs.length ? schema.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
|
|
|
+ const input = this._metadata && this._metadata.inputs && i < this._metadata.inputs.length ? this._metadata.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
|
|
|
const count = input.list ? node.input.length - i : 1;
|
|
|
const list = node.input.slice(i, i + count).map((input) => {
|
|
|
if (input.right_tensor) {
|
|
|
@@ -220,7 +222,7 @@ rknn.Node = class {
|
|
|
}
|
|
|
node.output = node.output || [];
|
|
|
for (let i = 0; i < node.output.length; ) {
|
|
|
- const output = schema && schema.outputs && i < schema.outputs.length ? schema.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
|
|
|
+ const output = this._metadata && this._metadata.outputs && i < this._metadata.outputs.length ? this._metadata.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
|
|
|
const count = output.list ? node.output.length - i : 1;
|
|
|
const list = node.output.slice(i, i + count).map((output) => {
|
|
|
if (output.right_tensor) {
|
|
|
@@ -261,12 +263,11 @@ rknn.Node = class {
|
|
|
}
|
|
|
|
|
|
get type() {
|
|
|
- const prefix = 'VSI_NN_OP_';
|
|
|
- return this._type.startsWith(prefix) ? this._type.substring(prefix.length) : this._type;
|
|
|
+ return this._type;
|
|
|
}
|
|
|
|
|
|
get metadata() {
|
|
|
- return this._metadata.type(this._type);
|
|
|
+ return this._metadata;
|
|
|
}
|
|
|
|
|
|
get inputs() {
|
|
|
@@ -420,14 +421,17 @@ rknn.Tensor = class {
|
|
|
rknn.TensorType = class {
|
|
|
|
|
|
constructor(dataType, shape) {
|
|
|
- switch (dataType.vx_type) {
|
|
|
- case 'VSI_NN_TYPE_UINT8': this._dataType = 'uint8'; break;
|
|
|
- case 'VSI_NN_TYPE_INT8': this._dataType = 'int8'; break;
|
|
|
- case 'VSI_NN_TYPE_INT16': this._dataType = 'int16'; break;
|
|
|
- case 'VSI_NN_TYPE_INT32': this._dataType = 'int32'; break;
|
|
|
- case 'VSI_NN_TYPE_INT64': this._dataType = 'int64'; break;
|
|
|
- case 'VSI_NN_TYPE_FLOAT16': this._dataType = 'float16'; break;
|
|
|
- case 'VSI_NN_TYPE_FLOAT32': this._dataType = 'float32'; break;
|
|
|
+ const type = dataType.vx_type.startsWith('VSI_NN_TYPE_') ? dataType.vx_type.split('_').pop().toLowerCase() : dataType.vx_type;
|
|
|
+ switch (type) {
|
|
|
+ case 'uint8':
|
|
|
+ case 'int8':
|
|
|
+ case 'int16':
|
|
|
+ case 'int32':
|
|
|
+ case 'int64':
|
|
|
+ case 'float16':
|
|
|
+ case 'float32':
|
|
|
+ this._dataType = type;
|
|
|
+ break;
|
|
|
default:
|
|
|
throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
|
|
|
}
|