Lutz Roeder 4 роки тому
батько
коміт
07ddb0b105
2 змінених файлів з 99 додано та 79 видалено
  1. 98 78
      source/om.js
  2. 1 1
      source/view.js

+ 98 - 78
source/om.js

@@ -8,31 +8,28 @@ var protobuf = protobuf || require('./protobuf');
 om.ModelFactory = class {
 
     match(context) {
-        const stream = context.stream;
-        const signature = [ 0x49, 0x4D, 0x4F, 0x44 ]; // IMOD
-        if (stream.length >= signature.length && stream.peek(4).every((value, index) => value === signature[index])) {
-            return 'om';
-        }
-        return undefined;
+        return om.File.open(context);
     }
 
-    open(context) {
-        const file = om.File.open(context);
+    open(context, match) {
+        const file = match;
         if (!file.model) {
             throw om.Error('File does not contain a model definition.');
         }
         return context.require('./om-proto').then(() => {
-            om.proto = protobuf.get('om').ge.proto;
+            let model = null;
             try {
+                om.proto = protobuf.get('om').ge.proto;
                 const reader = protobuf.BinaryReader.open(file.model);
-                file.model = om.proto.ModelDef.decode(reader);
+                model = om.proto.ModelDef.decode(reader);
             }
             catch (error) {
                 const message = error && error.message ? error.message : error.toString();
                 throw new om.Error('File format is not ge.proto.ModelDef (' + message.replace(/\.$/, '') + ').');
             }
             return om.Metadata.open(context).then((metadata) => {
-                return new om.Model(metadata, file);
+                const weights = file.weights;
+                return new om.Model(metadata, model, weights);
             });
         });
     }
@@ -40,10 +37,10 @@ om.ModelFactory = class {
 
 om.Model = class {
 
-    constructor(metadata, file) {
+    constructor(metadata, model, weights) {
         this._graphs = [];
-        const context = { metadata: metadata, weights: file.weights };
-        for (const graph of file.model.graph) {
+        const context = { metadata: metadata, weights: weights };
+        for (const graph of model.graph) {
             this._graphs.push(new om.Graph(context, graph));
         }
     }
@@ -479,74 +476,97 @@ om.File = class {
 
     static open(context) {
         const stream = context.stream;
-        const buffer = stream.peek();
-        const reader = new om.File.BinaryReader(buffer);
-        return new om.File(reader);
+        if (stream.length >= 256) {
+            const signature = [ 0x49, 0x4D, 0x4F, 0x44 ]; // IMOD
+            if (stream.peek(4).every((value, index) => value === signature[index])) {
+                const reader = new om.File.BinaryReader(stream.peek());
+                return new om.File(reader);
+            }
+        }
+        return null;
     }
 
     constructor(reader) {
-        const decoder = new TextDecoder('utf-8');
-        this.header = reader.read(4);
-        const size = reader.uint32();
-        this.version = reader.uint32();
-        this.checksum = reader.read(64);
-        reader.skip(4);
-        this.is_encrypt = reader.byte();
-        this.is_checksum = reader.byte();
-        this.type = reader.byte(); // 0=IR model, 1=standard model, 2=OM Tiny model
-        this.mode = reader.byte(); // 0=offline, 1=online
-        this.name = decoder.decode(reader.read(32));
-        this.ops = reader.uint32();
-        this.userdefineinfo = reader.read(32);
-        this.ir_version = reader.uint32();
-        this.model_num = reader.uint32();
-        this.platform_version = reader.read(20);
-        this.platform_type = reader.byte();
-        reader.seek(0);
-        reader.skip(size);
-        const partitions = new Array(reader.uint32());
-        for (let i = 0; i < partitions.length; i++) {
-            partitions[i] = {
-                type: reader.uint32(),
-                offset: reader.uint32(),
-                size: reader.uint32()
-            };
-        }
-        const offset = 256 + 4 + 12 * partitions.length;
-        for (const partition of partitions) {
-            reader.seek(offset + partition.offset);
-            const buffer = reader.read(partition.size);
-            switch (partition.type) {
-                case 0: { // MODEL_DEF
-                    this.model = buffer;
-                    break;
-                }
-                case 1: { // MODEL_WEIGHT
-                    this.weights = buffer;
-                    break;
-                }
-                case 2: // TASK_INFO
-                case 3: // TBE_KERNELS
-                case 4: { // CUST_AICPU_KERNELS
-                    break;
-                }
-                case 5: { // DEVICE_CONFIG
-                    this.devices = new Map();
-                    const decoder = new TextDecoder('ascii');
-                    const reader = new om.File.BinaryReader(buffer);
-                    reader.uint32();
-                    for (let position = 4; position < partition.size; ) {
-                        const length = reader.uint32();
-                        const buffer = reader.read(length);
-                        const name = decoder.decode(buffer);
-                        const device = reader.uint32();
-                        this.devices.set(name, device);
-                        position += 4 + length + 4;
+        this._reader = reader;
+    }
+
+    get model() {
+        this._read();
+        return this._model;
+    }
+
+    get weights() {
+        this._read();
+        return this._weights;
+    }
+
+    _read() {
+        if (this._reader) {
+            const reader = this._reader;
+            delete this._reader;
+            const decoder = new TextDecoder('utf-8');
+            this.header = reader.uint32();
+            const size = reader.uint32();
+            this.version = reader.uint32();
+            this.checksum = reader.read(64);
+            reader.skip(4);
+            this.is_encrypt = reader.byte();
+            this.is_checksum = reader.byte();
+            this.type = reader.byte(); // 0=IR model, 1=standard model, 2=OM Tiny model
+            this.mode = reader.byte(); // 0=offline, 1=online
+            this.name = decoder.decode(reader.read(32));
+            this.ops = reader.uint32();
+            this.userdefineinfo = reader.read(32);
+            this.ir_version = reader.uint32();
+            this.model_num = reader.uint32();
+            this.platform_version = reader.read(20);
+            this.platform_type = reader.byte();
+            reader.seek(0);
+            reader.skip(size);
+            const partitions = new Array(reader.uint32());
+            for (let i = 0; i < partitions.length; i++) {
+                partitions[i] = {
+                    type: reader.uint32(),
+                    offset: reader.uint32(),
+                    size: reader.uint32()
+                };
+            }
+            const offset = 256 + 4 + 12 * partitions.length;
+            for (const partition of partitions) {
+                reader.seek(offset + partition.offset);
+                const buffer = reader.read(partition.size);
+                switch (partition.type) {
+                    case 0: { // MODEL_DEF
+                        this._model = buffer;
+                        break;
+                    }
+                    case 1: { // MODEL_WEIGHT
+                        this._weights = buffer;
+                        break;
+                    }
+                    case 2: // TASK_INFO
+                    case 3: // TBE_KERNELS
+                    case 4: { // CUST_AICPU_KERNELS
+                        break;
+                    }
+                    case 5: { // DEVICE_CONFIG
+                        this.devices = new Map();
+                        const decoder = new TextDecoder('ascii');
+                        const reader = new om.File.BinaryReader(buffer);
+                        reader.uint32();
+                        for (let position = 4; position < partition.size; ) {
+                            const length = reader.uint32();
+                            const buffer = reader.read(length);
+                            const name = decoder.decode(buffer);
+                            const device = reader.uint32();
+                            this.devices.set(name, device);
+                            position += 4 + length + 4;
+                        }
+                        break;
+                    }
+                    default: {
+                        throw new om.Error("Unknown partition type '" + partition.type + "'.");
                     }
-                    break;
-                }
-                default: {
-                    throw new om.Error("Unknown partition type '" + partition.type + "'.");
                 }
             }
         }

+ 1 - 1
source/view.js

@@ -1477,7 +1477,7 @@ view.ModelFactoryService = class {
         this.register('./mlnet', [ '.zip' ]);
         this.register('./acuity', [ '.json' ]);
         this.register('./imgdnn', [ '.dnn', 'params', '.json' ]);
-        this.register('./om', [ '.om' ]);
+        this.register('./om', [ '.om', '.onnx', '.pb' ]);
         this.register('./nb', [ '.nb' ]);
     }