Lutz Roeder пре 3 година
родитељ
комит
838061f88e
1 измењених фајлова са 331 додато и 263 уклоњено
  1. 331 263
      source/kmodel.js

+ 331 - 263
source/kmodel.js

@@ -369,22 +369,24 @@ kmodel.Attribute = class {
 kmodel.Reader = class {
 
     static open(stream) {
-        const reader = new base.BinaryReader(stream);
-        if (reader.length > 4) {
-            const signature = reader.uint32();
-            if (signature === 3) {
-                return new kmodel.Reader(reader, 3);
+        if (stream && stream.length >= 4) {
+            const length = Math.min(8, stream.length);
+            const buffer = stream.peek(length);
+            if ([ 0x03, 0x00, 0x00, 0x00 ].every((value, index) => value === buffer[index])) {
+                return new kmodel.Reader(stream, 3);
             }
-            if (signature === 0x4B4D444C) {
+            if ([ 0x4C, 0x44, 0x4D, 0x4B ].every((value, index) => value === buffer[index]) && buffer.length >= 8) {
+                const reader = new base.BinaryReader(buffer);
+                reader.skip(4);
                 const version = reader.uint32();
-                return new kmodel.Reader(reader, version);
+                return new kmodel.Reader(stream, version);
             }
         }
         return null;
     }
 
-    constructor(reader, version) {
-        this._reader = reader;
+    constructor(stream, version) {
+        this._stream = stream;
         this._version = version;
         this._modules = [];
     }
@@ -399,8 +401,7 @@ kmodel.Reader = class {
     }
 
     _read() {
-        if (this._reader) {
-            const reader = this._reader;
+        if (this._stream) {
             if (this._version < 3 || this._version > 5) {
                 throw new kmodel.Error("Unsupported model version '" + this.version.toString() + "'.");
             }
@@ -408,60 +409,9 @@ kmodel.Reader = class {
             const register = (type, name, category, callback) => {
                 types.set(type, { type: { name: name, category: category || '' }, callback: callback });
             };
-            reader.uint64_bits = function(fields) {
-                const buffer = reader.read(8);
-                fields = Object.entries(fields);
-                fields.push([ null, Math.min(64, fields[fields.length - 1][1] + 56)]);
-                const obj = {};
-                for (let i = 0; i < fields.length - 1; i++) {
-                    const key = fields[i][0];
-                    let value = 0;
-                    let position = fields[i][1];
-                    const end = fields[i + 1][1];
-                    while (position < end) {
-                        const offset = (position / 8) >> 0;
-                        const start = (position & 7);
-                        const count = Math.min((offset + 1) * 8, end) - position;
-                        value = value | ((buffer[offset] >>> start) & ((1 << count) - 1)) << (position - fields[i][1]);
-                        position += count;
-                    }
-                    obj[key] = value;
-                }
-                return obj;
-            };
             switch (this._version) {
                 case 3: {
-                    reader.kpu_model_header_t = function() {
-                        return {
-                            flags: reader.uint32(),
-                            arch: reader.uint32(),
-                            layers_length: reader.uint32(),
-                            max_start_address: reader.uint32(),
-                            main_mem_usage: reader.uint32(),
-                            output_count: reader.uint32()
-                        };
-                    };
-                    reader.kpu_model_output_t = function(name) {
-                        return {
-                            address: [ this.parameter(name) ],
-                            size: reader.uint32()
-                        };
-                    };
-                    reader.kpu_model_layer_header_t = function() {
-                        return {
-                            type: reader.uint32(),
-                            body_size: reader.uint32()
-                        };
-                    };
-                    reader.argument = function(memory_type) {
-                        memory_type = memory_type || 'main';
-                        const address = this.uint32();
-                        return { name: memory_type + ':' + address.toString() };
-                    };
-                    reader.parameter = function(name, memory_type) {
-                        const argument = this.argument(memory_type);
-                        return { name: name, arguments: [ argument ] };
-                    };
+                    const reader = new kmodel.BinaryReader.v3(this._stream);
                     const model_header = reader.kpu_model_header_t();
                     const layers = new Array(model_header.layers_length);
                     const outputs = new Array(model_header.output_count);
@@ -735,6 +685,7 @@ kmodel.Reader = class {
                     break;
                 }
                 case 4: {
+                    const reader = new kmodel.BinaryReader.v4(this._stream);
                     const model_header = {
                         flags: reader.uint32(),
                         target: reader.uint32(), // 0=CPU, 1=K210
@@ -745,87 +696,6 @@ kmodel.Reader = class {
                         outputs: reader.uint32(),
                         reserved0: reader.uint32(),
                     };
-                    reader.memory_types = [ 'const', 'main', 'kpu' ];
-                    reader.memory_type_t = function() {
-                        const value = this.uint32();
-                        return this.memory_types[value];
-                    };
-                    reader.datatypes = [ 'float32', 'uint8' ];
-                    reader.datatype_t = function() {
-                        const value = this.uint32();
-                        return this.datatypes[value];
-                    };
-                    reader.memory_range = function() {
-                        return {
-                            memory_type: this.memory_type_t(),
-                            datatype: this.datatype_t(),
-                            start: this.uint32(),
-                            size: this.uint32()
-                        };
-                    };
-                    reader.argument = function() {
-                        const memory = this.memory_range();
-                        const value = {
-                            name: memory.memory_type + ':' + memory.start.toString(),
-                            datatype: memory.datatype
-                        };
-                        if (memory.memory_type === 'const') {
-                            value.data = constants.slice(memory.start, memory.start + memory.size);
-                            switch (value.datatype) {
-                                case 'uint8': value.shape = [ value.data.length ]; break;
-                                case 'float32': value.shape = [ value.data.length >> 2 ]; break;
-                                default: break;
-                            }
-                        }
-                        return value;
-                    };
-                    reader.parameter = function(name) {
-                        const argument = this.argument();
-                        return { name: name, arguments: [ argument ] };
-                    };
-                    reader.runtime_shape_t = function() {
-                        return [ reader.uint32(), reader.uint32(), reader.uint32(), reader.uint32() ];
-                    };
-                    reader.padding = function() {
-                        return { before: reader.int32(), after: reader.int32() };
-                    };
-                    reader.runtime_paddings_t = function() {
-                        return [ this.padding(), this.padding(), this.padding(), this.padding() ];
-                    };
-                    reader.scalar = function() {
-                        return {
-                            datatype_t: reader.uint32(),
-                            storage: reader.read(4)
-                        };
-                    };
-                    reader.kpu_activate_table_t = function() {
-                        const value = {};
-                        value.activate_para = new Array(16);
-                        for (let i = 0; i < 16; i++) {
-                            value.activate_para[i] = this.uint64_bits({ shift_number: 0, y_mul: 8, x_start: 24, reserved: 60 });
-                            delete value.activate_para[i].reserved;
-                        }
-                        for (let i = 0; i < 16; i++) {
-                            value.activate_para[i].bias = reader.int8();
-                        }
-                        return value;
-                    };
-                    reader.unary_op_t = function() {
-                        const value = reader.uint32();
-                        return [ 'abs', 'ceil', 'cos', 'exp', 'floor', 'log', 'neg', 'rsqrt', 'sin', 'square' ][value];
-                    };
-                    reader.binary_op_t = function() {
-                        const value = reader.uint32();
-                        return [ 'add', 'sub', 'mul', 'div', 'min', 'max' ][value];
-                    };
-                    reader.reduce_op_t = function() {
-                        const value = reader.uint32();
-                        return [ 'mean', 'min', 'max', 'sum' ][value];
-                    };
-                    reader.image_resize_mode_t = function() {
-                        const value = reader.uint32();
-                        return [ 'bilinear', 'nearest_neighbor' ][value];
-                    };
                     const inputs = new Array(model_header.inputs);
                     for (let i = 0; i < inputs.length; i++) {
                         inputs[i] = reader.parameter('input' + (i == 0 ? '' : (i + 1).toString()));
@@ -837,7 +707,7 @@ kmodel.Reader = class {
                     for (let i = 0; i < outputs.length; i++) {
                         outputs[i] = reader.parameter('output' + (i == 0 ? '' : (i + 1).toString()));
                     }
-                    const constants = reader.read(model_header.constants);
+                    reader.constants(model_header.constants);
                     const layers = new Array(model_header.nodes);
                     for (let i = 0; i < layers.length; i++) {
                         layers[i] = {
@@ -1100,123 +970,7 @@ kmodel.Reader = class {
                     break;
                 }
                 case 5: {
-                    reader.model_header = function() {
-                        return {
-                            header_size: reader.uint32(),
-                            flags: reader.uint32(),
-                            alignment: reader.uint32(),
-                            modules: reader.uint32(),
-                            entry_module: reader.uint32(),
-                            entry_function: reader.uint32()
-                        };
-                    };
-                    reader.module_type_t = function() {
-                        const buffer = reader.read(16);
-                        const decoder = new TextDecoder('ascii');
-                        const text = decoder.decode(buffer);
-                        return text.replace(/\0.*$/, '');
-                    };
-                    reader.module_header = function() {
-                        return {
-                            type: reader.module_type_t(),
-                            version: reader.uint32(),
-                            header_size: reader.uint32(),
-                            size: reader.uint32(),
-                            mempools: reader.uint32(),
-                            shared_mempools: reader.uint32(),
-                            sections: reader.uint32(),
-                            functions: reader.uint32(),
-                            reserved0: reader.uint32()
-                        };
-                    };
-                    reader.mempool_desc = function() {
-                        return {
-                            location: reader.byte(),
-                            reserved0: reader.read(3),
-                            size: reader.uint32()
-                        };
-                    };
-                    reader.section_header = function() {
-                        const buffer = reader.read(16);
-                        const decoder = new TextDecoder('ascii');
-                        const name = decoder.decode(buffer);
-                        return {
-                            name: name.replace(/\0.*$/, ''),
-                            flags: reader.uint32(),
-                            body_start: reader.uint32(),
-                            body_size: reader.uint32(),
-                            reserved0: reader.uint32()
-                        };
-                    };
-                    reader.function_header = function() {
-                        return {
-                            header_size: reader.uint32(),
-                            size: reader.uint32(),
-                            input_pool_size: reader.uint32(),
-                            output_pool_size: reader.uint32(),
-                            inputs: reader.uint32(),
-                            outputs: reader.uint32(),
-                            entrypoint: reader.uint32(),
-                            text_size: reader.uint32()
-                        };
-                    };
-                    reader.memory_locations = new Map([ [ 0, 'input' ], [ 1, 'output' ], [ 2, 'rdata' ], [ 3, 'data' ], [ 4, 'shared_data' ], [ 64, 'kpu' ] ]);
-                    reader.memory_location_t = function() {
-                        const value = this.byte();
-                        if (!this.memory_locations.has(value)) {
-                            throw new kmodel.Error("Unsupported memory location '" + value + "'.");
-                        }
-                        return this.memory_locations.get(value);
-                    };
-                    reader.datatypes = [ 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float16', 'float32', 'float64', 'bfloat16' ];
-                    reader.datatype_t = function() {
-                        const value = this.byte();
-                        return this.datatypes[value];
-                    };
-                    reader.memory_range = function() {
-                        return {
-                            memory_location: this.memory_location_t(),
-                            datatype: this.datatype_t(),
-                            shared_module: this.uint16(),
-                            start: this.uint32(),
-                            size: this.uint32()
-                        };
-                    };
-                    reader.argument = function() {
-                        const memory = this.memory_range();
-                        const value = {
-                            name: memory.memory_location + ':' + memory.start.toString(),
-                            datatype: memory.datatype
-                        };
-                        /*
-                        if (memory.memory_type === 'const') {
-                            value.data = constants.slice(memory.start, memory.start + memory.size);
-                            switch (value.datatype) {
-                                case 'uint8': value.shape = [ value.data.length ]; break;
-                                case 'float32': value.shape = [ value.data.length >> 2 ]; break;
-                                default: break;
-                            }
-                        }
-                        */
-                        return value;
-                    };
-                    reader.parameter = function(name) {
-                        const argument = this.argument();
-                        return { name: name, arguments: [ argument ] };
-                    };
-                    reader.shape = function() {
-                        const array = new Array(reader.uint32());
-                        for (let i = 0; i < array.length; i++) {
-                            array[i] = reader.uint32();
-                        }
-                        return array;
-                    };
-                    reader.align_position = function(alignment) {
-                        const remainder = this._position % alignment;
-                        if (remainder !== 0) {
-                            this.skip(alignment - remainder);
-                        }
-                    };
+                    const reader = new kmodel.BinaryReader.v5(this._stream);
                     const model_header = reader.model_header();
                     if (model_header.header_size < 32) {
                         throw new kmodel.Error("Invalid header size '" + model_header.header_size + "'.");
@@ -1318,7 +1072,321 @@ kmodel.Reader = class {
                     throw new kmodel.Error("Unsupported model version '" + this.version.toString() + "'.");
                 }
             }
-            delete this._reader;
+            delete this._stream;
+        }
+    }
+};
+
+kmodel.BinaryReader = class extends base.BinaryReader {
+
+    uint64_bits(fields) {
+        const buffer = this.read(8);
+        fields = Object.entries(fields);
+        fields.push([ null, Math.min(64, fields[fields.length - 1][1] + 56)]);
+        const obj = {};
+        for (let i = 0; i < fields.length - 1; i++) {
+            const key = fields[i][0];
+            let value = 0;
+            let position = fields[i][1];
+            const end = fields[i + 1][1];
+            while (position < end) {
+                const offset = (position / 8) >> 0;
+                const start = (position & 7);
+                const count = Math.min((offset + 1) * 8, end) - position;
+                value = value | ((buffer[offset] >>> start) & ((1 << count) - 1)) << (position - fields[i][1]);
+                position += count;
+            }
+            obj[key] = value;
+        }
+        return obj;
+    }
+};
+
+kmodel.BinaryReader.v3 = class extends kmodel.BinaryReader {
+
+    constructor(buffer) {
+        super(buffer);
+        this.skip(4);
+    }
+
+    kpu_model_header_t() {
+        return {
+            flags: this.uint32(),
+            arch: this.uint32(),
+            layers_length: this.uint32(),
+            max_start_address: this.uint32(),
+            main_mem_usage: this.uint32(),
+            output_count: this.uint32()
+        };
+    }
+
+    kpu_model_output_t(name) {
+        return {
+            address: [ this.parameter(name) ],
+            size: this.uint32()
+        };
+    }
+
+    kpu_model_layer_header_t() {
+        return {
+            type: this.uint32(),
+            body_size: this.uint32()
+        };
+    }
+
+    argument(memory_type) {
+        memory_type = memory_type || 'main';
+        const address = this.uint32();
+        return { name: memory_type + ':' + address.toString() };
+    }
+
+    parameter(name, memory_type) {
+        const argument = this.argument(memory_type);
+        return { name: name, arguments: [ argument ] };
+    }
+};
+
+kmodel.BinaryReader.v4 = class extends kmodel.BinaryReader {
+
+    constructor(buffer) {
+        super(buffer);
+        this.skip(8);
+        this._memory_types = [ 'const', 'main', 'kpu' ];
+        this._datatypes = [ 'float32', 'uint8' ];
+    }
+
+    memory_type_t() {
+        const value = this.uint32();
+        return this._memory_types[value];
+    }
+
+    datatype_t() {
+        const value = this.uint32();
+        return this._datatypes[value];
+    }
+
+    memory_range() {
+        return {
+            memory_type: this.memory_type_t(),
+            datatype: this.datatype_t(),
+            start: this.uint32(),
+            size: this.uint32()
+        };
+    }
+
+    argument() {
+        const memory = this.memory_range();
+        const value = {
+            name: memory.memory_type + ':' + memory.start.toString(),
+            datatype: memory.datatype
+        };
+        if (memory.memory_type === 'const') {
+            value.data = this._constants.slice(memory.start, memory.start + memory.size);
+            switch (value.datatype) {
+                case 'uint8': value.shape = [ value.data.length ]; break;
+                case 'float32': value.shape = [ value.data.length >> 2 ]; break;
+                default: break;
+            }
+        }
+        return value;
+    }
+
+    parameter(name) {
+        const argument = this.argument();
+        return { name: name, arguments: [ argument ] };
+    }
+
+    runtime_shape_t() {
+        return [ this.uint32(), this.uint32(), this.uint32(), this.uint32() ];
+    }
+
+    padding() {
+        return { before: this.int32(), after: this.int32() };
+    }
+
+    runtime_paddings_t() {
+        return [ this.padding(), this.padding(), this.padding(), this.padding() ];
+    }
+
+    scalar() {
+        return {
+            datatype_t: this.uint32(),
+            storage: this.read(4)
+        };
+    }
+
+    kpu_activate_table_t() {
+        const value = {};
+        value.activate_para = new Array(16);
+        for (let i = 0; i < 16; i++) {
+            value.activate_para[i] = this.uint64_bits({ shift_number: 0, y_mul: 8, x_start: 24, reserved: 60 });
+            delete value.activate_para[i].reserved;
+        }
+        for (let i = 0; i < 16; i++) {
+            value.activate_para[i].bias = this.int8();
+        }
+        return value;
+    }
+
+    unary_op_t() {
+        const value = this.uint32();
+        return [ 'abs', 'ceil', 'cos', 'exp', 'floor', 'log', 'neg', 'rsqrt', 'sin', 'square' ][value];
+    }
+
+    binary_op_t() {
+        const value = this.uint32();
+        return [ 'add', 'sub', 'mul', 'div', 'min', 'max' ][value];
+    }
+
+    reduce_op_t() {
+        const value = this.uint32();
+        return [ 'mean', 'min', 'max', 'sum' ][value];
+    }
+
+    image_resize_mode_t() {
+        const value = this.uint32();
+        return [ 'bilinear', 'nearest_neighbor' ][value];
+    }
+
+    constants(size) {
+        this._constants = this.read(size);
+    }
+};
+
+kmodel.BinaryReader.v5 = class extends kmodel.BinaryReader {
+
+    constructor(buffer) {
+        super(buffer);
+        this.skip(8);
+        this._datatypes = [ 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float16', 'float32', 'float64', 'bfloat16' ];
+        this._memory_locations = new Map([ [ 0, 'input' ], [ 1, 'output' ], [ 2, 'rdata' ], [ 3, 'data' ], [ 4, 'shared_data' ], [ 64, 'kpu' ] ]);
+    }
+
+    model_header() {
+        return {
+            header_size: this.uint32(),
+            flags: this.uint32(),
+            alignment: this.uint32(),
+            modules: this.uint32(),
+            entry_module: this.uint32(),
+            entry_function: this.uint32()
+        };
+    }
+
+    module_type_t() {
+        const buffer = this.read(16);
+        const decoder = new TextDecoder('ascii');
+        const text = decoder.decode(buffer);
+        return text.replace(/\0.*$/, '');
+    }
+
+    module_header() {
+        return {
+            type: this.module_type_t(),
+            version: this.uint32(),
+            header_size: this.uint32(),
+            size: this.uint32(),
+            mempools: this.uint32(),
+            shared_mempools: this.uint32(),
+            sections: this.uint32(),
+            functions: this.uint32(),
+            reserved0: this.uint32()
+        };
+    }
+
+    mempool_desc() {
+        return {
+            location: this.byte(),
+            reserved0: this.read(3),
+            size: this.uint32()
+        };
+    }
+
+    section_header() {
+        const buffer = this.read(16);
+        const decoder = new TextDecoder('ascii');
+        const name = decoder.decode(buffer);
+        return {
+            name: name.replace(/\0.*$/, ''),
+            flags: this.uint32(),
+            body_start: this.uint32(),
+            body_size: this.uint32(),
+            reserved0: this.uint32()
+        };
+    }
+
+    function_header() {
+        return {
+            header_size: this.uint32(),
+            size: this.uint32(),
+            input_pool_size: this.uint32(),
+            output_pool_size: this.uint32(),
+            inputs: this.uint32(),
+            outputs: this.uint32(),
+            entrypoint: this.uint32(),
+            text_size: this.uint32()
+        };
+    }
+
+    memory_location_t() {
+        const value = this.byte();
+        if (!this._memory_locations.has(value)) {
+            throw new kmodel.Error("Unsupported memory location '" + value + "'.");
+        }
+        return this._memory_locations.get(value);
+    }
+
+    datatype_t() {
+        const value = this.byte();
+        return this._datatypes[value];
+    }
+
+    memory_range() {
+        return {
+            memory_location: this.memory_location_t(),
+            datatype: this.datatype_t(),
+            shared_module: this.uint16(),
+            start: this.uint32(),
+            size: this.uint32()
+        };
+    }
+
+    argument() {
+        const memory = this.memory_range();
+        const value = {
+            name: memory.memory_location + ':' + memory.start.toString(),
+            datatype: memory.datatype
+        };
+        /*
+        if (memory.memory_type === 'const') {
+            value.data = constants.slice(memory.start, memory.start + memory.size);
+            switch (value.datatype) {
+                case 'uint8': value.shape = [ value.data.length ]; break;
+                case 'float32': value.shape = [ value.data.length >> 2 ]; break;
+                default: break;
+            }
+        }
+        */
+        return value;
+    }
+
+    parameter(name) {
+        const argument = this.argument();
+        return { name: name, arguments: [ argument ] };
+    }
+
+    shape() {
+        const array = new Array(this.uint32());
+        for (let i = 0; i < array.length; i++) {
+            array[i] = this.uint32();
+        }
+        return array;
+    }
+
+    align_position(alignment) {
+        const remainder = this._position % alignment;
+        if (remainder !== 0) {
+            this.skip(alignment - remainder);
         }
     }
 };