Explorar el Código

Update ncnn.js (#1314)

Lutz Roeder hace 1 año
padre
commit
fe77af5aa0
Se han modificado 5 ficheros con 207 adiciones y 46 borrados
  1. 176 18
      source/ncnn.js
  2. 8 9
      source/onnx.js
  3. 7 0
      source/pytorch-metadata.json
  4. 8 10
      source/pytorch.js
  5. 8 9
      source/tengine.js

+ 176 - 18
source/ncnn.js

@@ -2,6 +2,7 @@
 import * as base from './base.js';
 
 const ncnn = {};
+const pnnx = {};
 
 // https://github.com/Tencent/ncnn/wiki/param-and-model-file-structure
 // https://github.com/Tencent/ncnn/wiki/operation-param-weight-table
@@ -23,29 +24,46 @@ ncnn.ModelFactory = class {
             }
         } else if (identifier.endsWith('.param') || identifier.endsWith('.cfg.ncnn')) {
             const reader = context.read('text', 0x10000);
-            const type = identifier.endsWith('.pnnx.param') ? 'pnnx.model' : 'ncnn.model';
             if (reader) {
+                let type = '';
                 try {
+                    let match = false;
                     const signature = reader.read('\n');
                     if (signature !== undefined) {
                         if (signature.trim() === '7767517') {
-                            context.type = type;
-                            return;
+                            match = true;
+                        } else {
+                            const header = signature.trim().split(' ');
+                            if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) {
+                                match = true;
+                            }
                         }
-                        const header = signature.trim().split(' ');
-                        if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) {
-                            context.type = type;
+                    }
+                    if (match) {
+                        type = 'ncnn.model';
+                        for (let i = 0; i < 32; i++) {
+                            const line = reader.read('\n');
+                            if (!line) {
+                                break;
+                            }
+                            if (line.startsWith('pnnx.') || line.startsWith('nn.') || line.startsWith('F.')) {
+                                type = 'pnnx.model';
+                                break;
+                            }
                         }
                     }
                 } catch {
                     // continue regardless of error
                 }
+                if (type) {
+                    context.type = type;
+                }
             }
         } else if (identifier.endsWith('.ncnn.bin')) {
             context.type = 'ncnn.weights';
         } else if (identifier.endsWith('.pnnx.bin')) {
             const entries = context.peek('zip');
-            if (entries && entries.size > 0) {
+            if (entries) { // can be empty
                 context.type = 'pnnx.weights';
                 context.target = entries;
             }
@@ -87,7 +105,12 @@ ncnn.ModelFactory = class {
     }
 
     async open(context) {
-        const metadata = await context.metadata('ncnn-metadata.json');
+        let metadata = null;
+        if (context.type.startsWith('pnnx.')) {
+            metadata = await pnnx.Metadata.open(context);
+        } else {
+            metadata = await context.metadata('ncnn-metadata.json');
+        }
         const identifier = context.identifier.toLowerCase();
         const format = context.type.split('.').shift();
         switch (context.type) {
@@ -95,7 +118,8 @@ ncnn.ModelFactory = class {
             case 'ncnn.model': {
                 let file = null;
                 if (identifier.endsWith('.param')) {
-                    file = context.identifier.replace(/\.param$/, '.bin');
+                    const extension = context.type === 'pnnx.model' && !identifier.endsWith('.pnnx.param') ? '.pnnx.bin' : '.bin';
+                    file = context.identifier.replace(/\.param$/, extension);
                 } else if (identifier.endsWith('.cfg.ncnn')) {
                     file = context.identifier.replace(/\.cfg\.ncnn$/, '.weights.ncnn');
                 }
@@ -204,8 +228,16 @@ ncnn.Graph = class {
                 const dimensions = Array.from(layer.params.values()).map((value) => isNaN(parseInt(value, 10)) ? value : parseInt(value, 10));
                 const shape = new ncnn.TensorShape(dimensions);
                 const type = new ncnn.TensorType('float32', shape);
-                const input = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
-                this.inputs.push(input);
+                const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
+                this.inputs.push(argument);
+            } else if (layer.type === 'pnnx.Input' && layer.params) {
+                const type = ncnn.Utility.route(layer.params, '0');
+                const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
+                this.inputs.push(argument);
+            } else if (layer.type === 'pnnx.Output' && layer.params) {
+                const type = ncnn.Utility.route(layer.params, '0');
+                const argument = new ncnn.Argument(layer.name, layer.inputs.map((input) => values.map(input, type)));
+                this.outputs.push(argument);
             } else {
                 const node = new ncnn.Node(metadata, format, blobs, layer, values);
                 this.nodes.push(node);
@@ -251,11 +283,20 @@ ncnn.Node = class {
         const params = layer.params;
         const inputs = layer.inputs || [];
         let inputIndex = 0;
+        const names = new Map();
+        if (params) {
+            for (const [key, value] of params) {
+                if (key.startsWith('$')) {
+                    names.set(value, key.substring(1));
+                    params.delete(key);
+                }
+            }
+        }
         if (this.type && Array.isArray(this.type.inputs)) {
             for (const input of this.type.inputs) {
                 if (inputIndex < inputs.length || input.optional === false) {
                     const count = (input.type === 'Tensor[]') ? (inputs.length - inputIndex) : 1;
-                    const list = inputs.slice(inputIndex, inputIndex + count).filter((id) => id !== '' || input.option !== 'optional').map((id) => values.map(id));
+                    const list = inputs.slice(inputIndex, inputIndex + count).filter((id) => id !== '' || input.option !== 'optional').map((id) => values.map(id, ncnn.Utility.route(params, id)));
                     const argument = new ncnn.Argument(input.name, list);
                     this.inputs.push(argument);
                     inputIndex += count;
@@ -263,8 +304,14 @@ ncnn.Node = class {
             }
         }
         this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
-            const name = ((inputIndex + index) === 0) ? 'input' : (inputIndex + index).toString();
-            return new ncnn.Argument(name, [values.map(input)]);
+            index = inputIndex + index;
+            let name = 'input';
+            if (names.has(input)) {
+                name = names.get(input);
+            } else if (index !== 0) {
+                name = index.toString();
+            }
+            return new ncnn.Argument(name, [values.map(input, ncnn.Utility.route(params, input))]);
         }));
         const outputs = layer.outputs || [];
         let outputIndex = 0;
@@ -272,7 +319,7 @@ ncnn.Node = class {
             for (const output of this.type.outputs) {
                 if (outputIndex < outputs.length || output.option !== 'optional') {
                     const count = (output.type === 'Tensor[]') ? (outputs.length - outputIndex) : 1;
-                    const list = outputs.slice(outputIndex, outputIndex + count).map((id) => values.map(id));
+                    const list = outputs.slice(outputIndex, outputIndex + count).map((id) => values.map(id, ncnn.Utility.route(params, id)));
                     const argument = new ncnn.Argument(output.name, list);
                     this.outputs.push(argument);
                     outputIndex += count;
@@ -280,8 +327,8 @@ ncnn.Node = class {
             }
         }
         this.outputs.push(...outputs.slice(outputIndex).map((output, index) => {
-            const name = ((outputIndex + index) === 0) ? 'output' : (outputIndex + index).toString();
-            return new ncnn.Argument(name, [values.map(output)]);
+            const name = (outputIndex + index) === 0 ? 'output' : (outputIndex + index).toString();
+            return new ncnn.Argument(name, [values.map(output, ncnn.Utility.route(params, output))]);
         }));
         blobs.weight = (name, shape, code) => {
             const blob = blobs.load(shape, code || 0);
@@ -623,7 +670,20 @@ ncnn.Node = class {
             }
         }
         if (params && params.size > 0) {
-            const attributes = this.type && Array.isArray(this.type.attributes) ? this.type.attributes : [];
+            for (const [key, signature] of params) {
+                if (key.startsWith('@')) {
+                    const name = key.substring(1);
+                    const identifier = `${this.name}.${name}`;
+                    const data = blobs.entry(identifier);
+                    const type = ncnn.Utility.type(signature);
+                    const tensor = new ncnn.Tensor(type, data, null);
+                    const value = new ncnn.Value(identifier, null, tensor);
+                    const argument = new ncnn.Argument(name, [value]);
+                    this.inputs.push(argument);
+                    params.delete(key);
+                }
+            }
+            const attributes = Array.isArray(this.type.attributes) ? this.type.attributes : [];
             for (const [index, obj] of params) {
                 const metadata = attributes[index];
                 let name = index;
@@ -659,6 +719,20 @@ ncnn.Node = class {
                         }
                     }
                 }
+                if (!type && typeof value === 'string') {
+                    if (value === 'True') {
+                        value = true;
+                    } else if (value === 'False') {
+                        value = false;
+                    } else if (Number.isInteger(Number(value))) {
+                        value = Number(value);
+                    } else if (value.length > 3 && value.startsWith('(') && value.endsWith(')')) {
+                        const list = value.substring(1, value.length - 1).split(',').map((item) => Number(item.trim()));
+                        if (list.every((item) => Number.isInteger(item))) {
+                            value = list.map((item) => parseInt(item, 10));
+                        }
+                    }
+                }
                 const argument = new ncnn.Argument(name, value, type, visible);
                 this.attributes.push(argument);
             }
@@ -731,6 +805,24 @@ ncnn.Utility = class {
         }
         return value;
     }
+
+    static type(signature) {
+        const match = signature.match(/\(([^)]+)\)(\w+)/);
+        const shape = new ncnn.TensorShape(match[1].split(',').map((v) => parseInt(v, 10)));
+        const dataTypes = new Map([['f32', 'float32'], ['f16', 'float16']]);
+        const dataType = dataTypes.get(match[2]) || match[2];
+        return new ncnn.TensorType(dataType, shape);
+    }
+
+    static route(params, id) {
+        const key = `#${id}`;
+        if (params && params.has(key)) {
+            const signature = params.get(key);
+            params.delete(key);
+            return ncnn.Utility.type(signature);
+        }
+        return null;
+    }
 };
 
 ncnn.TextParamReader = class {
@@ -919,6 +1011,72 @@ ncnn.BlobReader = class {
             throw new ncnn.Error('Invalid weights data size.');
         }
     }
+
+    entry(identifier) {
+        if (this._entires && this._entires.has(identifier)) {
+            const reader = this._entires.get(identifier);
+            return reader.peek();
+        }
+        return null;
+    }
+};
+
+pnnx.Metadata = class {
+
+    static async open(context) {
+        if (!pnnx.Metadata._metadata) {
+            let data = null;
+            try {
+                data = await context.request('pytorch-metadata.json');
+            } catch {
+                // continue regardless of error
+            }
+            pnnx.Metadata._metadata = new pnnx.Metadata(data);
+        }
+        return pnnx.Metadata._metadata;
+    }
+
+    constructor(data) {
+        this._types = new Map();
+        this._attributes = new Map();
+        this._index = new Map();
+        if (data) {
+            const items = JSON.parse(data);
+            for (const item of items) {
+                item.name = item.name.replace(/^torch\.nn\.modules\.(\w)+\./, 'nn.');
+                item.name = item.name.replace(/aten::([a-z_]+)(\.\w+)?/g, (match, p1) => `torch.${p1}`);
+                this._types.set(item.name, { name: item.name, category: item.category });
+            }
+        }
+    }
+
+    type(name) {
+        if (!this._types.has(name)) {
+            this._types.set(name, { name: name.toString() });
+        }
+        return this._types.get(name);
+    }
+
+    attribute(type, name) {
+        const key = `${type}:${name}`;
+        if (!this._attributes.has(key)) {
+            this._attributes.set(key, null);
+            const metadata = this.type(type);
+            if (metadata) {
+                if (metadata.inputs) {
+                    for (const input of metadata.inputs) {
+                        this._attributes.set(`${type}:${input.name}`, input);
+                    }
+                }
+                if (metadata.attributes) {
+                    for (const attribute of metadata.attributes) {
+                        this._attributes.set(`${type}:${attribute.name}`, attribute);
+                    }
+                }
+            }
+        }
+        return this._attributes.get(key);
+    }
 };
 
 ncnn.Error = class extends Error {

+ 8 - 9
source/onnx.js

@@ -912,17 +912,16 @@ onnx.Context.Model = class {
 onnx.Metadata = class {
 
     static async open(context) {
-        if (onnx.Metadata._metadata) {
-            return onnx.Metadata._metadata;
-        }
-        try {
-            const data = await context.request('onnx-metadata.json');
+        if (!onnx.Metadata._metadata) {
+            let data = null;
+            try {
+                data = await context.request('onnx-metadata.json');
+            } catch {
+                // continue regardless of error
+            }
             onnx.Metadata._metadata = new onnx.Metadata(data);
-            return onnx.Metadata._metadata;
-        } catch {
-            onnx.Metadata._metadata = new onnx.Metadata(null);
-            return onnx.Metadata._metadata;
         }
+        return onnx.Metadata._metadata;
     }
 
     constructor(data) {

+ 7 - 0
source/pytorch-metadata.json

@@ -2853,6 +2853,7 @@
   },
   {
     "name": "aten::cat.names_out",
+    "category": "Tensor",
     "inputs": [
       { "name": "tensors", "type": "Tensor[]" },
       { "name": "dim", "type": "Dimname" }
@@ -2863,6 +2864,7 @@
   },
   {
     "name": "aten::cat.out",
+    "category": "Tensor",
     "inputs": [
       { "name": "tensors", "type": "Tensor[]" },
       { "name": "dim", "type": "int64", "default": 0 }
@@ -5770,6 +5772,7 @@
   },
   {
     "name": "aten::gather",
+    "category": "Transform",
     "inputs": [
       { "name": "self", "type": "Tensor" },
       { "name": "dim", "type": "int64" },
@@ -5782,6 +5785,7 @@
   },
   {
     "name": "aten::gather.dimname",
+    "category": "Transform",
     "inputs": [
       { "name": "self", "type": "Tensor" },
       { "name": "dim", "type": "Dimname" },
@@ -5794,6 +5798,7 @@
   },
   {
     "name": "aten::gather.dimname_out",
+    "category": "Transform",
     "inputs": [
       { "name": "self", "type": "Tensor" },
       { "name": "dim", "type": "Dimname" },
@@ -5806,6 +5811,7 @@
   },
   {
     "name": "aten::gather.out",
+    "category": "Transform",
     "inputs": [
       { "name": "self", "type": "Tensor" },
       { "name": "dim", "type": "int64" },
@@ -12961,6 +12967,7 @@
   },
   {
     "name": "aten::squeeze.dims",
+    "category": "Transform",
     "inputs": [
       { "name": "self", "type": "Tensor" },
       { "name": "dim", "type": "int64[]" }

+ 8 - 10
source/pytorch.js

@@ -3954,17 +3954,15 @@ pytorch.nnapi.Metadata = class {
 pytorch.Metadata = class {
 
     static async open(context) {
-        if (pytorch.Metadata._metadata) {
-            return pytorch.Metadata._metadata;
-        }
-        try {
-            const data = await context.request('pytorch-metadata.json');
-            pytorch.Metadata._metadata = new pytorch.Metadata(data);
-            return pytorch.Metadata._metadata;
-        } catch {
-            pytorch.Metadata._metadata = new pytorch.Metadata(null);
-            return pytorch.Metadata._metadata;
+        if (!pytorch.Metadata._metadata) {
+            try {
+                const data = await context.request('pytorch-metadata.json');
+                pytorch.Metadata._metadata = new pytorch.Metadata(data);
+            } catch {
+                pytorch.Metadata._metadata = new pytorch.Metadata(null);
+            }
         }
+        return pytorch.Metadata._metadata;
     }
 
     constructor(data) {

+ 8 - 9
source/tengine.js

@@ -187,17 +187,16 @@ tengine.TensorShape = class {
 tengine.Metadata = class {
 
     static async open(context) {
-        if (tengine.Metadata._metadata) {
-            return tengine.Metadata._metadata;
-        }
-        try {
-            const data = await context.request('tengine-metadata.json');
+        if (!tengine.Metadata._metadata) {
+            let data = null;
+            try {
+                data = await context.request('tengine-metadata.json');
+            } catch {
+                // continue regardless of error
+            }
             tengine.Metadata._metadata = new tengine.Metadata(data);
-            return tengine.Metadata._metadata;
-        } catch {
-            tengine.Metadata._metadata = new tengine.Metadata(null);
-            return tengine.Metadata._metadata;
         }
+        return tengine.Metadata._metadata;
     }
 
     constructor(data) {