2
0
Lutz Roeder 3 долоо хоног өмнө
parent
commit
4723145ce7
1 өөрчлөгдсөн 16 нэмэгдсэн , 9 устгасан
  1. 16 9
      source/ncnn.js

+ 16 - 9
source/ncnn.js

@@ -46,7 +46,7 @@ ncnn.ModelFactory = class {
                             if (!line) {
                                 break;
                             }
-                            if (line.startsWith('pnnx.') || line.startsWith('nn.') || line.startsWith('F.')) {
+                            if (line.startsWith('pnnx.') || line.startsWith('nn.') || line.startsWith('F.') || line.startsWith('torch.') || line.startsWith('Tensor.')) {
                                 type = 'pnnx.model';
                                 break;
                             }
@@ -243,12 +243,10 @@ ncnn.Graph = class {
                 const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
                 this.inputs.push(argument);
             } else if (layer.type === 'pnnx.Input' && layer.params) {
-                const type = ncnn.Utility.route(layer.params, '0');
-                const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
+                const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, ncnn.Utility.route(layer.params, output))));
                 this.inputs.push(argument);
             } else if (layer.type === 'pnnx.Output' && layer.params) {
-                const type = ncnn.Utility.route(layer.params, '0');
-                const argument = new ncnn.Argument(layer.name, layer.inputs.map((input) => values.map(input, type)));
+                const argument = new ncnn.Argument(layer.name, layer.inputs.map((input) => values.map(input, ncnn.Utility.route(layer.params, input))));
                 this.outputs.push(argument);
             } else {
                 const node = new ncnn.Node(metadata, format, blobs, layer, values);
@@ -785,9 +783,7 @@ ncnn.TensorShape = class {
     }
 
     equals(obj) {
-        return obj && Array.isArray(obj.dimensions) &&
-            Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length
-            && obj.dimensions.every((value, index) => this.dimensions[index] === value);
+        return obj && Array.isArray(obj.dimensions) && Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length && obj.dimensions.every((value, index) => Object.is(this.dimensions[index], value));
     }
 
     toString() {
@@ -1069,7 +1065,18 @@ pnnx.Metadata = class {
             const items = JSON.parse(data);
             for (const item of items) {
                 item.name = item.name.replace(/^torch\.nn\.modules\.(\w)+\./, 'nn.');
-                item.name = item.name.replace(/aten::([a-z_]+)(\.\w+)?/g, (match, p1) => `torch.${p1}`);
+                const match = item.name.match(/^aten::([a-z_]+)/);
+                if (match) {
+                    const name = match[1];
+                    if (item.category) {
+                        if (!this._types.has(`torch.${name}`)) {
+                            this._types.set(`torch.${name}`, { name: `torch.${name}`, category: item.category });
+                        }
+                        if (!this._types.has(`F.${name}`)) {
+                            this._types.set(`F.${name}`, { name: `F.${name}`, category: item.category });
+                        }
+                    }
+                }
                 this._types.set(item.name, { name: item.name, category: item.category });
             }
         }