tnn.js 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. const tnn = {};
  2. tnn.ModelFactory = class {
  3. async match(context) {
  4. const identifier = context.identifier.toLowerCase();
  5. const stream = context.stream;
  6. if (stream && identifier.endsWith('.tnnproto')) {
  7. try {
  8. const reader = await context.read('text', 0x10000);
  9. const content = reader.read('\n');
  10. if (content !== undefined) {
  11. const line = content.trim();
  12. if (line.startsWith('"') && line.endsWith('"')) {
  13. const header = line.replace(/(^")|("$)/g, '').split(',').shift().trim().split(' ');
  14. if (header.length === 3 || (header.length >= 4 && (header[3] === '4206624770' || header[3] === '4206624772'))) {
  15. return context.set('tnn.model');
  16. }
  17. }
  18. }
  19. } catch {
  20. // continue regardless of error
  21. }
  22. }
  23. if (stream && identifier.endsWith('.tnnmodel')) {
  24. for (const signature of [[0x02, 0x00, 0xbc, 0xfa], [0x04, 0x00, 0xbc, 0xfa]]) {
  25. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  26. return context.set('tnn.params');
  27. }
  28. }
  29. }
  30. return null;
  31. }
  32. async open(context) {
  33. const metadata = await context.metadata('tnn-metadata.json');
  34. switch (context.type) {
  35. case 'tnn.model': {
  36. const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnmodel`;
  37. const reader = await context.read('text');
  38. try {
  39. const content = await context.fetch(name);
  40. const resources = await tnn.LayerResourceReader.open(content);
  41. return new tnn.Model(metadata, reader, resources);
  42. } catch {
  43. const resources = await tnn.LayerResourceReader.open(null);
  44. return new tnn.Model(metadata, reader, resources);
  45. }
  46. }
  47. case 'tnn.params': {
  48. const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnproto`;
  49. const content = await context.fetch(name, null);
  50. const reader = await content.read('text');
  51. const resources = await tnn.LayerResourceReader.open(context);
  52. return new tnn.Model(metadata, reader, resources);
  53. }
  54. default: {
  55. throw new tnn.Error(`Unsupported TNN format '${context.type}'.`);
  56. }
  57. }
  58. }
  59. };
  60. tnn.Model = class {
  61. constructor(metadata, tnnproto, resources) {
  62. this.format = 'TNN';
  63. this.modules = [new tnn.Graph(metadata, tnnproto, resources)];
  64. }
  65. };
  66. tnn.Graph = class {
  67. constructor(metadata, tnnproto, resources) {
  68. this.inputs = [];
  69. this.outputs = [];
  70. this.nodes = [];
  71. const reader = new tnn.TextProtoReader(tnnproto);
  72. reader.read('\n');
  73. const values = new Map();
  74. values.map = (name, type, tensor) => {
  75. if (name.length === 0) {
  76. return new tnn.Value(name, type || null, tensor || null);
  77. }
  78. if (!values.has(name)) {
  79. values.set(name, new tnn.Value(name, type || null, tensor || null));
  80. } else if (type || tensor) {
  81. throw new tnn.Value(`Duplicate value '${name}'.`);
  82. }
  83. return values.get(name);
  84. };
  85. for (const input of reader.inputs) {
  86. const shape = new tnn.TensorShape(input.shape);
  87. const type = new tnn.TensorType(input.data_type, shape);
  88. const argument = new tnn.Argument(input.name, [values.map(input.name, type)]);
  89. this.inputs.push(argument);
  90. }
  91. for (const output of reader.outputs) {
  92. const argument = new tnn.Argument(output.name, [values.map(output.name)]);
  93. this.outputs.push(argument);
  94. }
  95. for (const layer of reader.layers) {
  96. const node = new tnn.Node(metadata, resources, layer, values);
  97. this.nodes.push(node);
  98. }
  99. }
  100. };
  101. tnn.Argument = class {
  102. constructor(name, value, type = null, visible = true) {
  103. this.name = name;
  104. this.value = value;
  105. this.type = type;
  106. this.visible = visible;
  107. }
  108. };
  109. tnn.Value = class {
  110. constructor(name, type, initializer = null) {
  111. if (typeof name !== 'string') {
  112. throw new tnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  113. }
  114. this.name = name;
  115. this.type = initializer ? initializer.type : type;
  116. this.initializer = initializer;
  117. }
  118. };
  119. tnn.Node = class {
  120. constructor(metadata, resources, layer, values) {
  121. this.inputs = [];
  122. this.outputs = [];
  123. this.attributes = [];
  124. this.name = layer.name;
  125. this.type = { ...metadata.type(layer.type) };
  126. delete this.type.identifier;
  127. const entries = Array.from(layer.params);
  128. for (let i = 0; i < entries.length;) {
  129. const metadata = this.type && Array.isArray(this.type.attributes) ? this.type.attributes[i] : null;
  130. let name = '';
  131. let value = null;
  132. let type = '';
  133. let visible = true;
  134. if (metadata && metadata.type === 'int32[]' && metadata.size) {
  135. const size = parseInt(layer.params.get(metadata.size), 10);
  136. value = entries.slice(i, i + size).map(([, value]) => parseInt(value, 10));
  137. i += size;
  138. } else {
  139. [name, value] = entries[i];
  140. i += 1;
  141. }
  142. if (metadata) {
  143. name = metadata.name ? metadata.name : name;
  144. type = metadata.type ? metadata.type : type;
  145. switch (type) {
  146. case '':
  147. break;
  148. case 'int32':
  149. value = parseInt(value, 10);
  150. break;
  151. case 'float32':
  152. value = parseFloat(value);
  153. break;
  154. case 'int32[]':
  155. value = value.map((v) => parseInt(v, 10));
  156. break;
  157. default:
  158. throw new tnn.Error(`Unsupported attribute type '${type}'.`);
  159. }
  160. visible = (metadata.visible === false) || (metadata.default !== undefined && (value === metadata.default || (value && value.toString() === metadata.default.toString()))) ? false : visible;
  161. }
  162. const argument = new tnn.Argument(name, value, type, visible);
  163. this.attributes.push(argument);
  164. }
  165. const inputs = layer.inputs;
  166. let inputIndex = 0;
  167. if (this.type && this.type.inputs) {
  168. for (const inputDef of this.type.inputs) {
  169. if (inputIndex < inputs.length || inputDef.option !== 'optional') {
  170. const inputCount = (inputDef.type === 'Tensor[]') ? (inputs.length - inputIndex) : 1;
  171. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount).filter((id) => id !== '' || inputDef.option !== 'optional').map((id) => values.map(id));
  172. const argument = new tnn.Argument(inputDef.name, inputArguments);
  173. this.inputs.push(argument);
  174. inputIndex += inputCount;
  175. }
  176. }
  177. } else {
  178. this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  179. const inputName = ((inputIndex + index) === 0) ? 'input' : (inputIndex + index).toString();
  180. return new tnn.Argument(inputName, [values.map(input)]);
  181. }));
  182. }
  183. const outputs = layer.outputs;
  184. let outputIndex = 0;
  185. if (this.type && this.type.outputs) {
  186. for (const outputDef of this.type.outputs) {
  187. if (outputIndex < outputs.length || outputDef.option !== 'optional') {
  188. const outputCount = (outputDef.option === 'variadic') ? (outputs.length - outputIndex) : 1;
  189. const outputArguments = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => values.map(id));
  190. const argument = new tnn.Argument(outputDef.name, outputArguments);
  191. this.outputs.push(argument);
  192. outputIndex += outputCount;
  193. }
  194. }
  195. } else {
  196. this.outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  197. const outputName = ((outputIndex + index) === 0) ? 'output' : (outputIndex + index).toString();
  198. return new tnn.Argument(outputName, [values.map(output)]);
  199. }));
  200. }
  201. const weight = (resource, name, shape) => {
  202. const initializer = resource[name];
  203. if (!initializer) {
  204. throw new tnn.Error(`Layer initializer'${resource.type}.${name}' not found '`);
  205. }
  206. const tensor = new tnn.Tensor(new tnn.TensorType(initializer.dataType, new tnn.TensorShape(shape)), initializer.value);
  207. const argument = new tnn.Argument(name, [values.map('', null, tensor)]);
  208. this.inputs.push(argument);
  209. };
  210. const params = layer.params;
  211. switch (this.type.name) {
  212. case 'Convolution':
  213. case 'ConvolutionDepthWise':
  214. case 'Deconvolution':
  215. case 'DeconvolutionDepthWise': {
  216. const resource = resources.get(this.name);
  217. if (resource) {
  218. const num_output = parseInt(params.get('2') || 0, 10);
  219. const kernel_w = parseInt(params.get('3') || 0, 10);
  220. const kernel_h = parseInt(params.get('4') || kernel_w, 10);
  221. const weight_data_size = resource.filter.length;
  222. weight(resource, 'filter', [num_output, weight_data_size / (num_output * kernel_w * kernel_h), kernel_w, kernel_h]);
  223. if (resource.bias) {
  224. weight(resource, 'bias', [num_output]);
  225. }
  226. if (resource.quantized) {
  227. weight(resource, 'quantized', [num_output]);
  228. }
  229. }
  230. break;
  231. }
  232. case 'Conv3D':{
  233. const resource = resources.get(this.name);
  234. if (resource) {
  235. const num_output = parseInt(params.get('2') || 0, 10);
  236. const kernel_w = parseInt(params.get('3') || 0, 10);
  237. const kernel_h = parseInt(params.get('4') || kernel_w, 10);
  238. const kernel_d = parseInt(params.get('5') || kernel_w, 10);
  239. const weight_data_size = resource.filter.length;
  240. weight(resource, 'weight', [num_output, weight_data_size / (num_output * kernel_w * kernel_h * kernel_d), kernel_w, kernel_h, kernel_d]);
  241. if (resource.bias) {
  242. weight(resources, 'bias', [num_output]);
  243. }
  244. }
  245. break;
  246. }
  247. case 'InnerProduct': {
  248. const resource = resources.get(this.name);
  249. if (resource) {
  250. const num_output = parseInt(params.get('0') || 0, 10);
  251. const weight_data_size = resource.weight.length;
  252. weight(resource, 'weight', [num_output, weight_data_size / num_output]);
  253. weight(resource, 'bias', [num_output]);
  254. if (resource.weight.dataType === 'int8') {
  255. weight(resource, 'scale', [num_output]);
  256. }
  257. }
  258. break;
  259. }
  260. case 'PReLU': {
  261. const resource = resources.get(this.name);
  262. if (resource) {
  263. weight(resource, 'slope', [resource.slope.length]);
  264. }
  265. break;
  266. }
  267. case 'BatchNormCxx':
  268. case 'InstBatchNormCxx': {
  269. const resource = resources.get(this.name);
  270. if (resource) {
  271. weight(resource, 'scale', [resource.scale.length]);
  272. weight(resource, 'bias', [resource.bias.length]);
  273. }
  274. break;
  275. }
  276. case 'Div':
  277. case 'Sub':
  278. case 'Add':
  279. case 'Mul':
  280. case 'MatMul': {
  281. if (this.inputs.length === 1) {
  282. const resource = resources.get(this.name);
  283. if (resource) {
  284. const num_output = resource.slope.length;
  285. weight(resource, 'slope', [num_output]);
  286. }
  287. }
  288. break;
  289. }
  290. case 'HdrGuide': {
  291. const resource = resources.get(this.name);
  292. if (resource) {
  293. const weight_size = resource.ccm_weight.length;
  294. weight(resource, 'ccm_weight', [weight_size]);
  295. weight(resource, 'ccm_bias', [weight_size]);
  296. weight(resource, 'shifts', [weight_size]);
  297. weight(resource, 'slopes', [weight_size]);
  298. weight(resource, 'projection_weight', [weight_size]);
  299. weight(resource, 'projection_bias', [weight_size]);
  300. }
  301. break;
  302. }
  303. case 'BlobScale': {
  304. const resource = resources.get(this.name);
  305. if (resource) {
  306. const scale_data_size = resource.scale.length;
  307. weight(resource, 'scale', [scale_data_size]);
  308. weight(resource, 'bias', [scale_data_size]);
  309. }
  310. break;
  311. }
  312. case 'Gather': {
  313. const resource = resources.get(this.name);
  314. if (resource) {
  315. if (resource.data) {
  316. weight(resource, 'data', [resource.data.length]);
  317. }
  318. if (resource.indices) {
  319. weight(resource, 'indices', [resource.indices.length]);
  320. }
  321. }
  322. break;
  323. }
  324. default: {
  325. break;
  326. }
  327. }
  328. }
  329. };
  330. tnn.Tensor = class {
  331. constructor(type, values) {
  332. this.type = type;
  333. this.values = values;
  334. }
  335. };
  336. tnn.TensorType = class {
  337. constructor(dataType, shape) {
  338. this.dataType = dataType || '?';
  339. this.shape = shape;
  340. }
  341. toString() {
  342. return this.dataType + this.shape.toString();
  343. }
  344. };
  345. tnn.TensorShape = class {
  346. constructor(dimensions) {
  347. this.dimensions = dimensions;
  348. }
  349. toString() {
  350. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`) : '';
  351. }
  352. };
  353. tnn.TextProtoReader = class {
  354. constructor(reader) {
  355. this.reader = reader;
  356. this.inputs = [];
  357. this.outputs = [];
  358. this.layers = [];
  359. }
  360. read() {
  361. if (this.reader) {
  362. let lines = [];
  363. for (;;) {
  364. const line = this.reader.read('\n');
  365. if (line === undefined) {
  366. break;
  367. }
  368. lines.push(line.replace(/\r|"/g, ''));
  369. }
  370. const split = (line, delimiter, trim, ignore_blank) => {
  371. return line.split(delimiter).map((v) => trim ? v.trim() : v).filter((v) => !ignore_blank || v);
  372. };
  373. lines = split(lines.join(''), ',', true, false);
  374. if (lines.length <= 5) {
  375. throw new tnn.Error('Invalid line count.');
  376. }
  377. const header = split(lines.shift(), ' ', true, false);
  378. if (header.length < 3) {
  379. throw new tnn.Error('Invalid header size.');
  380. } else if (header.length > 3 && (header[3] !== '4206624770' && header[3] !== '4206624772')) {
  381. throw new tnn.Error(`Invalid signature '${header[3]}'.`);
  382. }
  383. this.inputs = split(lines.shift(), ':', true, false).map((input) => {
  384. const array = split(input, ' ', true, false);
  385. const name = array.shift();
  386. if (header[3] === '4206624772') {
  387. const shape_size = parseInt(array.shift(), 10);
  388. const data_type_index = parseInt(array[shape_size], 10);
  389. return {
  390. name,
  391. data_type: ['float32', 'float16', 'int8', 'int32', 'bfloat16'][data_type_index],
  392. shape: array.slice(0, -1).map((dim) => parseInt(dim, 10)),
  393. };
  394. }
  395. return {
  396. name,
  397. data_type: 'float32',
  398. shape: array.map((dim) => parseInt(dim, 10))
  399. };
  400. });
  401. lines.shift();
  402. this.outputs = split(lines.shift(), ' ', true, false).map((output) => {
  403. return { name: output };
  404. });
  405. lines.shift();
  406. while (lines.length > 0) {
  407. const line = lines.shift().trim();
  408. if (line.length > 0) {
  409. const array = split(line, ' ', true, true);
  410. const layer = {};
  411. layer.type = array.shift();
  412. layer.name = array.shift();
  413. const inputs = parseInt(array.shift(), 10);
  414. const outputs = parseInt(array.shift(), 10);
  415. layer.inputs = array.splice(0, inputs);
  416. layer.outputs = array.splice(0, outputs);
  417. layer.params = new Map();
  418. let count = 0;
  419. for (const column of array) {
  420. const parts = column.split(' ');
  421. if (parts.length === 1) {
  422. let key = count.toString();
  423. let value = parts.toString();
  424. const keyInt = parseInt(key, 10);
  425. if (keyInt < 0) {
  426. value = value.split(',').map((v) => v.trim());
  427. value.shift();
  428. key = (-(keyInt + 23300)).toString();
  429. }
  430. layer.params.set(key, value);
  431. count++;
  432. }
  433. }
  434. this.layers.push(layer);
  435. }
  436. }
  437. delete this.reader;
  438. }
  439. }
  440. };
  441. tnn.LayerResourceReader = class {
  442. static async open(context) {
  443. if (context) {
  444. const reader = await context.read('binary');
  445. return new tnn.LayerResourceReader(reader);
  446. }
  447. return new tnn.LayerResourceReader(null);
  448. }
  449. constructor(reader) {
  450. this.resources = new Map();
  451. if (reader) {
  452. this.reader = reader;
  453. const magic_number = this.reader.uint32();
  454. if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
  455. throw new tnn.Error(`Invalid blob header signature '${magic_number}'.`);
  456. }
  457. const size = this.reader.int32() & 0x1FFFFFFF;
  458. for (let i = 0; i < size; i++) {
  459. const resource = {};
  460. resource.operator = this.reader.int32();
  461. resource.type = this.reader.string();
  462. resource.name = this.reader.string();
  463. switch (resource.type) {
  464. case 'Convolution':
  465. case 'ConvolutionDepthWise':
  466. case 'Deconvolution':
  467. case 'DeconvolutionDepthWise': {
  468. this._expect(resource.name);
  469. const bias = this.reader.int32();
  470. resource.filter = this._read();
  471. if (bias) {
  472. resource.bias = this._read();
  473. }
  474. if (resource.filter.dataType === 'int8') {
  475. resource.quantized = this._read();
  476. }
  477. break;
  478. }
  479. case 'Conv3D': {
  480. this._expect(resource.name);
  481. const bias = this.reader.int32();
  482. resource.filter = this._read();
  483. if (bias) {
  484. resource.bias = this._read();
  485. }
  486. break;
  487. }
  488. case 'InnerProduct': {
  489. this._expect(resource.name);
  490. resource.weight = this._read();
  491. resource.bias = this._read();
  492. if (resource.weight.dataType === 'int8') {
  493. resource.scale = this._read();
  494. }
  495. break;
  496. }
  497. case 'PReLU': {
  498. this._expect(resource.name);
  499. resource.slope = this._read();
  500. break;
  501. }
  502. case 'Add':
  503. case 'Div':
  504. case 'Mul':
  505. case 'Sub':
  506. case 'MatMul': {
  507. resource.slope = this._read();
  508. break;
  509. }
  510. case 'BatchNormCxx':
  511. case 'InstBatchNormCxx':
  512. resource.scale = this._read();
  513. resource.bias = this._read();
  514. break;
  515. case 'HdrGuide':
  516. resource.ccm_weight = this._read();
  517. resource.ccm_bias = this._read();
  518. resource.shifts = this._read();
  519. resource.slopes = this._read();
  520. resource.projection_weight = this._read();
  521. resource.projection_bias = this._read();
  522. break;
  523. case 'BlobScale':
  524. resource.scale = this._read();
  525. resource.bias = this._read();
  526. break;
  527. case 'Gather': {
  528. // reader.expect(resource.name);
  529. const has_data = this.reader.int32();
  530. if (has_data) {
  531. resource.data = this._read();
  532. }
  533. const has_indices = this.reader.int32();
  534. if (has_indices) {
  535. resource.indices = this._read();
  536. }
  537. break;
  538. }
  539. default: {
  540. throw new tnn.Error(`Unsupported layer resource type '${resource.type}'.`);
  541. }
  542. }
  543. this.resources.set(resource.name, resource);
  544. }
  545. if (this.reader.position !== this.reader.length) {
  546. throw new tnn.Error("Invalid blob size.");
  547. }
  548. delete this.reader;
  549. }
  550. }
  551. _read() {
  552. const magic_number = this.reader.uint32();
  553. if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
  554. throw new tnn.Error(`Invalid raw signature '${magic_number}'.`);
  555. }
  556. const data_type = this.reader.int32();
  557. if (data_type > 4) {
  558. throw new tnn.Error(`Unsupported data type '${data_type}'.`);
  559. }
  560. const length = this.reader.int32();
  561. if (length <= 0) {
  562. return null;
  563. }
  564. let dims = null;
  565. if (magic_number === 0xFABC0004) {
  566. const dim_size = this.reader.int32();
  567. dims = this.reader.read(dim_size * 4);
  568. }
  569. return {
  570. dataType: ['float32', 'float16', 'int8', 'int32', 'bfloat16'][data_type],
  571. length: length / [4, 2, 1, 4, 2][data_type],
  572. value: this.reader.read(length),
  573. shape: dims
  574. };
  575. }
  576. _expect(name) {
  577. const content = this.reader.string();
  578. if (name !== content) {
  579. throw new tnn.Error(`Invalid string '${content}' instead of '${name}'.`);
  580. }
  581. }
  582. get(name) {
  583. if (this.resources.size === 0) {
  584. return null;
  585. }
  586. if (!this.resources.has(name)) {
  587. throw new tnn.Error(`Invalid blob layer name '${name}'.`);
  588. }
  589. return this.resources.get(name);
  590. }
  591. };
  592. tnn.Error = class extends Error {
  593. constructor(message) {
  594. super(message);
  595. this.name = 'Error loading TNN model.';
  596. }
  597. };
  598. export const ModelFactory = tnn.ModelFactory;