Răsfoiți Sursa

Update kmodel.js

Lutz Roeder 4 ani în urmă
părinte
comite
9131a2e4de
1 a modificat fișierele cu 38 adăugiri și 7 ștergeri
  1. 38 7
      source/kmodel.js

+ 38 - 7
source/kmodel.js

@@ -37,6 +37,7 @@ kmodel.Graph = class {
     constructor(model) {
         this._inputs = [];
         this._outputs = [];
+        this._nodes = [];
         const scopes = new Map();
         let index = 0;
         for (const layer of model.layers) {
@@ -54,7 +55,25 @@ kmodel.Graph = class {
             }
             index++;
         }
-        this._nodes = model.layers.map((layer) => new kmodel.Node(layer));
+        for (const layer of model.layers) {
+            if (layer.type.name === 'INPUT') {
+                for (const input of layer.outputs) {
+                    this._inputs.push(new kmodel.Parameter('input', input.arguments.map((argument) => {
+                        return new kmodel.Argument(argument.name);
+                    })));
+                }
+                continue;
+            }
+            if (layer.type.name === 'OUTPUT') {
+                for (const output of layer.inputs) {
+                    this._outputs.push(new kmodel.Parameter(output.name, output.arguments.map((argument) => {
+                        return new kmodel.Argument(argument.name);
+                    })));
+                }
+                continue;
+            }
+            this._nodes.push(new kmodel.Node(layer));
+        }
     }
 
     get inputs() {
@@ -305,9 +324,9 @@ kmodel.Reader = class {
                             output_count: reader.uint32()
                         };
                     };
-                    reader.kpu_model_output_t = function() {
+                    reader.kpu_model_output_t = function(name) {
                         return {
-                            address: reader.uint32(),
+                            address: this.mem_address('main', name),
                             size: reader.uint32()
                         };
                     };
@@ -346,9 +365,9 @@ kmodel.Reader = class {
                     };
                     const model_header = reader.kpu_model_header_t();
                     this._layers = new Array(model_header.layers_length);
-                    this._outputs = new Array(model_header.output_count);
-                    for (let i = 0; i < this._outputs.length; i++) {
-                        this._outputs[i] = reader.kpu_model_output_t();
+                    const outputs = new Array(model_header.output_count);
+                    for (let i = 0; i < model_header.output_count; i++) {
+                        outputs[i] = reader.kpu_model_output_t('output' + (i > 0 ? i.toString() : ''));
                     }
                     for (let i = 0; i < this._layers.length; i++) {
                         this._layers[i] = reader.kpu_model_layer_header_t();
@@ -509,7 +528,7 @@ kmodel.Reader = class {
                     register(10242, 'K210_REMOVE_PADDING', '', (layer, reader) => {
                         layer.flags = reader.uint32();
                         layer.inputs = reader.mem_address('main', 'inputs');
-                        layer.outputs = reader.mem_address(layer.flags & 1 ? 'main' : 'kpu', 'outputs');
+                        layer.outputs = reader.mem_address('main', 'outputs');
                         layer.channels = reader.uint32();
                     });
                     register(10243, 'K210_UPLOAD', '', (layer, reader) => {
@@ -535,6 +554,18 @@ kmodel.Reader = class {
                         delete layer.body_size;
                         // console.log(JSON.stringify(Object.fromEntries(Object.entries(layer).filter((entry) => !(entry[1] instanceof Uint8Array))), null, 2));
                     }
+                    if (this._layers.length > 0) {
+                        this._layers.push({
+                            type: { name: 'INPUT' },
+                            outputs: this._layers[0].inputs
+                        });
+                    }
+                    for (const output of outputs) {
+                        this._layers.push({
+                            type: { name: 'OUTPUT' },
+                            inputs: output.address
+                        });
+                    }
                     break;
                 }
                 case 4: {