|
|
@@ -5,13 +5,13 @@ var json = json || require('./json');
|
|
|
rknn.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
- return rknn.Container.open(context);
|
|
|
+ return rknn.Reader.open(context);
|
|
|
}
|
|
|
|
|
|
open(context, match) {
|
|
|
return rknn.Metadata.open(context).then((metadata) => {
|
|
|
- const container = match;
|
|
|
- return new rknn.Model(metadata, container.model, container.weights);
|
|
|
+ const reader = match;
|
|
|
+ return new rknn.Model(metadata, reader.model, reader.weights);
|
|
|
});
|
|
|
}
|
|
|
};
|
|
|
@@ -401,10 +401,15 @@ rknn.TensorType = class {
|
|
|
case 'int64':
|
|
|
case 'float16':
|
|
|
case 'float32':
|
|
|
+ case 'vdata':
|
|
|
this._dataType = type;
|
|
|
break;
|
|
|
default:
|
|
|
- throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
|
|
|
+ if (dataType.vx_type !== '') {
|
|
|
+ throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
|
|
|
+ }
|
|
|
+ this._dataType = '?';
|
|
|
+ break;
|
|
|
}
|
|
|
this._shape = shape;
|
|
|
}
|
|
|
@@ -440,74 +445,68 @@ rknn.TensorShape = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-rknn.Container = class {
|
|
|
+rknn.Reader = class {
|
|
|
|
|
|
static open(context) {
|
|
|
const stream = context.stream;
|
|
|
- const signature = [ 0x52, 0x4B, 0x4E, 0x4E, 0x00, 0x00, 0x00, 0x00 ];
|
|
|
- if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
- return new rknn.Container(stream);
|
|
|
+ if (stream.length >= 8) {
|
|
|
+ const buffer = stream.read(8);
|
|
|
+ const decoder = new TextDecoder();
|
|
|
+ const signature = decoder.decode(buffer);
|
|
|
+ if (signature === 'RKNN\0\0\0\0' || signature === 'CYPTRKNN') {
|
|
|
+ return new rknn.Reader(stream, signature);
|
|
|
+ }
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
|
|
|
- constructor(stream) {
|
|
|
- this._reader = new rknn.Container.StreamReader(stream);
|
|
|
+ constructor(stream, signature) {
|
|
|
+ this._stream = stream;
|
|
|
+ this._signature = signature;
|
|
|
}
|
|
|
|
|
|
get version() {
|
|
|
- this._read();
|
|
|
+ this._decode();
|
|
|
return this._version;
|
|
|
}
|
|
|
|
|
|
get weights() {
|
|
|
- this._read();
|
|
|
+ this._decode();
|
|
|
return this._weights;
|
|
|
}
|
|
|
|
|
|
get model() {
|
|
|
- this._read();
|
|
|
+ this._decode();
|
|
|
return this._model;
|
|
|
}
|
|
|
|
|
|
- _read() {
|
|
|
- if (this._reader) {
|
|
|
- this._reader.uint64();
|
|
|
- this._version = this._reader.uint64();
|
|
|
- this._weights = this._reader.read();
|
|
|
- const buffer = this._reader.read();
|
|
|
- const reader = json.TextReader.open(buffer);
|
|
|
+ _decode() {
|
|
|
+ if (this._stream) {
|
|
|
+ if (this._signature === 'CYPTRKNN') {
|
|
|
+ throw new rknn.Error('Invalid file content. File contains undocumented encrypted RKNN data.');
|
|
|
+ }
|
|
|
+ this._version = this._uint64();
|
|
|
+ const weights_size = this._uint64();
|
|
|
+ if (this._version > 1) {
|
|
|
+ this._stream.read(40);
|
|
|
+ }
|
|
|
+ this._weights = this._stream.read(weights_size);
|
|
|
+ const model_size = this._uint64();
|
|
|
+ const model_buffer = this._stream.read(model_size);
|
|
|
+ const reader = json.TextReader.open(model_buffer);
|
|
|
this._model = reader.read();
|
|
|
- delete this._reader;
|
|
|
+ delete this._stream;
|
|
|
}
|
|
|
}
|
|
|
-};
|
|
|
-
|
|
|
-rknn.Container.StreamReader = class {
|
|
|
|
|
|
- constructor(stream) {
|
|
|
- this._stream = stream;
|
|
|
- this._length = stream.length;
|
|
|
- this._position = 0;
|
|
|
- }
|
|
|
-
|
|
|
- skip(offset) {
|
|
|
- this._position += offset;
|
|
|
- if (this._position > this._length) {
|
|
|
- throw new rknn.Error('Expected ' + (this._position - this._length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- uint64() {
|
|
|
- this.skip(8);
|
|
|
+ _uint64() {
|
|
|
const buffer = this._stream.read(8);
|
|
|
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
return view.getUint64(0, true).toNumber();
|
|
|
}
|
|
|
|
|
|
- read() {
|
|
|
- const size = this.uint64();
|
|
|
- this.skip(size);
|
|
|
+ _read() {
|
|
|
+ const size = this._uint64();
|
|
|
return this._stream.read(size);
|
|
|
}
|
|
|
};
|
|
|
@@ -528,36 +527,35 @@ rknn.Metadata = class {
|
|
|
}
|
|
|
|
|
|
constructor(data) {
|
|
|
- this._map = new Map();
|
|
|
+ this._types = new Map();
|
|
|
+ this._attributes = new Map();
|
|
|
if (data) {
|
|
|
- const metadata = JSON.parse(data);
|
|
|
- this._map = new Map(metadata.map((item) => [ item.name, item ]));
|
|
|
+ const items = JSON.parse(data);
|
|
|
+ for (const item of items) {
|
|
|
+ this._types.set(item.name, item);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
type(name) {
|
|
|
- return this._map.has(name) ? this._map.get(name) : null;
|
|
|
+ if (!this._types.has(name)) {
|
|
|
+ this._types.set(name, { name: name });
|
|
|
+ }
|
|
|
+ return this._types.get(name);
|
|
|
}
|
|
|
|
|
|
attribute(type, name) {
|
|
|
- const schema = this.type(type);
|
|
|
- if (schema) {
|
|
|
- let attributeMap = schema.attributeMap;
|
|
|
- if (!attributeMap) {
|
|
|
- attributeMap = {};
|
|
|
- if (schema.attributes) {
|
|
|
- for (const attribute of schema.attributes) {
|
|
|
- attributeMap[attribute.name] = attribute;
|
|
|
- }
|
|
|
+ const key = type + ':' + name;
|
|
|
+ if (!this._attributes.has(key)) {
|
|
|
+ this._attributes.set(key, null);
|
|
|
+ const metadata = this.type(type);
|
|
|
+ if (metadata && Array.isArray(metadata.attributes)) {
|
|
|
+ for (const attribute of metadata.attributes) {
|
|
|
+ this._attributes.set(type + ':' + attribute.name, attribute);
|
|
|
}
|
|
|
- schema.attributeMap = attributeMap;
|
|
|
- }
|
|
|
- const attributeSchema = attributeMap[name];
|
|
|
- if (attributeSchema) {
|
|
|
- return attributeSchema;
|
|
|
}
|
|
|
}
|
|
|
- return null;
|
|
|
+ return this._attributes.get(key);
|
|
|
}
|
|
|
};
|
|
|
|