| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633 |
- const tnn = {};
- tnn.ModelFactory = class {
- async match(context) {
- const identifier = context.identifier.toLowerCase();
- const stream = context.stream;
- if (stream && identifier.endsWith('.tnnproto')) {
- try {
- const reader = await context.read('text', 0x10000);
- const content = reader.read('\n');
- if (content !== undefined) {
- const line = content.trim();
- if (line.startsWith('"') && line.endsWith('"')) {
- const header = line.replace(/(^")|("$)/g, '').split(',').shift().trim().split(' ');
- if (header.length === 3 || (header.length >= 4 && (header[3] === '4206624770' || header[3] === '4206624772'))) {
- return context.set('tnn.model');
- }
- }
- }
- } catch {
- // continue regardless of error
- }
- }
- if (stream && identifier.endsWith('.tnnmodel')) {
- for (const signature of [[0x02, 0x00, 0xbc, 0xfa], [0x04, 0x00, 0xbc, 0xfa]]) {
- if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
- return context.set('tnn.params');
- }
- }
- }
- return null;
- }
- async open(context) {
- const metadata = await context.metadata('tnn-metadata.json');
- switch (context.type) {
- case 'tnn.model': {
- const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnmodel`;
- const reader = await context.read('text');
- try {
- const content = await context.fetch(name);
- const resources = await tnn.LayerResourceReader.open(content);
- return new tnn.Model(metadata, reader, resources);
- } catch {
- const resources = await tnn.LayerResourceReader.open(null);
- return new tnn.Model(metadata, reader, resources);
- }
- }
- case 'tnn.params': {
- const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnproto`;
- const content = await context.fetch(name, null);
- const reader = await content.read('text');
- const resources = await tnn.LayerResourceReader.open(context);
- return new tnn.Model(metadata, reader, resources);
- }
- default: {
- throw new tnn.Error(`Unsupported TNN format '${context.type}'.`);
- }
- }
- }
- };
- tnn.Model = class {
- constructor(metadata, tnnproto, resources) {
- this.format = 'TNN';
- this.modules = [new tnn.Graph(metadata, tnnproto, resources)];
- }
- };
- tnn.Graph = class {
- constructor(metadata, tnnproto, resources) {
- this.inputs = [];
- this.outputs = [];
- this.nodes = [];
- const reader = new tnn.TextProtoReader(tnnproto);
- reader.read('\n');
- const values = new Map();
- values.map = (name, type, tensor) => {
- if (name.length === 0) {
- return new tnn.Value(name, type || null, tensor || null);
- }
- if (!values.has(name)) {
- values.set(name, new tnn.Value(name, type || null, tensor || null));
- } else if (type || tensor) {
- throw new tnn.Value(`Duplicate value '${name}'.`);
- }
- return values.get(name);
- };
- for (const input of reader.inputs) {
- const shape = new tnn.TensorShape(input.shape);
- const type = new tnn.TensorType(input.data_type, shape);
- const argument = new tnn.Argument(input.name, [values.map(input.name, type)]);
- this.inputs.push(argument);
- }
- for (const output of reader.outputs) {
- const argument = new tnn.Argument(output.name, [values.map(output.name)]);
- this.outputs.push(argument);
- }
- for (const layer of reader.layers) {
- const node = new tnn.Node(metadata, resources, layer, values);
- this.nodes.push(node);
- }
- }
- };
- tnn.Argument = class {
- constructor(name, value, type = null, visible = true) {
- this.name = name;
- this.value = value;
- this.type = type;
- this.visible = visible;
- }
- };
- tnn.Value = class {
- constructor(name, type, initializer = null) {
- if (typeof name !== 'string') {
- throw new tnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
- }
- this.name = name;
- this.type = initializer ? initializer.type : type;
- this.initializer = initializer;
- }
- };
- tnn.Node = class {
- constructor(metadata, resources, layer, values) {
- this.inputs = [];
- this.outputs = [];
- this.attributes = [];
- this.name = layer.name;
- this.type = { ...metadata.type(layer.type) };
- delete this.type.identifier;
- const entries = Array.from(layer.params);
- for (let i = 0; i < entries.length;) {
- const metadata = this.type && Array.isArray(this.type.attributes) ? this.type.attributes[i] : null;
- let name = '';
- let value = null;
- let type = '';
- let visible = true;
- if (metadata && metadata.type === 'int32[]' && metadata.size) {
- const size = parseInt(layer.params.get(metadata.size), 10);
- value = entries.slice(i, i + size).map(([, value]) => parseInt(value, 10));
- i += size;
- } else {
- [name, value] = entries[i];
- i += 1;
- }
- if (metadata) {
- name = metadata.name ? metadata.name : name;
- type = metadata.type ? metadata.type : type;
- switch (type) {
- case '':
- break;
- case 'int32':
- value = parseInt(value, 10);
- break;
- case 'float32':
- value = parseFloat(value);
- break;
- case 'int32[]':
- value = value.map((v) => parseInt(v, 10));
- break;
- default:
- throw new tnn.Error(`Unsupported attribute type '${type}'.`);
- }
- visible = (metadata.visible === false) || (metadata.default !== undefined && (value === metadata.default || (value && value.toString() === metadata.default.toString()))) ? false : visible;
- }
- const argument = new tnn.Argument(name, value, type, visible);
- this.attributes.push(argument);
- }
- const inputs = layer.inputs;
- let inputIndex = 0;
- if (this.type && this.type.inputs) {
- for (const inputDef of this.type.inputs) {
- if (inputIndex < inputs.length || inputDef.option !== 'optional') {
- const inputCount = (inputDef.type === 'Tensor[]') ? (inputs.length - inputIndex) : 1;
- const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount).filter((id) => id !== '' || inputDef.option !== 'optional').map((id) => values.map(id));
- const argument = new tnn.Argument(inputDef.name, inputArguments);
- this.inputs.push(argument);
- inputIndex += inputCount;
- }
- }
- } else {
- this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
- const inputName = ((inputIndex + index) === 0) ? 'input' : (inputIndex + index).toString();
- return new tnn.Argument(inputName, [values.map(input)]);
- }));
- }
- const outputs = layer.outputs;
- let outputIndex = 0;
- if (this.type && this.type.outputs) {
- for (const outputDef of this.type.outputs) {
- if (outputIndex < outputs.length || outputDef.option !== 'optional') {
- const outputCount = (outputDef.option === 'variadic') ? (outputs.length - outputIndex) : 1;
- const outputArguments = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => values.map(id));
- const argument = new tnn.Argument(outputDef.name, outputArguments);
- this.outputs.push(argument);
- outputIndex += outputCount;
- }
- }
- } else {
- this.outputs.push(...outputs.slice(outputIndex).map((output, index) => {
- const outputName = ((outputIndex + index) === 0) ? 'output' : (outputIndex + index).toString();
- return new tnn.Argument(outputName, [values.map(output)]);
- }));
- }
- const weight = (resource, name, shape) => {
- const initializer = resource[name];
- if (!initializer) {
- throw new tnn.Error(`Layer initializer'${resource.type}.${name}' not found '`);
- }
- const tensor = new tnn.Tensor(new tnn.TensorType(initializer.dataType, new tnn.TensorShape(shape)), initializer.value);
- const argument = new tnn.Argument(name, [values.map('', null, tensor)]);
- this.inputs.push(argument);
- };
- const params = layer.params;
- switch (this.type.name) {
- case 'Convolution':
- case 'ConvolutionDepthWise':
- case 'Deconvolution':
- case 'DeconvolutionDepthWise': {
- const resource = resources.get(this.name);
- if (resource) {
- const num_output = parseInt(params.get('2') || 0, 10);
- const kernel_w = parseInt(params.get('3') || 0, 10);
- const kernel_h = parseInt(params.get('4') || kernel_w, 10);
- const weight_data_size = resource.filter.length;
- weight(resource, 'filter', [num_output, weight_data_size / (num_output * kernel_w * kernel_h), kernel_w, kernel_h]);
- if (resource.bias) {
- weight(resource, 'bias', [num_output]);
- }
- if (resource.quantized) {
- weight(resource, 'quantized', [num_output]);
- }
- }
- break;
- }
- case 'Conv3D':{
- const resource = resources.get(this.name);
- if (resource) {
- const num_output = parseInt(params.get('2') || 0, 10);
- const kernel_w = parseInt(params.get('3') || 0, 10);
- const kernel_h = parseInt(params.get('4') || kernel_w, 10);
- const kernel_d = parseInt(params.get('5') || kernel_w, 10);
- const weight_data_size = resource.filter.length;
- weight(resource, 'weight', [num_output, weight_data_size / (num_output * kernel_w * kernel_h * kernel_d), kernel_w, kernel_h, kernel_d]);
- if (resource.bias) {
- weight(resources, 'bias', [num_output]);
- }
- }
- break;
- }
- case 'InnerProduct': {
- const resource = resources.get(this.name);
- if (resource) {
- const num_output = parseInt(params.get('0') || 0, 10);
- const weight_data_size = resource.weight.length;
- weight(resource, 'weight', [num_output, weight_data_size / num_output]);
- weight(resource, 'bias', [num_output]);
- if (resource.weight.dataType === 'int8') {
- weight(resource, 'scale', [num_output]);
- }
- }
- break;
- }
- case 'PReLU': {
- const resource = resources.get(this.name);
- if (resource) {
- weight(resource, 'slope', [resource.slope.length]);
- }
- break;
- }
- case 'BatchNormCxx':
- case 'InstBatchNormCxx': {
- const resource = resources.get(this.name);
- if (resource) {
- weight(resource, 'scale', [resource.scale.length]);
- weight(resource, 'bias', [resource.bias.length]);
- }
- break;
- }
- case 'Div':
- case 'Sub':
- case 'Add':
- case 'Mul':
- case 'MatMul': {
- if (this.inputs.length === 1) {
- const resource = resources.get(this.name);
- if (resource) {
- const num_output = resource.slope.length;
- weight(resource, 'slope', [num_output]);
- }
- }
- break;
- }
- case 'HdrGuide': {
- const resource = resources.get(this.name);
- if (resource) {
- const weight_size = resource.ccm_weight.length;
- weight(resource, 'ccm_weight', [weight_size]);
- weight(resource, 'ccm_bias', [weight_size]);
- weight(resource, 'shifts', [weight_size]);
- weight(resource, 'slopes', [weight_size]);
- weight(resource, 'projection_weight', [weight_size]);
- weight(resource, 'projection_bias', [weight_size]);
- }
- break;
- }
- case 'BlobScale': {
- const resource = resources.get(this.name);
- if (resource) {
- const scale_data_size = resource.scale.length;
- weight(resource, 'scale', [scale_data_size]);
- weight(resource, 'bias', [scale_data_size]);
- }
- break;
- }
- case 'Gather': {
- const resource = resources.get(this.name);
- if (resource) {
- if (resource.data) {
- weight(resource, 'data', [resource.data.length]);
- }
- if (resource.indices) {
- weight(resource, 'indices', [resource.indices.length]);
- }
- }
- break;
- }
- default: {
- break;
- }
- }
- }
- };
- tnn.Tensor = class {
- constructor(type, values) {
- this.type = type;
- this.values = values;
- }
- };
- tnn.TensorType = class {
- constructor(dataType, shape) {
- this.dataType = dataType || '?';
- this.shape = shape;
- }
- toString() {
- return this.dataType + this.shape.toString();
- }
- };
- tnn.TensorShape = class {
- constructor(dimensions) {
- this.dimensions = dimensions;
- }
- toString() {
- return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`) : '';
- }
- };
- tnn.TextProtoReader = class {
- constructor(reader) {
- this.reader = reader;
- this.inputs = [];
- this.outputs = [];
- this.layers = [];
- }
- read() {
- if (this.reader) {
- let lines = [];
- for (;;) {
- const line = this.reader.read('\n');
- if (line === undefined) {
- break;
- }
- lines.push(line.replace(/\r|"/g, ''));
- }
- const split = (line, delimiter, trim, ignore_blank) => {
- return line.split(delimiter).map((v) => trim ? v.trim() : v).filter((v) => !ignore_blank || v);
- };
- lines = split(lines.join(''), ',', true, false);
- if (lines.length <= 5) {
- throw new tnn.Error('Invalid line count.');
- }
- const header = split(lines.shift(), ' ', true, false);
- if (header.length < 3) {
- throw new tnn.Error('Invalid header size.');
- } else if (header.length > 3 && (header[3] !== '4206624770' && header[3] !== '4206624772')) {
- throw new tnn.Error(`Invalid signature '${header[3]}'.`);
- }
- this.inputs = split(lines.shift(), ':', true, false).map((input) => {
- const array = split(input, ' ', true, false);
- const name = array.shift();
- if (header[3] === '4206624772') {
- const shape_size = parseInt(array.shift(), 10);
- const data_type_index = parseInt(array[shape_size], 10);
- return {
- name,
- data_type: ['float32', 'float16', 'int8', 'int32', 'bfloat16'][data_type_index],
- shape: array.slice(0, -1).map((dim) => parseInt(dim, 10)),
- };
- }
- return {
- name,
- data_type: 'float32',
- shape: array.map((dim) => parseInt(dim, 10))
- };
- });
- lines.shift();
- this.outputs = split(lines.shift(), ' ', true, false).map((output) => {
- return { name: output };
- });
- lines.shift();
- while (lines.length > 0) {
- const line = lines.shift().trim();
- if (line.length > 0) {
- const array = split(line, ' ', true, true);
- const layer = {};
- layer.type = array.shift();
- layer.name = array.shift();
- const inputs = parseInt(array.shift(), 10);
- const outputs = parseInt(array.shift(), 10);
- layer.inputs = array.splice(0, inputs);
- layer.outputs = array.splice(0, outputs);
- layer.params = new Map();
- let count = 0;
- for (const column of array) {
- const parts = column.split(' ');
- if (parts.length === 1) {
- let key = count.toString();
- let value = parts.toString();
- const keyInt = parseInt(key, 10);
- if (keyInt < 0) {
- value = value.split(',').map((v) => v.trim());
- value.shift();
- key = (-(keyInt + 23300)).toString();
- }
- layer.params.set(key, value);
- count++;
- }
- }
- this.layers.push(layer);
- }
- }
- delete this.reader;
- }
- }
- };
- tnn.LayerResourceReader = class {
- static async open(context) {
- if (context) {
- const reader = await context.read('binary');
- return new tnn.LayerResourceReader(reader);
- }
- return new tnn.LayerResourceReader(null);
- }
- constructor(reader) {
- this.resources = new Map();
- if (reader) {
- this.reader = reader;
- const magic_number = this.reader.uint32();
- if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
- throw new tnn.Error(`Invalid blob header signature '${magic_number}'.`);
- }
- const size = this.reader.int32() & 0x1FFFFFFF;
- for (let i = 0; i < size; i++) {
- const resource = {};
- resource.operator = this.reader.int32();
- resource.type = this.reader.string();
- resource.name = this.reader.string();
- switch (resource.type) {
- case 'Convolution':
- case 'ConvolutionDepthWise':
- case 'Deconvolution':
- case 'DeconvolutionDepthWise': {
- this._expect(resource.name);
- const bias = this.reader.int32();
- resource.filter = this._read();
- if (bias) {
- resource.bias = this._read();
- }
- if (resource.filter.dataType === 'int8') {
- resource.quantized = this._read();
- }
- break;
- }
- case 'Conv3D': {
- this._expect(resource.name);
- const bias = this.reader.int32();
- resource.filter = this._read();
- if (bias) {
- resource.bias = this._read();
- }
- break;
- }
- case 'InnerProduct': {
- this._expect(resource.name);
- resource.weight = this._read();
- resource.bias = this._read();
- if (resource.weight.dataType === 'int8') {
- resource.scale = this._read();
- }
- break;
- }
- case 'PReLU': {
- this._expect(resource.name);
- resource.slope = this._read();
- break;
- }
- case 'Add':
- case 'Div':
- case 'Mul':
- case 'Sub':
- case 'MatMul': {
- resource.slope = this._read();
- break;
- }
- case 'BatchNormCxx':
- case 'InstBatchNormCxx':
- resource.scale = this._read();
- resource.bias = this._read();
- break;
- case 'HdrGuide':
- resource.ccm_weight = this._read();
- resource.ccm_bias = this._read();
- resource.shifts = this._read();
- resource.slopes = this._read();
- resource.projection_weight = this._read();
- resource.projection_bias = this._read();
- break;
- case 'BlobScale':
- resource.scale = this._read();
- resource.bias = this._read();
- break;
- case 'Gather': {
- // reader.expect(resource.name);
- const has_data = this.reader.int32();
- if (has_data) {
- resource.data = this._read();
- }
- const has_indices = this.reader.int32();
- if (has_indices) {
- resource.indices = this._read();
- }
- break;
- }
- default: {
- throw new tnn.Error(`Unsupported layer resource type '${resource.type}'.`);
- }
- }
- this.resources.set(resource.name, resource);
- }
- if (this.reader.position !== this.reader.length) {
- throw new tnn.Error("Invalid blob size.");
- }
- delete this.reader;
- }
- }
- _read() {
- const magic_number = this.reader.uint32();
- if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
- throw new tnn.Error(`Invalid raw signature '${magic_number}'.`);
- }
- const data_type = this.reader.int32();
- if (data_type > 4) {
- throw new tnn.Error(`Unsupported data type '${data_type}'.`);
- }
- const length = this.reader.int32();
- if (length <= 0) {
- return null;
- }
- let dims = null;
- if (magic_number === 0xFABC0004) {
- const dim_size = this.reader.int32();
- dims = this.reader.read(dim_size * 4);
- }
- return {
- dataType: ['float32', 'float16', 'int8', 'int32', 'bfloat16'][data_type],
- length: length / [4, 2, 1, 4, 2][data_type],
- value: this.reader.read(length),
- shape: dims
- };
- }
- _expect(name) {
- const content = this.reader.string();
- if (name !== content) {
- throw new tnn.Error(`Invalid string '${content}' instead of '${name}'.`);
- }
- }
- get(name) {
- if (this.resources.size === 0) {
- return null;
- }
- if (!this.resources.has(name)) {
- throw new tnn.Error(`Invalid blob layer name '${name}'.`);
- }
- return this.resources.get(name);
- }
- };
- tnn.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading TNN model.';
- }
- };
- export const ModelFactory = tnn.ModelFactory;
|