|
|
@@ -8,31 +8,28 @@ var protobuf = protobuf || require('./protobuf');
|
|
|
om.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
- const stream = context.stream;
|
|
|
- const signature = [ 0x49, 0x4D, 0x4F, 0x44 ]; // IMOD
|
|
|
- if (stream.length >= signature.length && stream.peek(4).every((value, index) => value === signature[index])) {
|
|
|
- return 'om';
|
|
|
- }
|
|
|
- return undefined;
|
|
|
+ return om.File.open(context);
|
|
|
}
|
|
|
|
|
|
- open(context) {
|
|
|
- const file = om.File.open(context);
|
|
|
+ open(context, match) {
|
|
|
+ const file = match;
|
|
|
if (!file.model) {
|
|
|
throw om.Error('File does not contain a model definition.');
|
|
|
}
|
|
|
return context.require('./om-proto').then(() => {
|
|
|
- om.proto = protobuf.get('om').ge.proto;
|
|
|
+ let model = null;
|
|
|
try {
|
|
|
+ om.proto = protobuf.get('om').ge.proto;
|
|
|
const reader = protobuf.BinaryReader.open(file.model);
|
|
|
- file.model = om.proto.ModelDef.decode(reader);
|
|
|
+ model = om.proto.ModelDef.decode(reader);
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
throw new om.Error('File format is not ge.proto.ModelDef (' + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
return om.Metadata.open(context).then((metadata) => {
|
|
|
- return new om.Model(metadata, file);
|
|
|
+ const weights = file.weights;
|
|
|
+ return new om.Model(metadata, model, weights);
|
|
|
});
|
|
|
});
|
|
|
}
|
|
|
@@ -40,10 +37,10 @@ om.ModelFactory = class {
|
|
|
|
|
|
om.Model = class {
|
|
|
|
|
|
- constructor(metadata, file) {
|
|
|
+ constructor(metadata, model, weights) {
|
|
|
this._graphs = [];
|
|
|
- const context = { metadata: metadata, weights: file.weights };
|
|
|
- for (const graph of file.model.graph) {
|
|
|
+ const context = { metadata: metadata, weights: weights };
|
|
|
+ for (const graph of model.graph) {
|
|
|
this._graphs.push(new om.Graph(context, graph));
|
|
|
}
|
|
|
}
|
|
|
@@ -479,74 +476,97 @@ om.File = class {
|
|
|
|
|
|
static open(context) {
|
|
|
const stream = context.stream;
|
|
|
- const buffer = stream.peek();
|
|
|
- const reader = new om.File.BinaryReader(buffer);
|
|
|
- return new om.File(reader);
|
|
|
+ if (stream.length >= 256) {
|
|
|
+ const signature = [ 0x49, 0x4D, 0x4F, 0x44 ]; // IMOD
|
|
|
+ if (stream.peek(4).every((value, index) => value === signature[index])) {
|
|
|
+ const reader = new om.File.BinaryReader(stream.peek());
|
|
|
+ return new om.File(reader);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
}
|
|
|
|
|
|
constructor(reader) {
|
|
|
- const decoder = new TextDecoder('utf-8');
|
|
|
- this.header = reader.read(4);
|
|
|
- const size = reader.uint32();
|
|
|
- this.version = reader.uint32();
|
|
|
- this.checksum = reader.read(64);
|
|
|
- reader.skip(4);
|
|
|
- this.is_encrypt = reader.byte();
|
|
|
- this.is_checksum = reader.byte();
|
|
|
- this.type = reader.byte(); // 0=IR model, 1=standard model, 2=OM Tiny model
|
|
|
- this.mode = reader.byte(); // 0=offline, 1=online
|
|
|
- this.name = decoder.decode(reader.read(32));
|
|
|
- this.ops = reader.uint32();
|
|
|
- this.userdefineinfo = reader.read(32);
|
|
|
- this.ir_version = reader.uint32();
|
|
|
- this.model_num = reader.uint32();
|
|
|
- this.platform_version = reader.read(20);
|
|
|
- this.platform_type = reader.byte();
|
|
|
- reader.seek(0);
|
|
|
- reader.skip(size);
|
|
|
- const partitions = new Array(reader.uint32());
|
|
|
- for (let i = 0; i < partitions.length; i++) {
|
|
|
- partitions[i] = {
|
|
|
- type: reader.uint32(),
|
|
|
- offset: reader.uint32(),
|
|
|
- size: reader.uint32()
|
|
|
- };
|
|
|
- }
|
|
|
- const offset = 256 + 4 + 12 * partitions.length;
|
|
|
- for (const partition of partitions) {
|
|
|
- reader.seek(offset + partition.offset);
|
|
|
- const buffer = reader.read(partition.size);
|
|
|
- switch (partition.type) {
|
|
|
- case 0: { // MODEL_DEF
|
|
|
- this.model = buffer;
|
|
|
- break;
|
|
|
- }
|
|
|
- case 1: { // MODEL_WEIGHT
|
|
|
- this.weights = buffer;
|
|
|
- break;
|
|
|
- }
|
|
|
- case 2: // TASK_INFO
|
|
|
- case 3: // TBE_KERNELS
|
|
|
- case 4: { // CUST_AICPU_KERNELS
|
|
|
- break;
|
|
|
- }
|
|
|
- case 5: { // DEVICE_CONFIG
|
|
|
- this.devices = new Map();
|
|
|
- const decoder = new TextDecoder('ascii');
|
|
|
- const reader = new om.File.BinaryReader(buffer);
|
|
|
- reader.uint32();
|
|
|
- for (let position = 4; position < partition.size; ) {
|
|
|
- const length = reader.uint32();
|
|
|
- const buffer = reader.read(length);
|
|
|
- const name = decoder.decode(buffer);
|
|
|
- const device = reader.uint32();
|
|
|
- this.devices.set(name, device);
|
|
|
- position += 4 + length + 4;
|
|
|
+ this._reader = reader;
|
|
|
+ }
|
|
|
+
|
|
|
+ get model() {
|
|
|
+ this._read();
|
|
|
+ return this._model;
|
|
|
+ }
|
|
|
+
|
|
|
+ get weights() {
|
|
|
+ this._read();
|
|
|
+ return this._weights;
|
|
|
+ }
|
|
|
+
|
|
|
+ _read() {
|
|
|
+ if (this._reader) {
|
|
|
+ const reader = this._reader;
|
|
|
+ delete this._reader;
|
|
|
+ const decoder = new TextDecoder('utf-8');
|
|
|
+ this.header = reader.uint32();
|
|
|
+ const size = reader.uint32();
|
|
|
+ this.version = reader.uint32();
|
|
|
+ this.checksum = reader.read(64);
|
|
|
+ reader.skip(4);
|
|
|
+ this.is_encrypt = reader.byte();
|
|
|
+ this.is_checksum = reader.byte();
|
|
|
+ this.type = reader.byte(); // 0=IR model, 1=standard model, 2=OM Tiny model
|
|
|
+ this.mode = reader.byte(); // 0=offline, 1=online
|
|
|
+ this.name = decoder.decode(reader.read(32));
|
|
|
+ this.ops = reader.uint32();
|
|
|
+ this.userdefineinfo = reader.read(32);
|
|
|
+ this.ir_version = reader.uint32();
|
|
|
+ this.model_num = reader.uint32();
|
|
|
+ this.platform_version = reader.read(20);
|
|
|
+ this.platform_type = reader.byte();
|
|
|
+ reader.seek(0);
|
|
|
+ reader.skip(size);
|
|
|
+ const partitions = new Array(reader.uint32());
|
|
|
+ for (let i = 0; i < partitions.length; i++) {
|
|
|
+ partitions[i] = {
|
|
|
+ type: reader.uint32(),
|
|
|
+ offset: reader.uint32(),
|
|
|
+ size: reader.uint32()
|
|
|
+ };
|
|
|
+ }
|
|
|
+ const offset = 256 + 4 + 12 * partitions.length;
|
|
|
+ for (const partition of partitions) {
|
|
|
+ reader.seek(offset + partition.offset);
|
|
|
+ const buffer = reader.read(partition.size);
|
|
|
+ switch (partition.type) {
|
|
|
+ case 0: { // MODEL_DEF
|
|
|
+ this._model = buffer;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 1: { // MODEL_WEIGHT
|
|
|
+ this._weights = buffer;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 2: // TASK_INFO
|
|
|
+ case 3: // TBE_KERNELS
|
|
|
+ case 4: { // CUST_AICPU_KERNELS
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 5: { // DEVICE_CONFIG
|
|
|
+ this.devices = new Map();
|
|
|
+ const decoder = new TextDecoder('ascii');
|
|
|
+ const reader = new om.File.BinaryReader(buffer);
|
|
|
+ reader.uint32();
|
|
|
+ for (let position = 4; position < partition.size; ) {
|
|
|
+ const length = reader.uint32();
|
|
|
+ const buffer = reader.read(length);
|
|
|
+ const name = decoder.decode(buffer);
|
|
|
+ const device = reader.uint32();
|
|
|
+ this.devices.set(name, device);
|
|
|
+ position += 4 + length + 4;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ default: {
|
|
|
+ throw new om.Error("Unknown partition type '" + partition.type + "'.");
|
|
|
}
|
|
|
- break;
|
|
|
- }
|
|
|
- default: {
|
|
|
- throw new om.Error("Unknown partition type '" + partition.type + "'.");
|
|
|
}
|
|
|
}
|
|
|
}
|