Browse Source

Add ncnn test file (#296)

Lutz Roeder 1 year ago
parent
commit
5209dfc716
3 changed files with 93 additions and 56 deletions
  1. 15 1
      source/ncnn-metadata.json
  2. 70 54
      source/ncnn.js
  3. 8 1
      test/models.json

+ 15 - 1
source/ncnn-metadata.json

@@ -425,7 +425,21 @@
       { "name": "alpha", "type": "float32", "default": 1 },
       { "name": "beta", "type": "float32", "default": 1 },
       { "name": "transA", "type": "int32", "default": 0 },
-      { "name": "transB", "type": "int32", "default": 0 }
+      { "name": "transB", "type": "int32", "default": 0 },
+      { "name": "constantA", "type": "int32", "default": 0 },
+      { "name": "constantB", "type": "int32", "default": 0 },
+      { "name": "constantC", "type": "int32", "default": 0 },
+      { "name": "constantM", "type": "int32", "default": 0 },
+      { "name": "constantN", "type": "int32", "default": 0 },
+      { "name": "constantK", "type": "int32", "default": 0 },
+      { "name": "constant_broadcast_type_C", "type": "int32", "default": 0 },
+      { "name": "output_N1M", "type": "int32", "default": 0 },
+      { "name": "output_elempack", "type": "int32", "default": 0 },
+      { "name": "output_elemtype", "type": "int32", "default": 0 },
+      { "name": "output_transpose", "type": "int32", "default": 0 },
+      { "name": "constant_TILE_M", "type": "int32", "default": 0 },
+      { "name": "constant_TILE_N", "type": "int32", "default": 0 },
+      { "name": "constant_TILE_K", "type": "int32", "default": 0 }
     ]
   },
   {

+ 70 - 54
source/ncnn.js

@@ -77,14 +77,6 @@ ncnn.ModelFactory = class {
 
     async open(context) {
         const metadata = await context.metadata('ncnn-metadata.json');
-        const openBinary = (param, bin) => {
-            const reader = new ncnn.BinaryParamReader(param);
-            return new ncnn.Model(metadata, reader, bin);
-        };
-        const openText = (reader, bin) => {
-            reader = new ncnn.TextParamReader(reader);
-            return new ncnn.Model(metadata, reader, bin);
-        };
         const identifier = context.identifier.toLowerCase();
         let bin = null;
         switch (context.type) {
@@ -94,24 +86,29 @@ ncnn.ModelFactory = class {
                 } else if (identifier.endsWith('.cfg.ncnn')) {
                     bin = `${context.identifier.substring(0, context.identifier.length - 9)}.weights.ncnn`;
                 }
-                const reader = context.read('text');
+                let buffer = null;
                 try {
                     const content = await context.fetch(bin);
-                    const buffer = content.stream.peek();
-                    return openText(reader, buffer);
+                    buffer = content.stream.peek();
                 } catch {
-                    return openText(reader, null);
+                    // continue regardless of error
                 }
+                const param = context.read('text');
+                const reader = new ncnn.TextParamReader(param);
+                return new ncnn.Model(metadata, reader, buffer);
             }
             case 'ncnn.model.bin': {
                 bin = `${context.identifier.substring(0, context.identifier.length - 10)}.bin`;
+                let buffer = null;
                 try {
                     const content = await context.fetch(bin);
-                    const buffer = content.stream.peek();
-                    return openBinary(context.stream.peek(), buffer);
+                    buffer = content.stream.peek();
                 } catch {
-                    return openBinary(context.stream.peek(), null);
+                    // continue regardless of error
                 }
+                const param = context.stream.peek();
+                const reader = new ncnn.BinaryParamReader(param);
+                return new ncnn.Model(metadata, reader, buffer);
             }
             case 'ncnn.weights': {
                 let file = null;
@@ -120,15 +117,18 @@ ncnn.ModelFactory = class {
                 } else if (identifier.endsWith('.weights.ncnn')) {
                     file = `${context.identifier.substring(0, context.identifier.length - 13)}.cfg.ncnn`;
                 }
+                let reader = null;
                 try {
                     const content = await context.fetch(file);
-                    const reader = content.read('text');
-                    return openText(reader, context.stream.peek());
+                    const param = content.read('text');
+                    reader = new ncnn.TextParamReader(param);
                 } catch {
                     const content = await context.fetch(`${file}.bin`);
-                    const buffer = content.stream.peek();
-                    return openBinary(buffer, context.stream.peek());
+                    const param = content.stream.peek();
+                    reader = new ncnn.BinaryParamReader(param);
                 }
+                const buffer = context.stream.peek();
+                return new ncnn.Model(metadata, reader, buffer);
             }
             default: {
                 throw new ncnn.Error(`Unsupported ncnn format '${context.type}'.`);
@@ -525,6 +525,38 @@ ncnn.Node = class {
                 attributes.delete('2');
                 break;
             }
+            case 'Gemm': {
+                const transA = parseInt(attributes.get('2') || 0, 10);
+                const transB = parseInt(attributes.get('3') || 0, 10);
+                const constantA = parseInt(attributes.get('4') || 0, 10);
+                const constantB = parseInt(attributes.get('5') || 0, 10);
+                const constantC = parseInt(attributes.get('6') || 0, 10);
+                const M = parseInt(attributes.get('7') || 0, 10);
+                const N = parseInt(attributes.get('8') || 0, 10);
+                const K = parseInt(attributes.get('9') || 0, 10);
+                const constant_broadcast_type_C = parseInt(attributes.get('10') || 0, 10);
+                if (constantA === 1) {
+                    weight(blobReader, 'A', transA === 0 ? [K, M] : [M, K]);
+                }
+                if (constantB === 1) {
+                    weight(blobReader, 'B', transB === 1 ? [N, K] : [K, N]);
+                }
+                if (constantC === 1 && constant_broadcast_type_C !== -1) {
+                    let shape = null;
+                    switch (constant_broadcast_type_C) {
+                        case 0: shape = [1]; break;
+                        case 1: shape = [M]; break;
+                        case 2: shape = [1, M]; break;
+                        case 3: shape = [N, M]; break;
+                        case 4: shape = [N, 1]; break;
+                        default: break;
+                    }
+                    if (shape) {
+                        weight(blobReader, 'C', shape);
+                    }
+                }
+                break;
+            }
             default: {
                 break;
             }
@@ -753,21 +785,15 @@ ncnn.BlobReader = class {
                     const f3 = this._buffer[this._position++];
                     const type = f0 | f1 << 8 | f2 << 16 | f3 << 24;
                     switch (type) {
-                        case 0x00000000:
-                            dataType = 'float32';
-                            break;
-                        case 0x01306B47:
-                            dataType = 'float16';
-                            break;
-                        case 0x000D4B38:
-                            dataType = 'int8';
-                            break;
-                        case 0x00000001:
-                            dataType = 'qint8';
-                            break;
+                        case 0x00000000: dataType = 'float32'; break;
+                        case 0x01306B47: dataType = 'float16'; break;
+                        case 0x000D4B38: dataType = 'int8'; break;
+                        case 0x00000001: dataType = 'qint8'; break;
                         case 0x0002C056: // size * sizeof(float) - raw data with extra scaling
-                        default:
-                            throw new ncnn.Error(`Unsupported weight type '${type}'.`);
+                        default: {
+                            const hex = (type >>> 0).toString(16).padStart(8, '0');
+                            throw new ncnn.Error(`Unsupported weight type '${hex}'.`);
+                        }
                     }
                 } else {
                     this._buffer = null;
@@ -785,27 +811,17 @@ ncnn.BlobReader = class {
             if (this._buffer) {
                 if (dataType) {
                     const position = this._position;
-                    switch (dataType) {
-                        case 'float32':
-                            size *= 4;
-                            this._position += size;
-                            data = this._buffer.subarray(position, this._position);
-                            break;
-                        case 'float16':
-                            size *= 2;
-                            this._position += size;
-                            data = this._buffer.subarray(position, this._position);
-                            break;
-                        case 'int8':
-                            this._position += size;
-                            data = this._buffer.subarray(position, this._position);
-                            break;
-                        case 'qint8':
-                            this._position += size + 1024;
-                            data = null;
-                            break;
-                        default:
-                            throw new ncnn.Error(`Unsupported weight type '${dataType}'.`);
+                    const dataTypes = new Map([['float32', 4], ['float16', 2], ['int8', 1], ['qint8', 1]]);
+                    if (!dataTypes.has(dataType)) {
+                        throw new ncnn.Error(`Unsupported weight type '${dataType}'.`);
+                    }
+                    size *= dataTypes.get(dataType);
+                    if (dataType === 'qint8') {
+                        this._position += size + 1024;
+                        data = null;
+                    } else {
+                        this._position += size;
+                        data = this._buffer.subarray(position, this._position);
                     }
                 }
             }

+ 8 - 1
test/models.json

@@ -3367,6 +3367,13 @@
     "tags":     "validation",
     "link":     "https://github.com/MirrorYuChen/ncnn_example"
   },
+  {
+    "type":     "ncnn",
+    "target":   "ch_recv4.ncnn.param,ch_recv4.ncnn.bin",
+    "source":   "https://github.com/user-attachments/files/15937032/ch_recv4.ncnn.param.zip[ch_recv4.ncnn.param,ch_recv4.ncnn.bin]",
+    "format":   "ncnn",
+    "link":     "https://github.com/lutzroeder/netron/issues/296"
+  },
   {
     "type":     "ncnn",
     "target":   "darknet_yolov2.cfg.ncnn,darknet_yolov2.weights.ncnn",
@@ -3407,7 +3414,7 @@
     "target":   "mnet.25.zip",
     "source":   "https://github.com/lutzroeder/netron/files/6813063/mnet.25.zip",
     "format":   "ncnn",
-    "link":     "https://github.com/MirrorYuChen/ncnn_example"
+    "link":     "https://github.com/lutzroeder/netron/issues/768"
   },
   {
     "type":     "ncnn",