|
|
@@ -77,14 +77,6 @@ ncnn.ModelFactory = class {
|
|
|
|
|
|
async open(context) {
|
|
|
const metadata = await context.metadata('ncnn-metadata.json');
|
|
|
- const openBinary = (param, bin) => {
|
|
|
- const reader = new ncnn.BinaryParamReader(param);
|
|
|
- return new ncnn.Model(metadata, reader, bin);
|
|
|
- };
|
|
|
- const openText = (reader, bin) => {
|
|
|
- reader = new ncnn.TextParamReader(reader);
|
|
|
- return new ncnn.Model(metadata, reader, bin);
|
|
|
- };
|
|
|
const identifier = context.identifier.toLowerCase();
|
|
|
let bin = null;
|
|
|
switch (context.type) {
|
|
|
@@ -94,24 +86,29 @@ ncnn.ModelFactory = class {
|
|
|
} else if (identifier.endsWith('.cfg.ncnn')) {
|
|
|
bin = `${context.identifier.substring(0, context.identifier.length - 9)}.weights.ncnn`;
|
|
|
}
|
|
|
- const reader = context.read('text');
|
|
|
+ let buffer = null;
|
|
|
try {
|
|
|
const content = await context.fetch(bin);
|
|
|
- const buffer = content.stream.peek();
|
|
|
- return openText(reader, buffer);
|
|
|
+ buffer = content.stream.peek();
|
|
|
} catch {
|
|
|
- return openText(reader, null);
|
|
|
+ // continue regardless of error
|
|
|
}
|
|
|
+ const param = context.read('text');
|
|
|
+ const reader = new ncnn.TextParamReader(param);
|
|
|
+ return new ncnn.Model(metadata, reader, buffer);
|
|
|
}
|
|
|
case 'ncnn.model.bin': {
|
|
|
bin = `${context.identifier.substring(0, context.identifier.length - 10)}.bin`;
|
|
|
+ let buffer = null;
|
|
|
try {
|
|
|
const content = await context.fetch(bin);
|
|
|
- const buffer = content.stream.peek();
|
|
|
- return openBinary(context.stream.peek(), buffer);
|
|
|
+ buffer = content.stream.peek();
|
|
|
} catch {
|
|
|
- return openBinary(context.stream.peek(), null);
|
|
|
+ // continue regardless of error
|
|
|
}
|
|
|
+ const param = context.stream.peek();
|
|
|
+ const reader = new ncnn.BinaryParamReader(param);
|
|
|
+ return new ncnn.Model(metadata, reader, buffer);
|
|
|
}
|
|
|
case 'ncnn.weights': {
|
|
|
let file = null;
|
|
|
@@ -120,15 +117,18 @@ ncnn.ModelFactory = class {
|
|
|
} else if (identifier.endsWith('.weights.ncnn')) {
|
|
|
file = `${context.identifier.substring(0, context.identifier.length - 13)}.cfg.ncnn`;
|
|
|
}
|
|
|
+ let reader = null;
|
|
|
try {
|
|
|
const content = await context.fetch(file);
|
|
|
- const reader = content.read('text');
|
|
|
- return openText(reader, context.stream.peek());
|
|
|
+ const param = content.read('text');
|
|
|
+ reader = new ncnn.TextParamReader(param);
|
|
|
} catch {
|
|
|
const content = await context.fetch(`${file}.bin`);
|
|
|
- const buffer = content.stream.peek();
|
|
|
- return openBinary(buffer, context.stream.peek());
|
|
|
+ const param = content.stream.peek();
|
|
|
+ reader = new ncnn.BinaryParamReader(param);
|
|
|
}
|
|
|
+ const buffer = context.stream.peek();
|
|
|
+ return new ncnn.Model(metadata, reader, buffer);
|
|
|
}
|
|
|
default: {
|
|
|
throw new ncnn.Error(`Unsupported ncnn format '${context.type}'.`);
|
|
|
@@ -525,6 +525,38 @@ ncnn.Node = class {
|
|
|
attributes.delete('2');
|
|
|
break;
|
|
|
}
|
|
|
+ case 'Gemm': {
|
|
|
+ const transA = parseInt(attributes.get('2') || 0, 10);
|
|
|
+ const transB = parseInt(attributes.get('3') || 0, 10);
|
|
|
+ const constantA = parseInt(attributes.get('4') || 0, 10);
|
|
|
+ const constantB = parseInt(attributes.get('5') || 0, 10);
|
|
|
+ const constantC = parseInt(attributes.get('6') || 0, 10);
|
|
|
+ const M = parseInt(attributes.get('7') || 0, 10);
|
|
|
+ const N = parseInt(attributes.get('8') || 0, 10);
|
|
|
+ const K = parseInt(attributes.get('9') || 0, 10);
|
|
|
+ const constant_broadcast_type_C = parseInt(attributes.get('10') || 0, 10);
|
|
|
+ if (constantA === 1) {
|
|
|
+ weight(blobReader, 'A', transA === 0 ? [K, M] : [M, K]);
|
|
|
+ }
|
|
|
+ if (constantB === 1) {
|
|
|
+ weight(blobReader, 'B', transB === 1 ? [N, K] : [K, N]);
|
|
|
+ }
|
|
|
+ if (constantC === 1 && constant_broadcast_type_C !== -1) {
|
|
|
+ let shape = null;
|
|
|
+ switch (constant_broadcast_type_C) {
|
|
|
+ case 0: shape = [1]; break;
|
|
|
+ case 1: shape = [M]; break;
|
|
|
+ case 2: shape = [1, M]; break;
|
|
|
+ case 3: shape = [N, M]; break;
|
|
|
+ case 4: shape = [N, 1]; break;
|
|
|
+ default: break;
|
|
|
+ }
|
|
|
+ if (shape) {
|
|
|
+ weight(blobReader, 'C', shape);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
default: {
|
|
|
break;
|
|
|
}
|
|
|
@@ -753,21 +785,15 @@ ncnn.BlobReader = class {
|
|
|
const f3 = this._buffer[this._position++];
|
|
|
const type = f0 | f1 << 8 | f2 << 16 | f3 << 24;
|
|
|
switch (type) {
|
|
|
- case 0x00000000:
|
|
|
- dataType = 'float32';
|
|
|
- break;
|
|
|
- case 0x01306B47:
|
|
|
- dataType = 'float16';
|
|
|
- break;
|
|
|
- case 0x000D4B38:
|
|
|
- dataType = 'int8';
|
|
|
- break;
|
|
|
- case 0x00000001:
|
|
|
- dataType = 'qint8';
|
|
|
- break;
|
|
|
+ case 0x00000000: dataType = 'float32'; break;
|
|
|
+ case 0x01306B47: dataType = 'float16'; break;
|
|
|
+ case 0x000D4B38: dataType = 'int8'; break;
|
|
|
+ case 0x00000001: dataType = 'qint8'; break;
|
|
|
case 0x0002C056: // size * sizeof(float) - raw data with extra scaling
|
|
|
- default:
|
|
|
- throw new ncnn.Error(`Unsupported weight type '${type}'.`);
|
|
|
+ default: {
|
|
|
+ const hex = (type >>> 0).toString(16).padStart(8, '0');
|
|
|
+ throw new ncnn.Error(`Unsupported weight type '${hex}'.`);
|
|
|
+ }
|
|
|
}
|
|
|
} else {
|
|
|
this._buffer = null;
|
|
|
@@ -785,27 +811,17 @@ ncnn.BlobReader = class {
|
|
|
if (this._buffer) {
|
|
|
if (dataType) {
|
|
|
const position = this._position;
|
|
|
- switch (dataType) {
|
|
|
- case 'float32':
|
|
|
- size *= 4;
|
|
|
- this._position += size;
|
|
|
- data = this._buffer.subarray(position, this._position);
|
|
|
- break;
|
|
|
- case 'float16':
|
|
|
- size *= 2;
|
|
|
- this._position += size;
|
|
|
- data = this._buffer.subarray(position, this._position);
|
|
|
- break;
|
|
|
- case 'int8':
|
|
|
- this._position += size;
|
|
|
- data = this._buffer.subarray(position, this._position);
|
|
|
- break;
|
|
|
- case 'qint8':
|
|
|
- this._position += size + 1024;
|
|
|
- data = null;
|
|
|
- break;
|
|
|
- default:
|
|
|
- throw new ncnn.Error(`Unsupported weight type '${dataType}'.`);
|
|
|
+ const dataTypes = new Map([['float32', 4], ['float16', 2], ['int8', 1], ['qint8', 1]]);
|
|
|
+ if (!dataTypes.has(dataType)) {
|
|
|
+ throw new ncnn.Error(`Unsupported weight type '${dataType}'.`);
|
|
|
+ }
|
|
|
+ size *= dataTypes.get(dataType);
|
|
|
+ if (dataType === 'qint8') {
|
|
|
+ this._position += size + 1024;
|
|
|
+ data = null;
|
|
|
+ } else {
|
|
|
+ this._position += size;
|
|
|
+ data = this._buffer.subarray(position, this._position);
|
|
|
}
|
|
|
}
|
|
|
}
|