Lutz Roeder пре 3 година
родитељ
комит
c54f7b68fa
2 измењених фајлова са 156 додато и 44 уклоњено
  1. 149 42
      source/kmodel.js
  2. 7 2
      test/models.json

+ 149 - 42
source/kmodel.js

@@ -9,10 +9,7 @@ kmodel.ModelFactory = class {
     }
 
     open(context, match) {
-        return Promise.resolve().then(() => {
-            const reader = match;
-            return new kmodel.Model(reader);
-        });
+        return Promise.resolve().then(() => new kmodel.Model(match));
     }
 };
 
@@ -20,7 +17,7 @@ kmodel.Model = class {
 
     constructor(model) {
         this._format = 'kmodel v' + model.version.toString();
-        this._graphs = [ new kmodel.Graph(model) ];
+        this._graphs = model.modules.map((module) => new kmodel.Graph(module));
     }
 
     get format() {
@@ -34,13 +31,15 @@ kmodel.Model = class {
 
 kmodel.Graph = class {
 
-    constructor(model) {
+    constructor(module) {
+        this._name = module.name || '';
+        this._type = module.type || '';
         this._inputs = [];
         this._outputs = [];
         this._nodes = [];
         const scopes = new Map();
         let index = 0;
-        for (const layer of model.layers) {
+        for (const layer of module.layers) {
             for (const input of layer.inputs || []) {
                 for (const argument of input.arguments) {
                     argument.name = scopes.has(argument.name) ? scopes.get(argument.name) : argument.name;
@@ -55,7 +54,7 @@ kmodel.Graph = class {
             }
             index++;
         }
-        for (const layer of model.layers) {
+        for (const layer of module.layers) {
             switch (layer.type.name) {
                 case 'INPUT':
                 case 'input': {
@@ -82,6 +81,14 @@ kmodel.Graph = class {
         }
     }
 
+    get name() {
+        return this._name;
+    }
+
+    get type() {
+        return this._type;
+    }
+
     get inputs() {
         return this._inputs;
     }
@@ -305,7 +312,8 @@ kmodel.Node = class {
         }
         for (const output of layer.outputs || []) {
             this._outputs.push(new kmodel.Parameter(output.name, output.arguments.map((argument) => {
-                return new kmodel.Argument(argument.name);
+                const type = argument.shape ? new kmodel.TensorType(argument.datatype || '?', argument.shape) : null;
+                return new kmodel.Argument(argument.name, type);
             })));
         }
         for (const chain of layer.chain || []) {
@@ -378,15 +386,16 @@ kmodel.Reader = class {
     constructor(reader, version) {
         this._reader = reader;
         this._version = version;
+        this._modules = [];
     }
 
     get version() {
         return this._version;
     }
 
-    get layers() {
+    get modules() {
         this._read();
-        return this._layers;
+        return this._modules;
     }
 
     _read() {
@@ -454,17 +463,17 @@ kmodel.Reader = class {
                         return { name: name, arguments: [ argument ] };
                     };
                     const model_header = reader.kpu_model_header_t();
-                    this._layers = new Array(model_header.layers_length);
+                    const layers = new Array(model_header.layers_length);
                     const outputs = new Array(model_header.output_count);
                     for (let i = 0; i < model_header.output_count; i++) {
                         outputs[i] = reader.kpu_model_output_t('output' + (i > 0 ? i.toString() : ''));
                     }
-                    for (let i = 0; i < this._layers.length; i++) {
-                        this._layers[i] = reader.kpu_model_layer_header_t();
-                        this._layers[i].location = i;
+                    for (let i = 0; i < layers.length; i++) {
+                        layers[i] = reader.kpu_model_layer_header_t();
+                        layers[i].location = i;
                     }
                     let offset = reader.position;
-                    for (const layer of this._layers) {
+                    for (const layer of layers) {
                         layer.offset = offset;
                         offset += layer.body_size;
                     }
@@ -669,7 +678,7 @@ kmodel.Reader = class {
                         layer.inputs[0].arguments[0].shape = shape;
                         layer.outputs[0].arguments[0].shape = shape;
                     });
-                    for (const layer of this._layers) {
+                    for (const layer of layers) {
                         const type = types.get(layer.type);
                         if (!type) {
                             throw new kmodel.Error("Unsupported version '" + this._version.toString() + "' layer type '" + layer.type.toString() + "'.");
@@ -683,18 +692,22 @@ kmodel.Reader = class {
                         delete layer.offset;
                         delete layer.body_size;
                     }
-                    if (this._layers.length > 0) {
-                        this._layers.unshift({
+                    if (layers.length > 0) {
+                        layers.unshift({
                             type: { name: 'input' },
-                            outputs: [ this._layers[0].inputs[0] ]
+                            outputs: [ layers[0].inputs[0] ]
                         });
                     }
                     for (const output of outputs) {
-                        this._layers.push({
+                        layers.push({
                             type: { name: 'output' },
                             inputs: output.address
                         });
                     }
+                    this._modules.push({
+                        name: '',
+                        layers: layers
+                    });
                     break;
                 }
                 case 4: {
@@ -708,13 +721,15 @@ kmodel.Reader = class {
                         outputs: reader.uint32(),
                         reserved0: reader.uint32(),
                     };
+                    reader.memory_types = [ 'const', 'main', 'kpu' ];
                     reader.memory_type_t = function() {
                         const value = this.uint32();
-                        return [ 'const', 'main', 'kpu' ][value];
+                        return this.memory_types[value];
                     };
+                    reader.datatypes = [ 'float32', 'uint8' ];
                     reader.datatype_t = function() {
                         const value = this.uint32();
-                        return [ 'float32', 'uint8' ][value];
+                        return this.datatypes[value];
                     };
                     reader.memory_range = function() {
                         return {
@@ -795,16 +810,16 @@ kmodel.Reader = class {
                         outputs[i] = reader.parameter('output' + (i == 0 ? '' : (i + 1).toString()));
                     }
                     const constants = reader.read(model_header.constants);
-                    this._layers = new Array(model_header.nodes);
-                    for (let i = 0; i < this._layers.length; i++) {
-                        this._layers[i] = {
+                    const layers = new Array(model_header.nodes);
+                    for (let i = 0; i < layers.length; i++) {
+                        layers[i] = {
                             location: i,
                             opcode: reader.uint32(),
                             body_size: reader.uint32()
                         };
                     }
                     let offset = reader.position;
-                    for (const layer of this._layers) {
+                    for (const layer of layers) {
                         layer.offset = offset;
                         offset += layer.body_size;
                     }
@@ -1009,7 +1024,7 @@ kmodel.Reader = class {
                             } ]
                         });
                     });
-                    for (const layer of this._layers) {
+                    for (const layer of layers) {
                         const type = types.get(layer.opcode);
                         if (!type) {
                             throw new kmodel.Error("Unsupported version '" + this._version.toString() + "' layer type '" + layer.type.toString() + "'.");
@@ -1030,17 +1045,21 @@ kmodel.Reader = class {
                         }
                     }
                     for (const input of inputs) {
-                        this._layers.unshift({
+                        layers.unshift({
                             type: { name: 'INPUT' },
                             outputs: [ input ]
                         });
                     }
                     for (const output of outputs) {
-                        this._layers.push({
+                        layers.push({
                             type: { name: 'OUTPUT' },
                             inputs: [ output ]
                         });
                     }
+                    this._modules.push({
+                        name: '',
+                        layers: layers
+                    });
                     break;
                 }
                 case 5: {
@@ -1057,7 +1076,8 @@ kmodel.Reader = class {
                     reader.module_type_t = function() {
                         const buffer = reader.read(16);
                         const decoder = new TextDecoder('ascii');
-                        return decoder.decode(buffer);
+                        const text = decoder.decode(buffer);
+                        return text.replace(/\0.*$/, '');
                     };
                     reader.module_header = function() {
                         return {
@@ -1080,9 +1100,11 @@ kmodel.Reader = class {
                         };
                     };
                     reader.section_header = function() {
+                        const buffer = reader.read(16);
                         const decoder = new TextDecoder('ascii');
+                        const name = decoder.decode(buffer);
                         return {
-                            name: decoder.decode(reader.read(16)),
+                            name: name.replace(/\0.*$/, ''),
                             flags: reader.uint32(),
                             body_start: reader.uint32(),
                             body_size: reader.uint32(),
@@ -1101,15 +1123,50 @@ kmodel.Reader = class {
                             text_size: reader.uint32()
                         };
                     };
+                    reader.memory_locations = new Map([ [ 0, 'input' ], [ 1, 'output' ], [ 2, 'rdata' ], [ 3, 'data' ], [ 4, 'shared_data' ], [ 64, 'kpu' ] ]);
+                    reader.memory_location_t = function() {
+                        const value = this.byte();
+                        if (!this.memory_locations.has(value)) {
+                            throw new kmodel.Error("Unsupported memory location '" + value + "'.");
+                        }
+                        return this.memory_locations.get(value);
+                    };
+                    reader.datatypes = [ 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float16', 'float32', 'float64', 'bfloat16' ];
+                    reader.datatype_t = function() {
+                        const value = this.byte();
+                        return this.datatypes[value];
+                    };
                     reader.memory_range = function() {
                         return {
-                            memory_type: this.byte(), // 0=const, 1=main, 2=k210_kpu
-                            datatype: this.byte(),
+                            memory_location: this.memory_location_t(),
+                            datatype: this.datatype_t(),
                             shared_module: this.uint16(),
                             start: this.uint32(),
                             size: this.uint32()
                         };
                     };
+                    reader.argument = function() {
+                        const memory = this.memory_range();
+                        const value = {
+                            name: memory.memory_location + ':' + memory.start.toString(),
+                            datatype: memory.datatype
+                        };
+                        /*
+                        if (memory.memory_type === 'const') {
+                            value.data = constants.slice(memory.start, memory.start + memory.size);
+                            switch (value.datatype) {
+                                case 'uint8': value.shape = [ value.data.length ]; break;
+                                case 'float32': value.shape = [ value.data.length >> 2 ]; break;
+                                default: break;
+                            }
+                        }
+                        */
+                        return value;
+                    };
+                    reader.parameter = function(name) {
+                        const argument = this.argument();
+                        return { name: name, arguments: [ argument ] };
+                    };
                     reader.shape = function() {
                         const array = new Array(reader.uint32());
                         for (let i = 0; i < array.length; i++) {
@@ -1124,9 +1181,13 @@ kmodel.Reader = class {
                         }
                     };
                     const model_header = reader.model_header();
+                    if (model_header.header_size < 32) {
+                        throw new kmodel.Error("Invalid header size '" + model_header.header_size + "'.");
+                    }
                     if (model_header.header_size > reader.position) {
                         reader.skip(model_header.header_size - reader.position);
                     }
+                    delete model_header.header_size;
                     this._modules = new Array(model_header.modules);
                     for (let i = 0; i < this._modules.length; i++) {
                         const start = reader.position;
@@ -1142,31 +1203,77 @@ kmodel.Reader = class {
                         for (let i = 0; i < shared_mempools.length; i++) {
                             shared_mempools[i] = reader.mempool_desc();
                         }
+                        const function_headers = new Array(module_header.functions);
                         const functions = new Array(module_header.functions);
                         for (let i = 0; i < functions.length; i++) {
+                            const position = reader.position;
                             const function_header = reader.function_header();
+                            const header_size = reader.position - position;
+                            if (function_header.header_size > header_size) {
+                                reader.skip(function_header.header_size - header_size);
+                            }
                             const inputs = new Array(function_header.inputs);
                             for (let i = 0; i < inputs.length; i++) {
-                                inputs[i] = reader.memory_range();
+                                inputs[i] = reader.parameter('input' + (i == 0 ? '' : (i + 1).toString()));
                             }
                             for (let i = 0; i < inputs.length; i++) {
-                                inputs[i].shape = reader.shape();
+                                inputs[i].arguments[0].shape = reader.shape();
                             }
                             const outputs = new Array(function_header.outputs);
                             for (let i = 0; i < outputs.length; i++) {
-                                outputs[i] = reader.memory_range();
+                                outputs[i] = reader.parameter('output' + (i == 0 ? '' : (i + 1).toString()));
                             }
                             for (let i = 0; i < outputs.length; i++) {
-                                outputs[i].shape = reader.shape();
+                                outputs[i].arguments[0].shape = reader.shape();
                             }
                             reader.align_position(8);
+                            const size = reader.size - position;
+                            if (function_header.size > size) {
+                                reader.skip(function_header.size - size);
+                            }
+                            function_headers[i] = function_header;
+                            functions[i] = {
+                                type: { name: 'Unknown' },
+                                inputs: inputs,
+                                outputs: outputs
+                            };
+                        }
+                        const sections = new Map();
+                        for (let i = 0; i < module_header.sections; i++) {
+                            const section_header = reader.section_header();
+                            reader.skip(section_header.body_start);
+                            const body = reader.read(section_header.body_size);
+                            const section = {
+                                reader: new base.BinaryReader(body),
+                                flags: section_header.flags
+                            };
+                            reader.align_position(8);
+                            sections.set(section_header.name, section);
                         }
-                        const sections = new Array(module_header.sections);
-                        for (let i = 0; i < sections.length; i++) {
-                            sections[i] = reader.section_header();
+                        for (let i = 0; i < function_headers.length; i++) {
+                            const function_header = function_headers[i];
+                            const reader = sections.get('.text').reader;
+                            reader.seek(function_header.entrypoint);
+                            function_header.text = reader.read(function_header.text_size);
+                            const layer = functions[i];
+                            switch (module_header.type) {
+                                case 'stackvm':
+                                    layer.type = { name: 'stackvm' };
+                                    break;
+                                case 'k210':
+                                    break;
+                                default:
+                                    throw new kmodel.Error("Unsupported module type '" + module_header.type + "'.");
+                            }
                         }
+                        const name = this._modules.length > 1 ? i.toString() : '';
+                        this._modules[i] = {
+                            name: name,
+                            type: module_header.type,
+                            layers: functions
+                        };
                     }
-                    throw new kmodel.Error("Unsupported model version '" + this.version.toString() + "'.");
+                    break;
                 }
                 default: {
                     throw new kmodel.Error("Unsupported model version '" + this.version.toString() + "'.");

+ 7 - 2
test/models.json

@@ -2361,7 +2361,6 @@
     "target":   "mobilenet_v2.kmodel",
     "source":   "https://github.com/lutzroeder/netron/files/7965168/mobilenet_v2.kmodel.zip[mobilenet_v2.kmodel]",
     "format":   "kmodel v5",
-    "error":    "Unsupported model version '5' in 'mobilenet_v2.kmodel'.",
     "link":     "https://github.com/lutzroeder/netron/issues/871"
   },
   {
@@ -2376,7 +2375,6 @@
     "target":   "mnist.kmodel",
     "source":   "https://github.com/lutzroeder/netron/files/8111312/mnist.kmodel.zip[mnist.kmodel]",
     "format":   "kmodel v5",
-    "error":    "Unsupported model version '5' in 'mnist.kmodel'.",
     "link":     "https://github.com/lutzroeder/netron/issues/871"
   },
   {
@@ -2400,6 +2398,13 @@
     "format":   "kmodel v4",
     "link":     "https://github.com/lutzroeder/netron/issues/871"
   },
+  {
+    "type":     "kmodel",
+    "target":   "yolox_nano_224.kmodel",
+    "source":   "https://github.com/lutzroeder/netron/files/8695383/yolox_nano_224.kmodel.zip[yolox_nano_224.kmodel]",
+    "format":   "kmodel v5",
+    "link":     "https://github.com/lutzroeder/netron/issues/871"
+  },
   {
     "type":     "lasagne",
     "target":   "net2.pkl",