tnn.js 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923
  1. /* jshint esversion: 6 */
  2. var tnn = tnn || {};
  3. var base = base || require('./base');
  4. tnn.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier.toLowerCase();
  7. if (identifier.endsWith('.tnnproto')) {
  8. try {
  9. const reader = base.TextReader.open(context.stream.peek(), 2048);
  10. const text = reader.read();
  11. if (text !== undefined) {
  12. const line = text.trim();
  13. if (line.startsWith('"') && line.endsWith('"')) {
  14. const header = line.replace(/(^")|("$)/g, '').split(',').shift().trim().split(' ');
  15. if (header.length === 3 || (header.length >= 4 && (header[3] === '4206624770' || header[3] == '4206624772'))) {
  16. return true;
  17. }
  18. }
  19. }
  20. }
  21. catch (err) {
  22. // continue regardless of error
  23. }
  24. }
  25. if (identifier.endsWith('.tnnmodel')) {
  26. const stream = context.stream;
  27. for (const signature of [ [ 0x02, 0x00, 0xbc, 0xfa ], [ 0x04, 0x00, 0xbc, 0xfa ] ]) {
  28. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  29. return true;
  30. }
  31. }
  32. }
  33. return false;
  34. }
  35. open(context) {
  36. return tnn.Metadata.open(context).then((metadata) => {
  37. const identifier = context.identifier.toLowerCase();
  38. if (identifier.endsWith('.tnnproto')) {
  39. const tnnmodel = context.identifier.substring(0, context.identifier.length - 9) + '.tnnmodel';
  40. return context.request(tnnmodel, null).then((stream) => {
  41. const buffer = stream.peek();
  42. return new tnn.Model(metadata, context.stream.peek(), buffer);
  43. }).catch(() => {
  44. return new tnn.Model(metadata, context.stream.peek(), null);
  45. });
  46. }
  47. else if (identifier.endsWith('.tnnmodel')) {
  48. const tnnproto = context.identifier.substring(0, context.identifier.length - 9) + '.tnnproto';
  49. return context.request(tnnproto, null).then((stream) => {
  50. const buffer = stream.peek();
  51. return new tnn.Model(metadata, buffer, context.stream.peek());
  52. });
  53. }
  54. });
  55. }
  56. };
  57. tnn.Model = class {
  58. constructor(metadata, tnnproto, tnnmodel) {
  59. this._graphs = [];
  60. this._graphs.push(new tnn.Graph(metadata, tnnproto, tnnmodel));
  61. }
  62. get format() {
  63. return 'TNN';
  64. }
  65. get graphs() {
  66. return this._graphs;
  67. }
  68. };
  69. tnn.Graph = class {
  70. constructor(metadata, tnnproto, tnnmodel) {
  71. this._inputs = [];
  72. this._outputs = [];
  73. this._nodes = [];
  74. const resources = new tnn.LayerResourceReader(tnnmodel);
  75. const reader = new tnn.TextProtoReader(tnnproto);
  76. for (const input of reader.inputs) {
  77. const shape = new tnn.TensorShape(input.shape);
  78. const type = new tnn.TensorType(input.data_type, shape);
  79. this._inputs.push(new tnn.Parameter(input.name, [ new tnn.Argument(input.name, type, null) ]));
  80. }
  81. for (const output of reader.outputs) {
  82. this._outputs.push(new tnn.Parameter(output.name, [ new tnn.Argument(output.name, null, null) ]));
  83. }
  84. for (const layer of reader.layers) {
  85. this._nodes.push(new tnn.Node(metadata, resources, layer));
  86. }
  87. }
  88. get inputs() {
  89. return this._inputs;
  90. }
  91. get outputs() {
  92. return this._outputs;
  93. }
  94. get nodes() {
  95. return this._nodes;
  96. }
  97. };
  98. tnn.Parameter = class {
  99. constructor(name, args) {
  100. this._name = name;
  101. this._arguments = args;
  102. }
  103. get name() {
  104. return this._name;
  105. }
  106. get visible() {
  107. return true;
  108. }
  109. get arguments() {
  110. return this._arguments;
  111. }
  112. };
  113. tnn.Argument = class {
  114. constructor(name, type, initializer) {
  115. if (typeof name !== 'string') {
  116. throw new tnn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  117. }
  118. this._name = name;
  119. this._type = type || null;
  120. this._initializer = initializer || null;
  121. }
  122. get name() {
  123. return this._name;
  124. }
  125. get type() {
  126. if (this._initializer) {
  127. return this._initializer.type;
  128. }
  129. return this._type;
  130. }
  131. get initializer() {
  132. return this._initializer;
  133. }
  134. };
  135. tnn.Node = class {
  136. constructor(metadata, resources, layer) {
  137. this._inputs = [];
  138. this._outputs = [];
  139. this._attributes = [];
  140. this._name = layer.name;
  141. let type = layer.type;
  142. const operator = metadata.operator(type);
  143. if (operator) {
  144. type = operator;
  145. }
  146. this._type = metadata.type(type) || { name: type };
  147. const attributeSchemas = this._type && this._type.attributes ? this._type && this._type.attributes.slice() : [];
  148. const attributes = layer.attributes.slice();
  149. while (attributes.length > 0) {
  150. const attributeSchema = attributeSchemas.shift();
  151. let value = null;
  152. let name = '';
  153. if (attributeSchema && attributeSchema.type === 'int32[]' && attributeSchema.size) {
  154. name = attributeSchema.name;
  155. value = attributes.splice(0, layer.attr[attributeSchema.size]).map((attribute) => parseInt(attribute.value, 10));
  156. }
  157. else {
  158. const attribute = attributes.shift();
  159. name = attribute.key;
  160. value = attribute.value;
  161. }
  162. this._attributes.push(new tnn.Attribute(attributeSchema, name, value));
  163. }
  164. const inputs = layer.inputs;
  165. let inputIndex = 0;
  166. if (this._type && this._type.inputs) {
  167. for (const inputDef of this._type.inputs) {
  168. if (inputIndex < inputs.length || inputDef.option != 'optional') {
  169. const inputCount = (inputDef.option == 'variadic') ? (inputs.length - inputIndex) : 1;
  170. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount).filter((id) => id != '' || inputDef.option != 'optional').map((id) => {
  171. return new tnn.Argument(id, null, null);
  172. });
  173. this._inputs.push(new tnn.Parameter(inputDef.name, inputArguments));
  174. inputIndex += inputCount;
  175. }
  176. }
  177. }
  178. else {
  179. this._inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  180. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  181. return new tnn.Parameter(inputName, [ new tnn.Argument(input, null, null) ]);
  182. }));
  183. }
  184. const outputs = layer.outputs;
  185. let outputIndex = 0;
  186. if (this._type && this._type.outputs) {
  187. for (const outputDef of this._type.outputs) {
  188. if (outputIndex < outputs.length || outputDef.option != 'optional') {
  189. const outputCount = (outputDef.option == 'variadic') ? (outputs.length - outputIndex) : 1;
  190. const outputArguments = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => {
  191. return new tnn.Argument(id, null, null);
  192. });
  193. this._outputs.push(new tnn.Parameter(outputDef.name, outputArguments));
  194. outputIndex += outputCount;
  195. }
  196. }
  197. }
  198. else {
  199. this._outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  200. const outputName = ((outputIndex + index) == 0) ? 'output' : (outputIndex + index).toString();
  201. return new tnn.Parameter(outputName, [ new tnn.Argument(output, null, null) ]);
  202. }));
  203. }
  204. switch (type) {
  205. case 'Convolution':
  206. case 'ConvolutionDepthWise':
  207. case 'Deconvolution':
  208. case 'DeconvolutionDepthWise': {
  209. const resource = resources.read(this._name);
  210. if (resource) {
  211. const num_output = parseInt(layer.attr['2'] || 0, 10);
  212. const kernel_w = parseInt(layer.attr['3'] || 0, 10);
  213. const kernel_h = parseInt(layer.attr['4'] || kernel_w, 10);
  214. const weight_data_size = resource.filter.length;
  215. this._weight(resource, 'filter', [ num_output, weight_data_size / ( num_output * kernel_w * kernel_h), kernel_w, kernel_h ]);
  216. if (resource.bias) {
  217. this._weight(resource, 'bias', [ num_output ]);
  218. }
  219. if (resource.quantized) {
  220. this._weight(resource, 'quantized', [ num_output ]);
  221. }
  222. }
  223. break;
  224. }
  225. case 'Conv3D':{
  226. const resource = resources.read(this._name);
  227. if (resource) {
  228. const num_output = parseInt(layer.attr['2'] || 0, 10);
  229. const kernel_w = parseInt(layer.attr['3'] || 0, 10);
  230. const kernel_h = parseInt(layer.attr['4'] || kernel_w, 10);
  231. const kernel_d = parseInt(layer.attr['5'] || kernel_w, 10);
  232. const weight_data_size = resource.filter.length;
  233. this._weight(resource, 'weight', [ num_output, weight_data_size / ( num_output * kernel_w * kernel_h * kernel_d), kernel_w, kernel_h, kernel_d ]);
  234. if (resource.bias) {
  235. this._weight(resources, 'bias', [ num_output ]);
  236. }
  237. }
  238. break;
  239. }
  240. case 'InnerProduct': {
  241. const resource = resources.read(this._name);
  242. if (resource) {
  243. const num_output = parseInt(layer.attr['0'] || 0, 10);
  244. const weight_data_size = resource.weight.length;
  245. this._weight(resource, 'weight', [ num_output, weight_data_size / num_output ]);
  246. this._weight(resource, 'bias', [ num_output ]);
  247. if (resource.weight.dataType === 'int8') {
  248. this._weight(resource, 'scale', [ num_output ]);
  249. }
  250. }
  251. break;
  252. }
  253. case 'PReLU': {
  254. const resource = resources.read(this._name);
  255. if (resource) {
  256. this._weight(resource, 'slope', [ resource.slope.length ]);
  257. }
  258. break;
  259. }
  260. case 'BatchNormCxx':
  261. case 'InstBatchNormCxx': {
  262. const resource = resources.read(this._name);
  263. if (resource) {
  264. this._weight(resource, 'scale', [ resource.scale.length ]);
  265. this._weight(resource, 'bias', [ resource.bias.length ]);
  266. }
  267. break;
  268. }
  269. case 'Div':
  270. case 'Sub':
  271. case 'Add':
  272. case 'Mul':
  273. case 'MatMul': {
  274. if (this._inputs.length === 1) {
  275. const resource = resources.read(this._name);
  276. if (resource) {
  277. const num_output = resource.slope.length;
  278. this._weight(resource, 'slope', [ num_output ]);
  279. }
  280. }
  281. break;
  282. }
  283. case 'HdrGuide': {
  284. const resource = resources.read(this._name);
  285. if (resource) {
  286. const weight_size = resource.ccm_weight.length;
  287. this._weight(resource, 'ccm_weight', [ weight_size ]);
  288. this._weight(resource, 'ccm_bias', [ weight_size ]);
  289. this._weight(resource, 'shifts', [ weight_size ]);
  290. this._weight(resource, 'slopes', [ weight_size ]);
  291. this._weight(resource, 'projection_weight', [ weight_size ]);
  292. this._weight(resource, 'projection_bias', [ weight_size ]);
  293. }
  294. break;
  295. }
  296. case 'BlobScale': {
  297. const resource = resources.read(this._name);
  298. if (resource) {
  299. const scale_data_size = resource.scale.length;
  300. this._weight(resource, 'scale', [ scale_data_size]);
  301. this._weight(resource, 'bias', [ scale_data_size ]);
  302. }
  303. break;
  304. }
  305. case 'Gather': {
  306. const resource = resources.read(this._name);
  307. if (resource) {
  308. if (resource.data) {
  309. this._weight(resource, 'data', [ resource.data.length ]);
  310. }
  311. if (resource.indices) {
  312. this._weight(resource, 'indices', [ resource.indices.length ]);
  313. }
  314. }
  315. break;
  316. }
  317. }
  318. }
  319. get type() {
  320. return this._type;
  321. }
  322. get name() {
  323. return this._name;
  324. }
  325. get attributes() {
  326. return this._attributes;
  327. }
  328. get inputs() {
  329. return this._inputs;
  330. }
  331. get outputs() {
  332. return this._outputs;
  333. }
  334. _weight(resource, name, shape) {
  335. const initializer = resource[name];
  336. if (!initializer) {
  337. throw new tnn.Error("Layer initializer'" + resource.type + "." + name + "' not found '");
  338. }
  339. const tensor = new tnn.Tensor(new tnn.TensorType(initializer.dataType, new tnn.TensorShape(shape)), initializer.value);
  340. this._inputs.push(new tnn.Parameter(name, [ new tnn.Argument('', null, tensor) ]));
  341. }
  342. };
  343. tnn.Attribute = class {
  344. constructor(schema, key, value) {
  345. this._type = '';
  346. this._name = key.toString();
  347. this._value = value;
  348. if (schema) {
  349. this._name = schema.name;
  350. if (schema.type) {
  351. this._type = schema.type;
  352. }
  353. switch (this._type) {
  354. case 'int32':
  355. this._value = parseInt(this._value, 10);
  356. break;
  357. case 'float32':
  358. this._value = parseFloat(this._value);
  359. break;
  360. case 'float32[]':
  361. this._value = this._value.map((v) => parseFloat(v));
  362. break;
  363. }
  364. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  365. this._visible = false;
  366. }
  367. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  368. if (this._value == schema.default || (this._value && this._value.toString() == schema.default.toString())) {
  369. this._visible = false;
  370. }
  371. }
  372. }
  373. }
  374. get type() {
  375. return this._type;
  376. }
  377. get name() {
  378. return this._name;
  379. }
  380. get value() {
  381. return this._value;
  382. }
  383. get visible() {
  384. return this._visible == false ? false : true;
  385. }
  386. };
  387. tnn.Tensor = class {
  388. constructor(type, data) {
  389. this._type = type;
  390. this._data = data;
  391. }
  392. get kind() {
  393. return 'Weight';
  394. }
  395. get type() {
  396. return this._type;
  397. }
  398. get state() {
  399. return this._context().state || null;
  400. }
  401. get value() {
  402. const context = this._context();
  403. if (context.state) {
  404. return null;
  405. }
  406. context.limit = Number.MAX_SAFE_INTEGER;
  407. return this._decode(context, 0);
  408. }
  409. toString() {
  410. const context = this._context();
  411. if (context.state) {
  412. return '';
  413. }
  414. context.limit = 10000;
  415. const value = this._decode(context, 0);
  416. return JSON.stringify(value, null, 4);
  417. }
  418. _context() {
  419. const context = {};
  420. context.index = 0;
  421. context.count = 0;
  422. context.state = null;
  423. if (this._type.dataType == '?') {
  424. context.state = 'Tensor has unknown data type.';
  425. return context;
  426. }
  427. if (!this._type.shape) {
  428. context.state = 'Tensor has no dimensions.';
  429. return context;
  430. }
  431. if (!this._data) {
  432. context.state = 'Tensor data is empty.';
  433. return context;
  434. }
  435. switch (this._type.dataType) {
  436. case 'float16':
  437. case 'float32':
  438. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  439. break;
  440. default:
  441. context.state = 'Tensor data type is not implemented.';
  442. break;
  443. }
  444. context.dataType = this._type.dataType;
  445. context.shape = this._type.shape.dimensions;
  446. return context;
  447. }
  448. _decode(context, dimension) {
  449. const shape = context.shape.length !== 0 ? context.shape : [ 1 ];
  450. const results = [];
  451. const size = shape[dimension];
  452. if (dimension == shape.length - 1) {
  453. for (let i = 0; i < size; i++) {
  454. if (context.count > context.limit) {
  455. results.push('...');
  456. return results;
  457. }
  458. switch (this._type.dataType) {
  459. case 'float32':
  460. results.push(context.data.getFloat32(context.index, true));
  461. context.index += 4;
  462. context.count++;
  463. break;
  464. case 'float16':
  465. results.push(context.data.getFloat16(context.index, true));
  466. context.index += 2;
  467. context.count++;
  468. break;
  469. }
  470. }
  471. }
  472. else {
  473. for (let j = 0; j < size; j++) {
  474. if (context.count > context.limit) {
  475. results.push('...');
  476. return results;
  477. }
  478. results.push(this._decode(context, dimension + 1));
  479. }
  480. }
  481. if (context.shape.length == 0) {
  482. return results[0];
  483. }
  484. return results;
  485. }
  486. };
  487. tnn.TensorType = class {
  488. constructor(dataType, shape) {
  489. this._dataType = dataType || '?';
  490. this._shape = shape;
  491. }
  492. get dataType() {
  493. return this._dataType;
  494. }
  495. get shape() {
  496. return this._shape;
  497. }
  498. toString() {
  499. return this._dataType + this._shape.toString();
  500. }
  501. };
  502. tnn.TensorShape = class {
  503. constructor(dimensions) {
  504. this._dimensions = dimensions;
  505. }
  506. get dimensions() {
  507. return this._dimensions;
  508. }
  509. toString() {
  510. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']') : '';
  511. }
  512. };
  513. tnn.Metadata = class {
  514. static open(context) {
  515. if (tnn.Metadata._metadata) {
  516. return Promise.resolve(tnn.Metadata._metadata);
  517. }
  518. return context.request('tnn-metadata.json', 'utf-8', null).then((data) => {
  519. tnn.Metadata._metadata = new tnn.Metadata(data);
  520. return tnn.Metadata._metadata;
  521. }).catch(() => {
  522. tnn.Metadata._metadata = new tnn.Metadata(null);
  523. return tnn.Metadata._metadatas;
  524. });
  525. }
  526. constructor(data) {
  527. this._operatorMap = new Map();
  528. this._map = new Map();
  529. this._attributeCache = new Map();
  530. if (data) {
  531. const metadata = JSON.parse(data);
  532. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  533. this._operatorMap = new Map(metadata.map((item) => [ item.operator, item ]));
  534. }
  535. }
  536. operator(code) {
  537. return this._operatorMap.get(code);
  538. }
  539. type(operator) {
  540. return this._map.get(operator);
  541. }
  542. attribute(operator, name) {
  543. const key = operator + ':' + name;
  544. if (!this._attributeCache.has(key)) {
  545. const schema = this.type(operator);
  546. if (schema && schema.attributes && schema.attributes.length > 0) {
  547. for (const attribute of schema.attributes) {
  548. this._attributeCache.set(operator + ':' + attribute.name, attribute);
  549. }
  550. }
  551. if (!this._attributeCache.has(key)) {
  552. this._attributeCache.set(key, null);
  553. }
  554. }
  555. return this._attributeCache.get(key);
  556. }
  557. };
  558. tnn.TextProtoReader = class {
  559. constructor(buffer) {
  560. const reader = base.TextReader.open(buffer);
  561. let lines = [];
  562. for (;;) {
  563. const line = reader.read();
  564. if (line === undefined) {
  565. break;
  566. }
  567. lines.push(line.replace(/\r|"/g, ''));
  568. }
  569. const split = (line, delimiter, trim, ignore_blank) => {
  570. return line.split(delimiter).map((v) => trim ? v.trim() : v).filter((v) => !ignore_blank || v);
  571. };
  572. lines = split(lines.join(''), ',', true, false);
  573. if (lines.length <= 5) {
  574. throw new tnn.Error('Invalid line count.');
  575. }
  576. const header = split(lines.shift(), ' ', true, false);
  577. if (header.length < 3) {
  578. throw new tnn.Error('Invalid header size.');
  579. }
  580. else if (header.length > 3 && (header[3] !== '4206624770' && header[3] !== '4206624772')) {
  581. throw new tnn.Error("Invalid signature '" + header[3] + "'.");
  582. }
  583. this._inputs = split(lines.shift(), ':', true, false).map((input) => {
  584. const array = split(input, ' ', true, false);
  585. const name = array.shift();
  586. if (header[3] === '4206624772') {
  587. const shape_size = parseInt(array.shift(), 10);
  588. const data_type_index = parseInt(array[shape_size], 10);
  589. return {
  590. name: name,
  591. data_type: [ 'float32', 'float16', 'int8', 'int32', 'bfloat16' ][data_type_index],
  592. shape: array.slice(0, -1).map((dim) => parseInt(dim, 10)),
  593. };
  594. }
  595. return {
  596. name: name,
  597. data_type: 'float32',
  598. shape: array.map((dim) => parseInt(dim, 10))
  599. };
  600. });
  601. lines.shift();
  602. this._outputs = split(lines.shift(), ' ', true, false).map((output) => { return { name: output }; });
  603. lines.shift();
  604. this._layers = [];
  605. while (lines.length > 0) {
  606. const line = lines.shift().trim();
  607. if (line.length > 0) {
  608. const array = split(line, ' ', true, true);
  609. const layer = {};
  610. layer.type = array.shift();
  611. layer.name = array.shift();
  612. const inputCount = parseInt(array.shift(), 10);
  613. const outputCount = parseInt(array.shift(), 10);
  614. layer.inputs = array.splice(0, inputCount);
  615. layer.outputs = array.splice(0, outputCount);
  616. layer.attr = {};
  617. layer.attributes = [];
  618. let count = 0;
  619. for (const column of array) {
  620. const parts = column.split(' ');
  621. if (parts.length === 1) {
  622. let key = count;
  623. let value = parts.toString();
  624. const keyInt = parseInt(key, 10);
  625. if (keyInt < 0) {
  626. value = value.split(',').map((v) => v.trim());
  627. value.shift();
  628. key = (-(keyInt + 23300)).toString();
  629. }
  630. layer.attr[key] = value;
  631. layer.attributes.push({ key: key, value: value });
  632. count++;
  633. }
  634. }
  635. this._layers.push(layer);
  636. }
  637. }
  638. }
  639. get inputs() {
  640. return this._inputs;
  641. }
  642. get outputs() {
  643. return this._outputs;
  644. }
  645. get layers() {
  646. return this._layers;
  647. }
  648. };
  649. tnn.LayerResourceReader = class {
  650. constructor(buffer) {
  651. this._layerResources = [];
  652. if (buffer) {
  653. const reader = new tnn.BinaryReader(buffer);
  654. const magic_number = reader.uint32();
  655. if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
  656. throw new tnn.Error("Invalid blob header signature '" + magic_number.toString() + "'.");
  657. }
  658. const layerCount = reader.int32() & 0x1FFFFFFF;
  659. const raw = (reader) => {
  660. const magic_number = reader.uint32();
  661. if (magic_number !== 0xFABC0002 && magic_number !== 0xFABC0004) {
  662. throw new tnn.Error("Invalid raw signature '" + magic_number.toString() + "'.");
  663. }
  664. const data_type = reader.int32();
  665. if (data_type > 4) {
  666. throw new tnn.Error("Unknown data type '" + data_type + "'.");
  667. }
  668. const length = reader.int32();
  669. if (length <= 0) {
  670. return null;
  671. }
  672. let dims = null;
  673. if (magic_number === 0xFABC0004) {
  674. const dim_size = reader.int32();
  675. dims = reader.bytes(dim_size * 4);
  676. }
  677. return {
  678. dataType: [ 'float32', 'float16', 'int8', 'int32', 'bfloat16' ][data_type],
  679. length: length / [ 4, 2, 1, 4, 2 ][data_type],
  680. value: reader.bytes(length),
  681. shape: dims
  682. };
  683. };
  684. for (let i = 0; i < layerCount; i++) {
  685. const resource = {};
  686. resource.operator = reader.int32();
  687. resource.type = reader.string();
  688. resource.name = reader.string();
  689. switch (resource.type) {
  690. case 'Convolution':
  691. case 'ConvolutionDepthWise':
  692. case 'Deconvolution':
  693. case 'DeconvolutionDepthWise': {
  694. reader.expect(resource.name);
  695. const bias = reader.int32();
  696. resource.filter = raw(reader);
  697. if (bias) {
  698. resource.bias = raw(reader);
  699. }
  700. if (resource.filter.dataType === 'int8') {
  701. resource.quantized = raw();
  702. }
  703. break;
  704. }
  705. case 'Conv3D': {
  706. reader.expect(resource.name);
  707. const bias = reader.int32();
  708. resource.filter = raw(reader);
  709. if (bias) {
  710. resource.bias = raw(reader);
  711. }
  712. break;
  713. }
  714. case 'InnerProduct': {
  715. reader.expect(resource.name);
  716. resource.weight = raw(reader);
  717. resource.bias = raw(reader);
  718. if (resource.weight.dataType === 'int8') {
  719. resource.scale = raw();
  720. }
  721. break;
  722. }
  723. case 'PReLU': {
  724. reader.expect(resource.name);
  725. resource.slope = raw(reader);
  726. break;
  727. }
  728. case 'Add':
  729. case 'Div':
  730. case 'Mul':
  731. case 'Sub':
  732. case 'MatMul': {
  733. resource.slope = raw(reader);
  734. break;
  735. }
  736. case 'BatchNormCxx':
  737. case 'InstBatchNormCxx':
  738. resource.scale = raw(reader);
  739. resource.bias = raw(reader);
  740. break;
  741. case 'HdrGuide':
  742. resource.ccm_weight = raw(reader);
  743. resource.ccm_bias = raw(reader);
  744. resource.shifts = raw(reader);
  745. resource.slopes = raw(reader);
  746. resource.projection_weight = raw(reader);
  747. resource.projection_bias = raw(reader);
  748. break;
  749. case 'BlobScale':
  750. resource.scale = raw(reader);
  751. resource.bias = raw(reader);
  752. break;
  753. case 'Gather': {
  754. // reader.expect(resource.name);
  755. const has_data = reader.int32();
  756. if (has_data) {
  757. resource.data = raw(reader);
  758. }
  759. const has_indices = reader.int32();
  760. if (has_indices) {
  761. resource.indices = raw(reader);
  762. }
  763. break;
  764. }
  765. default:
  766. throw new tnn.Error("Unknown layer resource type '" + resource.type + "'.");
  767. }
  768. this._layerResources.push(resource);
  769. }
  770. if (!reader.end()) {
  771. throw new tnn.Error("Invalid blob size.");
  772. }
  773. }
  774. }
  775. read(name) {
  776. const resource = this._layerResources.shift();
  777. if (resource && resource.name !== name) {
  778. throw new tnn.Error("Invalid blob layer name '" + name + "'.");
  779. }
  780. return resource;
  781. }
  782. };
  783. tnn.BinaryReader = class {
  784. constructor(buffer) {
  785. this._buffer = buffer;
  786. this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  787. this._position = 0;
  788. }
  789. end() {
  790. return this._position === this._buffer.length;
  791. }
  792. skip(size) {
  793. this._position += size;
  794. if (this._position > this._buffer.length) {
  795. throw new tnn.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  796. }
  797. }
  798. bytes(size) {
  799. const position = this._position;
  800. this.skip(size);
  801. return this._buffer.subarray(position, this._position);
  802. }
  803. uint32() {
  804. const position = this._position;
  805. this.skip(4);
  806. return this._dataView.getUint32(position, true);
  807. }
  808. int32() {
  809. const position = this._position;
  810. this.skip(4);
  811. return this._dataView.getInt32(position, true);
  812. }
  813. string() {
  814. const length = this.int32();
  815. const position = this._position;
  816. this.skip(length);
  817. const data = this._buffer.subarray(position, this._position);
  818. return new TextDecoder('utf-8').decode(data);
  819. }
  820. expect(name) {
  821. const text = this.string();
  822. if (name !== text) {
  823. throw new tnn.Error("Invalid string '" + text + "' instead of '" + name + "'.");
  824. }
  825. }
  826. };
  827. tnn.Error = class extends Error {
  828. constructor(message) {
  829. super(message);
  830. this.name = 'Error loading TNN model.';
  831. }
  832. };
  833. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  834. module.exports.ModelFactory = tnn.ModelFactory;
  835. }