tnn.js 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. import * as base from './base.js';
  2. import * as text from './text.js';
  3. const tnn = {};
  4. tnn.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier.toLowerCase();
  7. const stream = context.stream;
  8. if (stream && identifier.endsWith('.tnnproto')) {
  9. try {
  10. const buffer = stream.peek();
  11. const reader = text.Reader.open(buffer, 2048);
  12. const content = reader.read();
  13. if (content !== undefined) {
  14. const line = content.trim();
  15. if (line.startsWith('"') && line.endsWith('"')) {
  16. const header = line.replace(/(^")|("$)/g, '').split(',').shift().trim().split(' ');
  17. if (header.length === 3 || (header.length >= 4 && (header[3] === '4206624770' || header[3] == '4206624772'))) {
  18. context.type = 'tnn.model';
  19. return;
  20. }
  21. }
  22. }
  23. } catch (err) {
  24. // continue regardless of error
  25. }
  26. }
  27. if (stream && identifier.endsWith('.tnnmodel')) {
  28. for (const signature of [ [ 0x02, 0x00, 0xbc, 0xfa ], [ 0x04, 0x00, 0xbc, 0xfa ] ]) {
  29. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  30. context.type = 'tnn.params';
  31. return;
  32. }
  33. }
  34. }
  35. }
  36. async open(context) {
  37. const metadata = await context.metadata('tnn-metadata.json');
  38. switch (context.type) {
  39. case 'tnn.model': {
  40. const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnmodel`;
  41. try {
  42. const content = await context.fetch(name);
  43. const buffer = content.stream.peek();
  44. return new tnn.Model(metadata, context.stream.peek(), buffer);
  45. } catch (error) {
  46. return new tnn.Model(metadata, context.stream.peek(), null);
  47. }
  48. }
  49. case 'tnn.params': {
  50. const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnproto`;
  51. const content = await context.fetch(name, null);
  52. const buffer = content.stream.peek();
  53. return new tnn.Model(metadata, buffer, context.stream.peek());
  54. }
  55. default: {
  56. throw new tnn.Error(`Unsupported TNN format '${context.type}'.`);
  57. }
  58. }
  59. }
  60. };
  61. tnn.Model = class {
  62. constructor(metadata, tnnproto, tnnmodel) {
  63. this.format = 'TNN';
  64. this.graphs = [
  65. new tnn.Graph(metadata, tnnproto, tnnmodel)
  66. ];
  67. }
  68. };
  69. tnn.Graph = class {
  70. constructor(metadata, tnnproto, tnnmodel) {
  71. this.inputs = [];
  72. this.outputs = [];
  73. this.nodes = [];
  74. const resources = new tnn.LayerResourceReader(tnnmodel);
  75. const reader = new tnn.TextProtoReader(tnnproto);
  76. const values = new Map();
  77. values.map = (name, type, tensor) => {
  78. if (name.length === 0) {
  79. return new tnn.Value(name, type || null, tensor || null);
  80. }
  81. if (!values.has(name)) {
  82. values.set(name, new tnn.Value(name, type || null, tensor || null));
  83. } else if (type || tensor) {
  84. throw new tnn.Value(`Duplicate value '${name}'.`);
  85. }
  86. return values.get(name);
  87. };
  88. for (const input of reader.inputs) {
  89. const shape = new tnn.TensorShape(input.shape);
  90. const type = new tnn.TensorType(input.data_type, shape);
  91. this.inputs.push(new tnn.Argument(input.name, [ values.map(input.name, type) ]));
  92. }
  93. for (const output of reader.outputs) {
  94. this.outputs.push(new tnn.Argument(output.name, [ values.map(output.name) ]));
  95. }
  96. for (const layer of reader.layers) {
  97. this.nodes.push(new tnn.Node(metadata, resources, layer, values));
  98. }
  99. }
  100. };
  101. tnn.Argument = class {
  102. constructor(name, value) {
  103. this.name = name;
  104. this.value = value;
  105. }
  106. };
  107. tnn.Value = class {
  108. constructor(name, type, initializer) {
  109. if (typeof name !== 'string') {
  110. throw new tnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  111. }
  112. this._name = name;
  113. this._type = type || null;
  114. this._initializer = initializer || null;
  115. }
  116. get name() {
  117. return this._name;
  118. }
  119. get type() {
  120. if (this._initializer) {
  121. return this._initializer.type;
  122. }
  123. return this._type;
  124. }
  125. get initializer() {
  126. return this._initializer;
  127. }
  128. };
  129. tnn.Node = class {
  130. constructor(metadata, resources, layer, values) {
  131. this.inputs = [];
  132. this.outputs = [];
  133. this.attributes = [];
  134. this.name = layer.name;
  135. this.type = metadata.type(layer.type);
  136. const attributeSchemas = this.type && this.type.attributes ? this.type && this.type.attributes.slice() : [];
  137. const attributes = layer.attributes.slice();
  138. while (attributes.length > 0) {
  139. const attributeSchema = attributeSchemas.shift();
  140. let value = null;
  141. let name = '';
  142. if (attributeSchema && attributeSchema.type === 'int32[]' && attributeSchema.size) {
  143. name = attributeSchema.name;
  144. value = attributes.splice(0, layer.attr[attributeSchema.size]).map((attribute) => parseInt(attribute.value, 10));
  145. } else {
  146. const attribute = attributes.shift();
  147. name = attribute.key;
  148. value = attribute.value;
  149. }
  150. this.attributes.push(new tnn.Attribute(attributeSchema, name, value));
  151. }
  152. const inputs = layer.inputs;
  153. let inputIndex = 0;
  154. if (this.type && this.type.inputs) {
  155. for (const inputDef of this.type.inputs) {
  156. if (inputIndex < inputs.length || inputDef.option != 'optional') {
  157. const inputCount = (inputDef.option == 'variadic') ? (inputs.length - inputIndex) : 1;
  158. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount).filter((id) => id != '' || inputDef.option != 'optional').map((id) => values.map(id));
  159. this.inputs.push(new tnn.Argument(inputDef.name, inputArguments));
  160. inputIndex += inputCount;
  161. }
  162. }
  163. } else {
  164. this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  165. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  166. return new tnn.Argument(inputName, [ values.map(input) ]);
  167. }));
  168. }
  169. const outputs = layer.outputs;
  170. let outputIndex = 0;
  171. if (this.type && this.type.outputs) {
  172. for (const outputDef of this.type.outputs) {
  173. if (outputIndex < outputs.length || outputDef.option != 'optional') {
  174. const outputCount = (outputDef.option == 'variadic') ? (outputs.length - outputIndex) : 1;
  175. const outputArguments = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => values.map(id));
  176. this.outputs.push(new tnn.Argument(outputDef.name, outputArguments));
  177. outputIndex += outputCount;
  178. }
  179. }
  180. } else {
  181. this.outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  182. const outputName = ((outputIndex + index) == 0) ? 'output' : (outputIndex + index).toString();
  183. return new tnn.Argument(outputName, [ values.map(output) ]);
  184. }));
  185. }
  186. const weight = (resource, name, shape) => {
  187. const initializer = resource[name];
  188. if (!initializer) {
  189. throw new tnn.Error(`Layer initializer'${resource.type}.${name}' not found '`);
  190. }
  191. const tensor = new tnn.Tensor(new tnn.TensorType(initializer.dataType, new tnn.TensorShape(shape)), initializer.value);
  192. this.inputs.push(new tnn.Argument(name, [ values.map('', null, tensor) ]));
  193. };
  194. switch (this.type.name) {
  195. case 'Convolution':
  196. case 'ConvolutionDepthWise':
  197. case 'Deconvolution':
  198. case 'DeconvolutionDepthWise': {
  199. const resource = resources.read(this.name);
  200. if (resource) {
  201. const num_output = parseInt(layer.attr['2'] || 0, 10);
  202. const kernel_w = parseInt(layer.attr['3'] || 0, 10);
  203. const kernel_h = parseInt(layer.attr['4'] || kernel_w, 10);
  204. const weight_data_size = resource.filter.length;
  205. weight(resource, 'filter', [ num_output, weight_data_size / (num_output * kernel_w * kernel_h), kernel_w, kernel_h ]);
  206. if (resource.bias) {
  207. weight(resource, 'bias', [ num_output ]);
  208. }
  209. if (resource.quantized) {
  210. weight(resource, 'quantized', [ num_output ]);
  211. }
  212. }
  213. break;
  214. }
  215. case 'Conv3D':{
  216. const resource = resources.read(this.name);
  217. if (resource) {
  218. const num_output = parseInt(layer.attr['2'] || 0, 10);
  219. const kernel_w = parseInt(layer.attr['3'] || 0, 10);
  220. const kernel_h = parseInt(layer.attr['4'] || kernel_w, 10);
  221. const kernel_d = parseInt(layer.attr['5'] || kernel_w, 10);
  222. const weight_data_size = resource.filter.length;
  223. weight(resource, 'weight', [ num_output, weight_data_size / (num_output * kernel_w * kernel_h * kernel_d), kernel_w, kernel_h, kernel_d ]);
  224. if (resource.bias) {
  225. weight(resources, 'bias', [ num_output ]);
  226. }
  227. }
  228. break;
  229. }
  230. case 'InnerProduct': {
  231. const resource = resources.read(this.name);
  232. if (resource) {
  233. const num_output = parseInt(layer.attr['0'] || 0, 10);
  234. const weight_data_size = resource.weight.length;
  235. weight(resource, 'weight', [ num_output, weight_data_size / num_output ]);
  236. weight(resource, 'bias', [ num_output ]);
  237. if (resource.weight.dataType === 'int8') {
  238. weight(resource, 'scale', [ num_output ]);
  239. }
  240. }
  241. break;
  242. }
  243. case 'PReLU': {
  244. const resource = resources.read(this.name);
  245. if (resource) {
  246. weight(resource, 'slope', [ resource.slope.length ]);
  247. }
  248. break;
  249. }
  250. case 'BatchNormCxx':
  251. case 'InstBatchNormCxx': {
  252. const resource = resources.read(this.name);
  253. if (resource) {
  254. weight(resource, 'scale', [ resource.scale.length ]);
  255. weight(resource, 'bias', [ resource.bias.length ]);
  256. }
  257. break;
  258. }
  259. case 'Div':
  260. case 'Sub':
  261. case 'Add':
  262. case 'Mul':
  263. case 'MatMul': {
  264. if (this.inputs.length === 1) {
  265. const resource = resources.read(this.name);
  266. if (resource) {
  267. const num_output = resource.slope.length;
  268. weight(resource, 'slope', [ num_output ]);
  269. }
  270. }
  271. break;
  272. }
  273. case 'HdrGuide': {
  274. const resource = resources.read(this.name);
  275. if (resource) {
  276. const weight_size = resource.ccm_weight.length;
  277. weight(resource, 'ccm_weight', [ weight_size ]);
  278. weight(resource, 'ccm_bias', [ weight_size ]);
  279. weight(resource, 'shifts', [ weight_size ]);
  280. weight(resource, 'slopes', [ weight_size ]);
  281. weight(resource, 'projection_weight', [ weight_size ]);
  282. weight(resource, 'projection_bias', [ weight_size ]);
  283. }
  284. break;
  285. }
  286. case 'BlobScale': {
  287. const resource = resources.read(this.name);
  288. if (resource) {
  289. const scale_data_size = resource.scale.length;
  290. weight(resource, 'scale', [ scale_data_size]);
  291. weight(resource, 'bias', [ scale_data_size ]);
  292. }
  293. break;
  294. }
  295. case 'Gather': {
  296. const resource = resources.read(this.name);
  297. if (resource) {
  298. if (resource.data) {
  299. weight(resource, 'data', [ resource.data.length ]);
  300. }
  301. if (resource.indices) {
  302. weight(resource, 'indices', [ resource.indices.length ]);
  303. }
  304. }
  305. break;
  306. }
  307. default: {
  308. break;
  309. }
  310. }
  311. }
  312. };
  313. tnn.Attribute = class {
  314. constructor(metadata, key, value) {
  315. this.type = '';
  316. this.name = key.toString();
  317. this.value = value;
  318. if (metadata) {
  319. this.name = metadata.name;
  320. if (metadata.type) {
  321. this.type = metadata.type;
  322. }
  323. switch (this.type) {
  324. case '':
  325. break;
  326. case 'int32':
  327. this.value = parseInt(this.value, 10);
  328. break;
  329. case 'float32':
  330. this.value = parseFloat(this.value);
  331. break;
  332. case 'int32[]':
  333. this.value = this.value.map((v) => parseInt(v, 10));
  334. break;
  335. case 'float32[]':
  336. this.value = this.value.map((v) => parseFloat(v));
  337. break;
  338. default:
  339. throw new tnn.Error(`Unsupported attribute type '${this.type}'.`);
  340. }
  341. if (metadata && metadata.visible === false) {
  342. this.visible = false;
  343. } else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  344. if (this.value == metadata.default || (this.value && this.value.toString() == metadata.default.toString())) {
  345. this.visible = false;
  346. }
  347. }
  348. }
  349. }
  350. };
  351. tnn.Tensor = class {
  352. constructor(type, values) {
  353. this.type = type;
  354. this.values = values;
  355. }
  356. };
  357. tnn.TensorType = class {
  358. constructor(dataType, shape) {
  359. this.dataType = dataType || '?';
  360. this.shape = shape;
  361. }
  362. toString() {
  363. return this.dataType + this.shape.toString();
  364. }
  365. };
  366. tnn.TensorShape = class {
  367. constructor(dimensions) {
  368. this.dimensions = dimensions;
  369. }
  370. toString() {
  371. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`) : '';
  372. }
  373. };
  374. tnn.TextProtoReader = class {
  375. constructor(buffer) {
  376. const reader = text.Reader.open(buffer);
  377. let lines = [];
  378. for (;;) {
  379. const line = reader.read();
  380. if (line === undefined) {
  381. break;
  382. }
  383. lines.push(line.replace(/\r|"/g, ''));
  384. }
  385. const split = (line, delimiter, trim, ignore_blank) => {
  386. return line.split(delimiter).map((v) => trim ? v.trim() : v).filter((v) => !ignore_blank || v);
  387. };
  388. lines = split(lines.join(''), ',', true, false);
  389. if (lines.length <= 5) {
  390. throw new tnn.Error('Invalid line count.');
  391. }
  392. const header = split(lines.shift(), ' ', true, false);
  393. if (header.length < 3) {
  394. throw new tnn.Error('Invalid header size.');
  395. } else if (header.length > 3 && (header[3] !== '4206624770' && header[3] !== '4206624772')) {
  396. throw new tnn.Error(`Invalid signature '${header[3]}'.`);
  397. }
  398. this.inputs = split(lines.shift(), ':', true, false).map((input) => {
  399. const array = split(input, ' ', true, false);
  400. const name = array.shift();
  401. if (header[3] === '4206624772') {
  402. const shape_size = parseInt(array.shift(), 10);
  403. const data_type_index = parseInt(array[shape_size], 10);
  404. return {
  405. name: name,
  406. data_type: [ 'float32', 'float16', 'int8', 'int32', 'bfloat16' ][data_type_index],
  407. shape: array.slice(0, -1).map((dim) => parseInt(dim, 10)),
  408. };
  409. }
  410. return {
  411. name: name,
  412. data_type: 'float32',
  413. shape: array.map((dim) => parseInt(dim, 10))
  414. };
  415. });
  416. lines.shift();
  417. this.outputs = split(lines.shift(), ' ', true, false).map((output) => {
  418. return { name: output };
  419. });
  420. lines.shift();
  421. this.layers = [];
  422. while (lines.length > 0) {
  423. const line = lines.shift().trim();
  424. if (line.length > 0) {
  425. const array = split(line, ' ', true, true);
  426. const layer = {};
  427. layer.type = array.shift();
  428. layer.name = array.shift();
  429. const inputCount = parseInt(array.shift(), 10);
  430. const outputCount = parseInt(array.shift(), 10);
  431. layer.inputs = array.splice(0, inputCount);
  432. layer.outputs = array.splice(0, outputCount);
  433. layer.attr = {};
  434. layer.attributes = [];
  435. let count = 0;
  436. for (const column of array) {
  437. const parts = column.split(' ');
  438. if (parts.length === 1) {
  439. let key = count;
  440. let value = parts.toString();
  441. const keyInt = parseInt(key, 10);
  442. if (keyInt < 0) {
  443. value = value.split(',').map((v) => v.trim());
  444. value.shift();
  445. key = (-(keyInt + 23300)).toString();
  446. }
  447. layer.attr[key] = value;
  448. layer.attributes.push({ key: key, value: value });
  449. count++;
  450. }
  451. }
  452. this.layers.push(layer);
  453. }
  454. }
  455. }
  456. };
  457. tnn.LayerResourceReader = class {
  458. constructor(buffer) {
  459. this.layerResources = [];
  460. if (buffer) {
  461. const reader = new base.BinaryReader(buffer);
  462. const magic_number = reader.uint32();
  463. if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
  464. throw new tnn.Error(`Invalid blob header signature '${magic_number}'.`);
  465. }
  466. this.layerResources = new Array(reader.int32() & 0x1FFFFFFF);
  467. const raw = (reader) => {
  468. const magic_number = reader.uint32();
  469. if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
  470. throw new tnn.Error(`Invalid raw signature '${magic_number}'.`);
  471. }
  472. const data_type = reader.int32();
  473. if (data_type > 4) {
  474. throw new tnn.Error(`Unsupported data type '${data_type}'.`);
  475. }
  476. const length = reader.int32();
  477. if (length <= 0) {
  478. return null;
  479. }
  480. let dims = null;
  481. if (magic_number === 0xFABC0004) {
  482. const dim_size = reader.int32();
  483. dims = reader.read(dim_size * 4);
  484. }
  485. return {
  486. dataType: [ 'float32', 'float16', 'int8', 'int32', 'bfloat16' ][data_type],
  487. length: length / [ 4, 2, 1, 4, 2 ][data_type],
  488. value: reader.read(length),
  489. shape: dims
  490. };
  491. };
  492. const expect = (reader, name) => {
  493. const content = reader.string();
  494. if (name !== content) {
  495. throw new tnn.Error(`Invalid string '${content}' instead of '${name}'.`);
  496. }
  497. };
  498. for (let i = 0; i < this.layerResources.length; i++) {
  499. const resource = {};
  500. resource.operator = reader.int32();
  501. resource.type = reader.string();
  502. resource.name = reader.string();
  503. switch (resource.type) {
  504. case 'Convolution':
  505. case 'ConvolutionDepthWise':
  506. case 'Deconvolution':
  507. case 'DeconvolutionDepthWise': {
  508. expect(reader, resource.name);
  509. const bias = reader.int32();
  510. resource.filter = raw(reader);
  511. if (bias) {
  512. resource.bias = raw(reader);
  513. }
  514. if (resource.filter.dataType === 'int8') {
  515. resource.quantized = raw(reader);
  516. }
  517. break;
  518. }
  519. case 'Conv3D': {
  520. expect(reader, resource.name);
  521. const bias = reader.int32();
  522. resource.filter = raw(reader);
  523. if (bias) {
  524. resource.bias = raw(reader);
  525. }
  526. break;
  527. }
  528. case 'InnerProduct': {
  529. expect(reader, resource.name);
  530. resource.weight = raw(reader);
  531. resource.bias = raw(reader);
  532. if (resource.weight.dataType === 'int8') {
  533. resource.scale = raw(reader);
  534. }
  535. break;
  536. }
  537. case 'PReLU': {
  538. expect(reader, resource.name);
  539. resource.slope = raw(reader);
  540. break;
  541. }
  542. case 'Add':
  543. case 'Div':
  544. case 'Mul':
  545. case 'Sub':
  546. case 'MatMul': {
  547. resource.slope = raw(reader);
  548. break;
  549. }
  550. case 'BatchNormCxx':
  551. case 'InstBatchNormCxx':
  552. resource.scale = raw(reader);
  553. resource.bias = raw(reader);
  554. break;
  555. case 'HdrGuide':
  556. resource.ccm_weight = raw(reader);
  557. resource.ccm_bias = raw(reader);
  558. resource.shifts = raw(reader);
  559. resource.slopes = raw(reader);
  560. resource.projection_weight = raw(reader);
  561. resource.projection_bias = raw(reader);
  562. break;
  563. case 'BlobScale':
  564. resource.scale = raw(reader);
  565. resource.bias = raw(reader);
  566. break;
  567. case 'Gather': {
  568. // reader.expect(resource.name);
  569. const has_data = reader.int32();
  570. if (has_data) {
  571. resource.data = raw(reader);
  572. }
  573. const has_indices = reader.int32();
  574. if (has_indices) {
  575. resource.indices = raw(reader);
  576. }
  577. break;
  578. }
  579. default: {
  580. throw new tnn.Error(`Unsupported layer resource type '${resource.type}'.`);
  581. }
  582. }
  583. this.layerResources[i] = resource;
  584. }
  585. if (reader.position !== reader.length) {
  586. throw new tnn.Error("Invalid blob size.");
  587. }
  588. }
  589. }
  590. read(name) {
  591. const resource = this.layerResources.shift();
  592. if (resource && resource.name !== name) {
  593. throw new tnn.Error(`Invalid blob layer name '${name}'.`);
  594. }
  595. return resource;
  596. }
  597. };
  598. tnn.Error = class extends Error {
  599. constructor(message) {
  600. super(message);
  601. this.name = 'Error loading TNN model.';
  602. }
  603. };
  604. export const ModelFactory = tnn.ModelFactory;