|
|
@@ -46,7 +46,7 @@ ncnn.ModelFactory = class {
|
|
|
if (!line) {
|
|
|
break;
|
|
|
}
|
|
|
- if (line.startsWith('pnnx.') || line.startsWith('nn.') || line.startsWith('F.')) {
|
|
|
+ if (line.startsWith('pnnx.') || line.startsWith('nn.') || line.startsWith('F.') || line.startsWith('torch.') || line.startsWith('Tensor.')) {
|
|
|
type = 'pnnx.model';
|
|
|
break;
|
|
|
}
|
|
|
@@ -243,12 +243,10 @@ ncnn.Graph = class {
|
|
|
const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
|
|
|
this.inputs.push(argument);
|
|
|
} else if (layer.type === 'pnnx.Input' && layer.params) {
|
|
|
- const type = ncnn.Utility.route(layer.params, '0');
|
|
|
- const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, type)));
|
|
|
+ const argument = new ncnn.Argument(layer.name, layer.outputs.map((output) => values.map(output, ncnn.Utility.route(layer.params, output))));
|
|
|
this.inputs.push(argument);
|
|
|
} else if (layer.type === 'pnnx.Output' && layer.params) {
|
|
|
- const type = ncnn.Utility.route(layer.params, '0');
|
|
|
- const argument = new ncnn.Argument(layer.name, layer.inputs.map((input) => values.map(input, type)));
|
|
|
+ const argument = new ncnn.Argument(layer.name, layer.inputs.map((input) => values.map(input, ncnn.Utility.route(layer.params, input))));
|
|
|
this.outputs.push(argument);
|
|
|
} else {
|
|
|
const node = new ncnn.Node(metadata, format, blobs, layer, values);
|
|
|
@@ -785,9 +783,7 @@ ncnn.TensorShape = class {
|
|
|
}
|
|
|
|
|
|
equals(obj) {
|
|
|
- return obj && Array.isArray(obj.dimensions) &&
|
|
|
- Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length
|
|
|
- && obj.dimensions.every((value, index) => this.dimensions[index] === value);
|
|
|
+ return obj && Array.isArray(obj.dimensions) && Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length && obj.dimensions.every((value, index) => Object.is(this.dimensions[index], value));
|
|
|
}
|
|
|
|
|
|
toString() {
|
|
|
@@ -1069,7 +1065,18 @@ pnnx.Metadata = class {
|
|
|
const items = JSON.parse(data);
|
|
|
for (const item of items) {
|
|
|
item.name = item.name.replace(/^torch\.nn\.modules\.(\w)+\./, 'nn.');
|
|
|
- item.name = item.name.replace(/aten::([a-z_]+)(\.\w+)?/g, (match, p1) => `torch.${p1}`);
|
|
|
+ const match = item.name.match(/^aten::([a-z_]+)/);
|
|
|
+ if (match) {
|
|
|
+ const name = match[1];
|
|
|
+ if (item.category) {
|
|
|
+ if (!this._types.has(`torch.${name}`)) {
|
|
|
+ this._types.set(`torch.${name}`, { name: `torch.${name}`, category: item.category });
|
|
|
+ }
|
|
|
+ if (!this._types.has(`F.${name}`)) {
|
|
|
+ this._types.set(`F.${name}`, { name: `F.${name}`, category: item.category });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
this._types.set(item.name, { name: item.name, category: item.category });
|
|
|
}
|
|
|
}
|