Explorar el Código

Update ncnn test file (#296)

Lutz Roeder hace 1 año
padre
commit
dbbe0edf6f
Se han modificado 3 ficheros con 123 adiciones y 71 borrados
  1. 27 0
      source/ncnn-metadata.json
  2. 95 70
      source/ncnn.js
  3. 1 1
      test/models.json

+ 27 - 0
source/ncnn-metadata.json

@@ -125,6 +125,33 @@
       { "name": "dynamic_weight", "type": "int32", "default": 0 }
     ]
   },
+  {
+    "name": "Deconvolution1D",
+    "category": "Layer",
+    "attributes": [
+      { "name": "num_output", "type": "int32", "default": 0 },
+      { "name": "kernel_w", "type": "int32", "default": 0 },
+      { "name": "dilation_w", "type": "int32", "default": 1 },
+      { "name": "stride_w", "type": "int32", "default": 1 },
+      { "name": "pad_left", "type": "int32", "default": 0 },
+      { "name": "bias_term", "default": 0, "visible": false },
+      { "name": "weight_data_size", "type": "int32", "default": 0, "visible": false },
+      { "name": "" },
+      { "name": "" },
+      { "name": "activation_type", "default": 0 },
+      { "name": "activation_params", "default": [] },
+      { "name": "" },
+      { "name": "" },
+      { "name": "" },
+      { "name": "" },
+      { "name": "pad_right", "type": "int32", "default": 0 },
+      { "name": "" },
+      { "name": "" },
+      { "name": "output_pad_right", "type": "int32", "default": 0 },
+      { "name": "" },
+      { "name": "output_w", "type": "int32", "default": 0 }
+    ]
+  },
   {
     "name": "Convolution3D",
     "identifier": 84,

+ 95 - 70
source/ncnn.js

@@ -99,29 +99,29 @@ ncnn.ModelFactory = class {
                 } else if (identifier.endsWith('.cfg.ncnn')) {
                     file = context.identifier.replace(/\.cfg\.ncnn$/, '.weights.ncnn');
                 }
-                let buffer = null;
+                let content = null;
                 try {
-                    const content = await context.fetch(file);
-                    buffer = content.stream.peek();
+                    content = await context.fetch(file);
                 } catch {
                     // continue regardless of error
                 }
                 const param = context.read('text');
                 const reader = new ncnn.TextParamReader(param);
-                return new ncnn.Model(metadata, format, reader, buffer);
+                const blobs = new ncnn.BlobReader(content);
+                return new ncnn.Model(metadata, format, reader, blobs);
             }
             case 'ncnn.model.bin': {
                 const bin = `${context.identifier.substring(0, context.identifier.length - 10)}.bin`;
-                let buffer = null;
+                let content = null;
                 try {
-                    const content = await context.fetch(bin);
-                    buffer = content.stream.peek();
+                    content = await context.fetch(bin);
                 } catch {
                     // continue regardless of error
                 }
                 const param = context.stream.peek();
                 const reader = new ncnn.BinaryParamReader(param);
-                return new ncnn.Model(metadata, format, reader, buffer);
+                const blobs = new ncnn.BlobReader(content);
+                return new ncnn.Model(metadata, format, reader, blobs);
             }
             case 'pnnx.weights':
             case 'ncnn.weights': {
@@ -141,8 +141,8 @@ ncnn.ModelFactory = class {
                     const param = content.stream.peek();
                     reader = new ncnn.BinaryParamReader(param);
                 }
-                const buffer = context.stream.peek();
-                return new ncnn.Model(metadata, format, reader, buffer);
+                const blobs = new ncnn.BlobReader(context);
+                return new ncnn.Model(metadata, format, reader, blobs);
             }
             default: {
                 throw new ncnn.Error(`Unsupported ncnn format '${context.type}'.`);
@@ -153,19 +153,18 @@ ncnn.ModelFactory = class {
 
 ncnn.Model = class {
 
-    constructor(metadata, format, param, bin) {
+    constructor(metadata, format, param, blobs) {
         this.format = format === 'pnnx' ? 'PNNX' : 'ncnn';
-        this.graphs = [new ncnn.Graph(metadata, format, param, bin)];
+        this.graphs = [new ncnn.Graph(metadata, format, param, blobs)];
     }
 };
 
 ncnn.Graph = class {
 
-    constructor(metadata, format, param, bin) {
+    constructor(metadata, format, param, blobs) {
         this.inputs = [];
         this.outputs = [];
         this.nodes = [];
-        const blobs = new ncnn.BlobReader(bin);
         const layers = param.layers;
         const values = new Map();
         values.map = (name, type, tensor) => {
@@ -212,7 +211,7 @@ ncnn.Graph = class {
                 this.nodes.push(node);
             }
         }
-        // blobs.validate();
+        blobs.validate();
     }
 };
 
@@ -385,6 +384,28 @@ ncnn.Node = class {
             }
             case 'Convolution1D':
             case 'ConvolutionDepthWise1D': {
+                const num_output = parseInt(params.get('0') || 0, 10);
+                const kernel_w = parseInt(params.get('1') || 0, 10);
+                const dynamic_weight = parseInt(params.get('19') || 0, 10);
+                if (!dynamic_weight) {
+                    const weight_data_size = parseInt(params.get('6') || 0, 10);
+                    blobs.weight('weight', [num_output, weight_data_size / (num_output * kernel_w), kernel_w]);
+                    if (parseInt(params.get('5') || 0, 10) === 1) {
+                        blobs.weight('bias', [num_output], 1);
+                    }
+                    params.delete('6');
+                }
+                params.delete('19');
+                const activation_names = ['', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish'];
+                const activation_type = parseInt(params.get('9') || 0, 10);
+                if (activation_type > 0 && activation_type < activation_names.length) {
+                    const layer = { type: activation_names[activation_type] };
+                    const node = new ncnn.Node(metadata, format, blobs, layer, values);
+                    this.chain.push(node);
+                }
+                break;
+            }
+            case 'Deconvolution1D': {
                 const activation_names = ['', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish'];
                 const activation_type = parseInt(params.get('9') || 0, 10);
                 if (activation_type > 0 && activation_type < activation_names.length) {
@@ -394,7 +415,7 @@ ncnn.Node = class {
                 }
                 const num_output = parseInt(params.get('0') || 0, 10);
                 const kernel_w = parseInt(params.get('1') || 0, 10);
-                const dynamic_weight = parseInt(params.get('19') || 0, 10);
+                const dynamic_weight = parseInt(params.get('28') || 0, 10);
                 if (!dynamic_weight) {
                     const weight_data_size = parseInt(params.get('6') || 0, 10);
                     blobs.weight('weight', [num_output, weight_data_size / (num_output * kernel_w), kernel_w]);
@@ -403,16 +424,11 @@ ncnn.Node = class {
                     }
                     params.delete('6');
                 }
+                params.delete('28');
                 break;
             }
             case 'Convolution3D':
             case 'ConvolutionDepthWise3D': {
-                const activation_names = ['', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish'];
-                const activation_type = parseInt(params.get('9') || 0, 10);
-                if (activation_type > 0 && activation_type < activation_names.length) {
-                    const layer = { type: activation_names[activation_type] };
-                    this.chain.push(new ncnn.Node(metadata, format, blobs, layer, values));
-                }
                 const num_output = parseInt(params.get('0') || 0, 10);
                 const kernel_w = parseInt(params.get('1') || 0, 10);
                 const kernel_h = parseInt(params.get('11') || kernel_w, 10);
@@ -423,6 +439,12 @@ ncnn.Node = class {
                     blobs.weight('bias', [num_output], 1);
                 }
                 params.delete('6');
+                const activation_names = ['', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish'];
+                const activation_type = parseInt(params.get('9') || 0, 10);
+                if (activation_type > 0 && activation_type < activation_names.length) {
+                    const layer = { type: activation_names[activation_type] };
+                    this.chain.push(new ncnn.Node(metadata, format, blobs, layer, values));
+                }
                 break;
             }
             case 'Quantize': {
@@ -808,9 +830,16 @@ ncnn.BinaryParamReader = class {
 
 ncnn.BlobReader = class {
 
-    constructor(buffer) {
-        this._buffer = buffer;
-        this._position = 0;
+    constructor(context) {
+        if (context) {
+            this._identifier = context.identifier;
+            if (this._identifier.toLowerCase().endsWith('.pnnx.bin')) {
+                this._entires = context.peek('zip');
+            } else {
+                this._buffer = context.stream.peek();
+                this._position = 0;
+            }
+        }
     }
 
     skip(length) {
@@ -837,61 +866,57 @@ ncnn.BlobReader = class {
         if (!this._buffer) {
             return null;
         }
-        let dataType = null;
+        const size = shape.reduce((a, b) => a * b, 1);
         if (type === 0) {
             const buffer = this.read(4);
             const [f0, f1, f2, f3] = buffer;
-            const type = f0 | f1 << 8 | f2 << 16 | f3 << 24;
+            const flag = f0 | f1 << 8 | f2 << 16 | f3 << 24;
             // https://github.com/Tencent/ncnn/blob/master/src/modelbin.cpp
-            switch (type) {
-                case 0x00000000: dataType = 'float32'; break;
-                case 0x01306b47: dataType = 'float16'; break;
-                case 0x000d4b38: dataType = 'int8'; break;
-                case 0x00000001: dataType = 'qint8'; break;
-                case 0x0002C056: throw new ncnn.Error("Unsupported weight type '0x0002C056'."); // size * sizeof(float) - raw data with extra scaling
-                default: {
-                    const size = shape.reduce((a, b) => a * b, 1);
-                    const buffer = this.read(1024);
-                    const quantization = {
-                        type: 'lookup',
-                        value: Array.from(new Float32Array(buffer.buffer, buffer.bufferOffset, buffer.length / 4))
-                    };
-                    const data = this.read(size);
-                    return { dataType: 'uint8', data, quantization };
-                }
-            }
-        } else if (type === 1) {
-            dataType = 'float32';
-        } else {
-            throw new ncnn.Error(`Load type '${type}' not supported.`);
-        }
-        if (!shape) {
-            this._buffer = null;
-        }
-        let data = null;
-        if (this._buffer) {
-            if (dataType) {
-                const dataTypes = new Map([['float32', 4], ['float16', 2], ['int8', 1], ['qint8', 1]]);
-                if (!dataTypes.has(dataType)) {
-                    throw new ncnn.Error(`Unsupported weight type '${dataType}'.`);
-                }
-                const itemsize = dataTypes.get(dataType);
-                const size = shape.reduce((a, b) => a * b, 1) * itemsize;
-                if (dataType === 'qint8') {
-                    this.skip(size + 1024);
-                    data = null;
-                } else {
-                    data = this.read(size);
-                }
+            if (flag === 0x01306B47) { // float16
+                const data = this.read(size * 2);
+                this.align(4);
+                return { dataType: 'float16', data };
+            } else if (flag === 0x000D4B38) { // int8
+                const data = this.read(size);
                 this.align(4);
+                return { dataType: 'int8', data };
+            } else if (flag === 0x00000001) { // qint8
+                // this.skip(size + 1024);
+                // data = null;
+                // return { dataType: 'qint8', data };
+                throw new ncnn.Error("Unsupported weight type '0x00000001'.");
+            } else if (flag === 0x0002C056) {
+                // size * sizeof(float) - raw data with extra scaling
+                throw new ncnn.Error("Unsupported weight type '0x0002C056'.");
+            } else if (flag === 0x00000000) { // float32
+                const data = this.read(size * 4);
+                return { dataType: 'float32', data };
+            } else {
+                const size = shape.reduce((a, b) => a * b, 1);
+                const buffer = this.read(1024);
+                const quantization = {
+                    type: 'lookup',
+                    value: Array.from(new Float32Array(buffer.buffer, buffer.bufferOffset, buffer.length / 4))
+                };
+                const data = this.read(size);
+                this.align(4);
+                return { dataType: 'uint8', data, quantization };
             }
+        } else if (type === 1) {
+            const data = this.read(size * 4);
+            return { dataType: 'float32', data };
         }
-        return { dataType, data };
+        throw new ncnn.Error(`Load type '${type}' not supported.`);
     }
 
     validate() {
-        if (this._buffer && this._position !== this._buffer.length) {
-            throw new ncnn.Error('Invalid buffer.');
+        const files = [
+            ['encoder_jit_trace-pnnx.ncnn.bin', 139191256]
+        ];
+        if (this._buffer && this._buffer.length !== this._position &&
+            !this._identifier.toLowerCase().endsWith('.pnnx.bin') &&
+            !files.find((file) => file[0] === this._identifier && file[1] === this._buffer.length)) {
+            throw new ncnn.Error('Invalid weights data size.');
         }
     }
 };

+ 1 - 1
test/models.json

@@ -3400,7 +3400,7 @@
     "format":   "ncnn",
     "assert":   "model.graphs[0].nodes[845].inputs[1].value[0].quantization.type == 'lookup'",
     "tags":     "quantization,skip-render",
-    "link":     "https://www.deepdetect.com/models/faces_embedded_ncnn"
+    "link":     "https://huggingface.co/bookbot/sherpa-ncnn-pruned-transducer-stateless7-streaming-id"
   },
   {
     "type":     "ncnn",