Przeglądaj źródła

Add PNNX test file (#1314)

Lutz Roeder 1 rok temu
rodzic
commit
4bd032d852
3 zmienionych plików z 57 dodań i 39 usunięć
  1. 2 1
      source/ncnn-metadata.json
  2. 48 38
      source/ncnn.js
  3. 7 0
      test/models.json

+ 2 - 1
source/ncnn-metadata.json

@@ -48,6 +48,7 @@
   {
     "name": "Clip",
     "identifier": 54,
+    "category": "Activation",
     "attributes": [
       { "name": "min", "type": "float32" },
       { "name": "max", "type": "float32" }
@@ -651,7 +652,7 @@
   {
     "name": "Padding",
     "identifier": 43,
-    "category": "Layer",
+    "category": "Tensor",
     "attributes": [
       { "name": "top", "default": 0 },
       { "name": "bottom", "default": 0 },

+ 48 - 38
source/ncnn.js

@@ -22,23 +22,30 @@ ncnn.ModelFactory = class {
             }
         } else if (identifier.endsWith('.param') || identifier.endsWith('.cfg.ncnn')) {
             const reader = context.read('text', 0x10000);
+            const type = identifier.endsWith('.pnnx.param') ? 'pnnx.model' : 'ncnn.model';
             if (reader) {
                 try {
                     const signature = reader.read('\n');
                     if (signature !== undefined) {
                         if (signature.trim() === '7767517') {
-                            context.type = 'ncnn.model';
+                            context.type = type;
                             return;
                         }
                         const header = signature.trim().split(' ');
                         if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) {
-                            context.type = 'ncnn.model';
+                            context.type = type;
                         }
                     }
                 } catch {
                     // continue regardless of error
                 }
             }
+        } else if (identifier.endsWith('.pnnx.bin')) {
+            const entries = context.peek('zip');
+            if (entries.size > 0) {
+                context.type = 'pnnx.weights';
+                context.target = entries;
+            }
         } else if (identifier.endsWith('.bin') || identifier.endsWith('.weights.ncnn')) {
             const stream = context.stream;
             if (stream.length > 4) {
@@ -78,27 +85,29 @@ ncnn.ModelFactory = class {
     async open(context) {
         const metadata = await context.metadata('ncnn-metadata.json');
         const identifier = context.identifier.toLowerCase();
-        let bin = null;
+        const format = context.type.split('.').shift();
         switch (context.type) {
+            case 'pnnx.model':
             case 'ncnn.model': {
+                let file = null;
                 if (identifier.endsWith('.param')) {
-                    bin = `${context.identifier.substring(0, context.identifier.length - 6)}.bin`;
+                    file = context.identifier.replace(/\.param$/, '.bin');
                 } else if (identifier.endsWith('.cfg.ncnn')) {
-                    bin = `${context.identifier.substring(0, context.identifier.length - 9)}.weights.ncnn`;
+                    file = context.identifier.replace(/\.cfg\.ncnn$/, '.weights.ncnn');
                 }
                 let buffer = null;
                 try {
-                    const content = await context.fetch(bin);
+                    const content = await context.fetch(file);
                     buffer = content.stream.peek();
                 } catch {
                     // continue regardless of error
                 }
                 const param = context.read('text');
                 const reader = new ncnn.TextParamReader(param);
-                return new ncnn.Model(metadata, reader, buffer);
+                return new ncnn.Model(metadata, format, reader, buffer);
             }
             case 'ncnn.model.bin': {
-                bin = `${context.identifier.substring(0, context.identifier.length - 10)}.bin`;
+                const bin = `${context.identifier.substring(0, context.identifier.length - 10)}.bin`;
                 let buffer = null;
                 try {
                     const content = await context.fetch(bin);
@@ -108,14 +117,15 @@ ncnn.ModelFactory = class {
                 }
                 const param = context.stream.peek();
                 const reader = new ncnn.BinaryParamReader(param);
-                return new ncnn.Model(metadata, reader, buffer);
+                return new ncnn.Model(metadata, format, reader, buffer);
             }
+            case 'pnnx.weights':
             case 'ncnn.weights': {
                 let file = null;
-                if (identifier.endsWith('bin')) {
-                    file = `${context.identifier.substring(0, context.identifier.length - 4)}.param`;
+                if (identifier.endsWith('.bin')) {
+                    file = context.identifier.replace(/\.bin$/, '.param');
                 } else if (identifier.endsWith('.weights.ncnn')) {
-                    file = `${context.identifier.substring(0, context.identifier.length - 13)}.cfg.ncnn`;
+                    file = context.identifier.replace(/\.weights\.ncnn$/, '.cfg.ncnn');
                 }
                 let reader = null;
                 try {
@@ -128,7 +138,7 @@ ncnn.ModelFactory = class {
                     reader = new ncnn.BinaryParamReader(param);
                 }
                 const buffer = context.stream.peek();
-                return new ncnn.Model(metadata, reader, buffer);
+                return new ncnn.Model(metadata, format, reader, buffer);
             }
             default: {
                 throw new ncnn.Error(`Unsupported ncnn format '${context.type}'.`);
@@ -139,15 +149,15 @@ ncnn.ModelFactory = class {
 
 ncnn.Model = class {
 
-    constructor(metadata, param, bin) {
-        this.format = 'ncnn';
-        this.graphs = [new ncnn.Graph(metadata, param, bin)];
+    constructor(metadata, format, param, bin) {
+        this.format = format === 'pnnx' ? 'PNNX' : 'ncnn';
+        this.graphs = [new ncnn.Graph(metadata, format, param, bin)];
     }
 };
 
 ncnn.Graph = class {
 
-    constructor(metadata, param, bin) {
+    constructor(metadata, format, param, bin) {
         this.inputs = [];
         this.outputs = [];
         this.nodes = [];
@@ -192,7 +202,7 @@ ncnn.Graph = class {
                 const input = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
                 this.inputs.push(input);
             } else {
-                const node = new ncnn.Node(metadata, blobs, layer, values);
+                const node = new ncnn.Node(metadata, format, blobs, layer, values);
                 this.nodes.push(node);
             }
         }
@@ -223,7 +233,7 @@ ncnn.Value = class {
 
 ncnn.Node = class {
 
-    constructor(metadata, blobs, layer, values) {
+    constructor(metadata, format, blobs, layer, values) {
         this.inputs = [];
         this.outputs = [];
         this.chain = [];
@@ -287,6 +297,13 @@ ncnn.Node = class {
                 break;
             }
             case 'InnerProduct': {
+                const num_output = parseInt(attributes.get('0') || 0, 10);
+                const weight_data_size = parseInt(attributes.get('2') || 0, 10);
+                blobs.weight('weight', [num_output, weight_data_size / num_output]);
+                if (parseInt(attributes.get('1') || 0, 10) === 1) {
+                    blobs.weight('bias', [num_output], 'float32');
+                }
+                attributes.delete('2');
                 const activation_names = ['', 'ReLU', 'Leaky ReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish'];
                 const activation_type = parseInt(attributes.get('9') || 0, 10);
                 if (activation_type > 0 && activation_type < activation_names.length) {
@@ -294,15 +311,8 @@ ncnn.Node = class {
                         type: activation_names[activation_type],
                         attributes: new Map()
                     };
-                    this.chain.push(new ncnn.Node(metadata, blobs, layer, values));
-                }
-                const num_output = parseInt(attributes.get('0') || 0, 10);
-                const weight_data_size = parseInt(attributes.get('2') || 0, 10);
-                blobs.weight('weight', [num_output, weight_data_size / num_output]);
-                if (parseInt(attributes.get('1') || 0, 10) === 1) {
-                    blobs.weight('bias', [num_output], 'float32');
+                    this.chain.push(new ncnn.Node(metadata, format, blobs, layer, values));
                 }
-                attributes.delete('2');
                 break;
             }
             case 'Bias': {
@@ -324,15 +334,6 @@ ncnn.Node = class {
             case 'ConvolutionDepthWise':
             case 'Deconvolution':
             case 'DeconvolutionDepthWise': {
-                const activation_names = ['', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish'];
-                const activation_type = parseInt(attributes.get('9') || 0, 10);
-                if (activation_type > 0 && activation_type < activation_names.length) {
-                    const layer = {
-                        type: activation_names[activation_type],
-                        attributes: new Map()
-                    };
-                    this.chain.push(new ncnn.Node(metadata, blobs, layer, values));
-                }
                 const num_output = parseInt(attributes.get('0') || 0, 10);
                 const kernel_w = parseInt(attributes.get('1') || 0, 10);
                 const kernel_h = parseInt(attributes.get('11') || kernel_w, 10);
@@ -364,6 +365,15 @@ ncnn.Node = class {
                     }
                 }
                 attributes.delete('6');
+                const activation_names = ['', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish'];
+                const activation_type = parseInt(attributes.get('9') || 0, 10);
+                if (activation_type > 0 && activation_type < activation_names.length) {
+                    const layer = {
+                        type: activation_names[activation_type],
+                        attributes: new Map()
+                    };
+                    this.chain.push(new ncnn.Node(metadata, format, blobs, layer, values));
+                }
                 break;
             }
             case 'Convolution1D':
@@ -375,7 +385,7 @@ ncnn.Node = class {
                         type: activation_names[activation_type],
                         attributes: new Map()
                     };
-                    const node = new ncnn.Node(metadata, blobs, layer, values);
+                    const node = new ncnn.Node(metadata, format, blobs, layer, values);
                     this.chain.push(node);
                 }
                 const num_output = parseInt(attributes.get('0') || 0, 10);
@@ -397,7 +407,7 @@ ncnn.Node = class {
                         type: activation_names[activation_type],
                         attributes: new Map()
                     };
-                    this.chain.push(new ncnn.Node(metadata, blobs, layer, values));
+                    this.chain.push(new ncnn.Node(metadata, format, blobs, layer, values));
                 }
                 const num_output = parseInt(attributes.get('0') || 0, 10);
                 const kernel_w = parseInt(attributes.get('1') || 0, 10);

+ 7 - 0
test/models.json

@@ -4836,6 +4836,13 @@
     "error":    "Unsupported Pickle type 'gensim.models.word2vec.Word2Vec'.",
     "link":     "https://github.com/lutzroeder/netron/issues/901"
   },
+  {
+    "type":     "pnnx",
+    "target":   "demo_net.pnnx.param,demo_net.pnnx.bin",
+    "source":   "https://github.com/user-attachments/files/16325286/demo_net.pnnx.zip[demo_net.pnnx.param,demo_net.pnnx.bin]",
+    "format":   "PNNX",
+    "link":     "https://github.com/lutzroeder/netron/issues/296"
+  },
   {
     "type":     "pytorch",
     "target":   "add.pte",