tflite.js 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. /* jshint esversion: 6 */
  2. var tflite = tflite || {};
  3. var flatbuffers = flatbuffers || require('./flatbuffers');
  4. tflite.ModelFactory = class {
  5. match(context) {
  6. const extension = context.identifier.split('.').pop().toLowerCase();
  7. if (['tflite', 'lite', 'tfl', 'bin', 'pb', 'model', 'tmfile', 'h5' ].indexOf(extension) !== -1) {
  8. const buffer = context.buffer;
  9. const signature = 'TFL3';
  10. if (buffer && buffer.length > 8 && buffer.subarray(4, 8).every((x, i) => x === signature.charCodeAt(i))) {
  11. return true;
  12. }
  13. }
  14. if (extension === 'json') {
  15. const tags = context.tags('json');
  16. if (tags.has('subgraphs') && tags.has('operator_codes')) {
  17. return true;
  18. }
  19. }
  20. return false;
  21. }
  22. open(context, host) {
  23. return host.require('./tflite-schema').then(() => {
  24. tflite.schema = flatbuffers.get('tflite').tflite;
  25. return tflite.Metadata.open(host).then((metadata) => {
  26. const identifier = context.identifier;
  27. try {
  28. const extension = identifier.split('.').pop().toLowerCase();
  29. switch (extension) {
  30. default: {
  31. const reader = new flatbuffers.Reader(context.buffer);
  32. if (!tflite.schema.Model.identifier(reader)) {
  33. throw new tflite.Error("File format is not tflite.Model.");
  34. }
  35. const model = tflite.schema.Model.create(reader);
  36. return new tflite.Model(metadata, model);
  37. }
  38. case 'json': {
  39. const reader = new flatbuffers.TextReader(context.buffer);
  40. const model = tflite.schema.Model.createText(reader);
  41. return new tflite.Model(metadata, model);
  42. }
  43. }
  44. }
  45. catch (error) {
  46. const message = error && error.message ? error.message : error.toString();
  47. throw new tflite.Error(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  48. }
  49. });
  50. });
  51. }
  52. };
  53. tflite.Model = class {
  54. constructor(metadata, model) {
  55. this._graphs = [];
  56. this._format = 'TensorFlow Lite';
  57. this._format = this._format + ' v' + model.version.toString();
  58. this._description = model.description || '';
  59. const operatorList = [];
  60. const builtinOperatorMap = {};
  61. for (const key of Object.keys(tflite.schema.BuiltinOperator)) {
  62. const index = tflite.schema.BuiltinOperator[key];
  63. builtinOperatorMap[index] = tflite.Utility.type(key);
  64. }
  65. for (let i = 0; i < model.operator_codes.length; i++) {
  66. const operatorCode = model.operator_codes[i];
  67. const code = operatorCode.builtin_code;
  68. const version = operatorCode.version;
  69. const custom = code === tflite.schema.BuiltinOperator.CUSTOM;
  70. const name = custom ? operatorCode.custom_code : builtinOperatorMap[code];
  71. if (!name) {
  72. throw new tflite.Error("Invalid built-in code '" + code.toString() + "' at '" + i.toString() + "'.");
  73. }
  74. operatorList.push(custom ? { name: name, version: version, custom: true } : { name: name, version: version });
  75. }
  76. let modelMetadata = null;
  77. for (const metadata of model.metadata) {
  78. switch (metadata.name) {
  79. case 'min_runtime_version': {
  80. const data = model.buffers[metadata.buffer].data;
  81. this._runtime = data ? new TextDecoder().decode(data) : undefined;
  82. break;
  83. }
  84. case 'TFLITE_METADATA': {
  85. const data = model.buffers[metadata.buffer].data || new Uint8Array(0);
  86. const reader = new flatbuffers.Reader(data);
  87. if (tflite.schema.ModelMetadata.identifier(reader)) {
  88. modelMetadata = tflite.schema.ModelMetadata.create(reader);
  89. this._name = modelMetadata.name || '';
  90. this._version = modelMetadata.version || '';
  91. this._description = modelMetadata.description ? [ this.description, modelMetadata.description].join(' ') : this._description;
  92. this._author = modelMetadata.author || '';
  93. this._license = modelMetadata.license || '';
  94. }
  95. break;
  96. }
  97. }
  98. }
  99. const subgraphs = model.subgraphs;
  100. const subgraphsMetadata = modelMetadata ? modelMetadata.subgraph_metadata : null;
  101. for (let i = 0; i < subgraphs.length; i++) {
  102. const subgraph = subgraphs[i];
  103. const name = subgraphs.length > 1 ? i.toString() : '';
  104. const subgraphMetadata = subgraphsMetadata && i < subgraphsMetadata.length ? subgraphsMetadata[i] : null;
  105. this._graphs.push(new tflite.Graph(metadata, subgraph, subgraphMetadata, name, operatorList, model));
  106. }
  107. }
  108. get format() {
  109. return this._format;
  110. }
  111. get runtime() {
  112. return this._runtime;
  113. }
  114. get name() {
  115. return this._name;
  116. }
  117. get version() {
  118. return this._version;
  119. }
  120. get description() {
  121. return this._description;
  122. }
  123. get author() {
  124. return this._author;
  125. }
  126. get license() {
  127. return this._license;
  128. }
  129. get graphs() {
  130. return this._graphs;
  131. }
  132. };
  133. tflite.Graph = class {
  134. constructor(metadata, subgraph, subgraphMetadata, name, operatorList, model) {
  135. this._nodes = [];
  136. this._inputs = [];
  137. this._outputs = [];
  138. this._name = subgraph.name || name;
  139. const args = [];
  140. const tensorNames = [];
  141. for (let i = 0; i < subgraph.tensors.length; i++) {
  142. const tensor = subgraph.tensors[i];
  143. const buffer = model.buffers[tensor.buffer];
  144. const is_variable = tensor.is_variable;
  145. const data = buffer.data;
  146. const initializer = (data && data.length > 0) || is_variable ? new tflite.Tensor(i, tensor, buffer, is_variable) : null;
  147. args.push(new tflite.Argument(i, tensor, initializer));
  148. tensorNames.push(tensor.name);
  149. }
  150. const operators = subgraph.operators;
  151. for (let i = 0; i < subgraph.operators.length; i++) {
  152. const node = operators[i];
  153. const index = node.opcode_index;
  154. const operator = index < operatorList.length ? operatorList[index] : { name: '(' + index.toString() + ')' };
  155. this._nodes.push(new tflite.Node(metadata, node, operator, i.toString(), args));
  156. }
  157. const applyTensorMetadata = (argument, tensorMetadata) => {
  158. if (tensorMetadata) {
  159. const description = tensorMetadata.description;
  160. if (description) {
  161. argument.description = description;
  162. }
  163. const content = tensorMetadata.content;
  164. if (argument.type && content) {
  165. let denotation = null;
  166. const contentProperties = content.content_properties;
  167. if (contentProperties instanceof tflite.schema.FeatureProperties) {
  168. denotation = 'Feature';
  169. }
  170. else if (contentProperties instanceof tflite.schema.ImageProperties) {
  171. denotation = 'Image';
  172. switch(contentProperties.color_space) {
  173. case 1: denotation += '(RGB)'; break;
  174. case 2: denotation += '(Grayscale)'; break;
  175. }
  176. }
  177. else if (contentProperties instanceof tflite.schema.BoundingBoxProperties) {
  178. denotation = 'BoundingBox';
  179. }
  180. if (denotation) {
  181. argument.type.denotation = denotation;
  182. }
  183. }
  184. }
  185. };
  186. const inputs = subgraph.inputs;
  187. for (let i = 0; i < inputs.length; i++) {
  188. const input = inputs[i];
  189. const argument = args[input];
  190. if (subgraphMetadata && i < subgraphMetadata.input_tensor_metadata.length) {
  191. applyTensorMetadata(argument, subgraphMetadata.input_tensor_metadata[i]);
  192. }
  193. this._inputs.push(new tflite.Parameter(tensorNames[input], true, [ argument ]));
  194. }
  195. const outputs = subgraph.outputs;
  196. for (let i = 0; i < outputs.length; i++) {
  197. const output = outputs[i];
  198. const argument = args[output];
  199. if (subgraphMetadata && i < subgraphMetadata.output_tensor_metadata.length) {
  200. applyTensorMetadata(argument, subgraphMetadata.output_tensor_metadata[i]);
  201. }
  202. this._outputs.push(new tflite.Parameter(tensorNames[output], true, [ argument ]));
  203. }
  204. }
  205. get name() {
  206. return this._name;
  207. }
  208. get groups() {
  209. return false;
  210. }
  211. get inputs() {
  212. return this._inputs;
  213. }
  214. get outputs() {
  215. return this._outputs;
  216. }
  217. get nodes() {
  218. return this._nodes;
  219. }
  220. };
  221. tflite.Node = class {
  222. constructor(metadata, node, type, location, args) {
  223. this._metadata = metadata;
  224. this._location = location;
  225. this._type = type;
  226. this._inputs = [];
  227. this._outputs = [];
  228. this._attributes = [];
  229. if (node) {
  230. let inputs = [];
  231. let outputs = [];
  232. inputs = Array.from(node.inputs || new Int32Array(0));
  233. outputs = Array.from(node.outputs || new Int32Array(0));
  234. const schema = this._metadata.type(this.type);
  235. let inputIndex = 0;
  236. while (inputIndex < inputs.length) {
  237. let count = 1;
  238. let inputName = null;
  239. let inputVisible = true;
  240. const inputArguments = [];
  241. if (schema && schema.inputs && inputIndex < schema.inputs.length) {
  242. const input = schema.inputs[inputIndex];
  243. inputName = input.name;
  244. if (input.option == 'variadic') {
  245. count = inputs.length - inputIndex;
  246. }
  247. if (Object.prototype.hasOwnProperty.call(input, 'visible') && !input.visible) {
  248. inputVisible = false;
  249. }
  250. }
  251. const inputArray = inputs.slice(inputIndex, inputIndex + count);
  252. for (let j = 0; j < inputArray.length; j++) {
  253. if (inputArray[j] != -1) {
  254. inputArguments.push(args[inputArray[j]]);
  255. }
  256. }
  257. inputIndex += count;
  258. inputName = inputName ? inputName : inputIndex.toString();
  259. this._inputs.push(new tflite.Parameter(inputName, inputVisible, inputArguments));
  260. }
  261. for (let k = 0; k < outputs.length; k++) {
  262. const outputIndex = outputs[k];
  263. const argument = args[outputIndex];
  264. let outputName = k.toString();
  265. if (schema && schema.outputs && k < schema.outputs.length) {
  266. const output = schema.outputs[k];
  267. if (output && (!output.option || output.opcodeIndex != 'variadic') && output.name) {
  268. outputName = output.name;
  269. }
  270. }
  271. this._outputs.push(new tflite.Parameter(outputName, true, [ argument ]));
  272. }
  273. if (type.custom && node.custom_options.length > 0) {
  274. const schema = metadata.attribute(this.type, 'custom');
  275. this._attributes.push(new tflite.Attribute(schema, 'custom', Array.from(node.custom_options)));
  276. }
  277. const options = node.builtin_options;
  278. if (options) {
  279. for (const name of Object.keys(options)) {
  280. const value = options[name];
  281. if (name === 'fused_activation_function' && value !== 0) {
  282. const activationFunctionMap = { 1: 'Relu', 2: 'ReluN1To1', 3: 'Relu6', 4: 'Tanh', 5: 'SignBit' };
  283. if (!activationFunctionMap[value]) {
  284. throw new tflite.Error("Unknown activation funtion index '" + JSON.stringify(value) + "'.");
  285. }
  286. const type = activationFunctionMap[value];
  287. this._chain = [ new tflite.Node(metadata, null, { name: type }, null, []) ];
  288. }
  289. const schema = metadata.attribute(this.type, name);
  290. this._attributes.push(new tflite.Attribute(schema, name, value));
  291. }
  292. }
  293. }
  294. }
  295. get type() {
  296. return this._type.name;
  297. }
  298. get name() {
  299. return '';
  300. }
  301. get location() {
  302. return this._location;
  303. }
  304. get domain() {
  305. return null;
  306. }
  307. get metadata() {
  308. if (this._type.custom) {
  309. return { name: this.type, category: 'custom' };
  310. }
  311. return this._metadata.type(this.type);
  312. }
  313. get group() {
  314. return null;
  315. }
  316. get inputs() {
  317. return this._inputs;
  318. }
  319. get outputs() {
  320. return this._outputs;
  321. }
  322. get chain() {
  323. return this._chain;
  324. }
  325. get attributes() {
  326. return this._attributes;
  327. }
  328. };
  329. tflite.Attribute = class {
  330. constructor(schema, name, value) {
  331. this._type = null;
  332. this._name = name;
  333. this._value = value;
  334. if (this._name == 'fused_activation_function') {
  335. this._visible = false;
  336. }
  337. if (schema) {
  338. if (schema.type) {
  339. this._type = schema.type;
  340. }
  341. if (this._type) {
  342. switch (this._type) {
  343. case 'shape':
  344. this._value = new tflite.TensorShape(value);
  345. break;
  346. case 'TensorType':
  347. this._value = tflite.Utility.dataType(this._value);
  348. break;
  349. default:
  350. this._value = tflite.Utility.enum(this._type, this._value);
  351. break;
  352. }
  353. }
  354. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  355. this._visible = false;
  356. }
  357. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  358. value = this._value;
  359. if (typeof value == 'function') {
  360. value = value();
  361. }
  362. if (value == schema.default) {
  363. this._visible = false;
  364. }
  365. }
  366. }
  367. }
  368. get name() {
  369. return this._name;
  370. }
  371. get type() {
  372. return this._type;
  373. }
  374. get value() {
  375. return this._value;
  376. }
  377. get visible() {
  378. return this._visible == false ? false : true;
  379. }
  380. };
  381. tflite.Parameter = class {
  382. constructor(name, visible, args) {
  383. this._name = name;
  384. this._visible = visible;
  385. this._arguments = args;
  386. }
  387. get name() {
  388. return this._name;
  389. }
  390. get visible() {
  391. return this._visible;
  392. }
  393. get arguments() {
  394. return this._arguments;
  395. }
  396. };
  397. tflite.Argument = class {
  398. constructor(index, tensor, initializer) {
  399. this._location = index.toString();
  400. this._type = new tflite.TensorType(tensor);
  401. this._initializer = initializer;
  402. this._name = tensor.name;
  403. const quantization = tensor.quantization;
  404. if (quantization) {
  405. let value = 'q';
  406. const scale = (quantization.scale.length == 1) ? quantization.scale[0] : 0;
  407. const zeroPoint = (quantization.zero_point.length == 1) ? quantization.zero_point[0] : 0;
  408. if (scale != 0 || zeroPoint != 0) {
  409. value = scale.toString() + ' * ' + (zeroPoint == 0 ? 'q' : ('(q - ' + zeroPoint.toString() + ')'));
  410. }
  411. if (quantization.min.length == 1) {
  412. value = quantization.min[0].toString() + ' \u2264 ' + value;
  413. }
  414. if (quantization.max.length == 1) {
  415. value = value + ' \u2264 ' + quantization.max[0].toString();
  416. }
  417. if (value != 'q') {
  418. this._quantization = value;
  419. }
  420. }
  421. }
  422. get name() {
  423. return this._name;
  424. }
  425. get location() {
  426. return this._location;
  427. }
  428. get type() {
  429. return this._type;
  430. }
  431. get quantization() {
  432. return this._quantization;
  433. }
  434. set description(value) {
  435. this._description = value;
  436. }
  437. get description() {
  438. return this._description;
  439. }
  440. get initializer() {
  441. return this._initializer;
  442. }
  443. };
  444. tflite.Tensor = class {
  445. constructor(index, tensor, buffer, is_variable) {
  446. this._location = index.toString();
  447. this._type = new tflite.TensorType(tensor);
  448. this._is_variable = is_variable;
  449. this._name = tensor.name;
  450. this._data = buffer.data.slice(0);
  451. }
  452. get kind() {
  453. return this._is_variable ? 'Variable' : '';
  454. }
  455. get name() {
  456. return this._name;
  457. }
  458. get location() {
  459. return this._location;
  460. }
  461. get type() {
  462. return this._type;
  463. }
  464. get state() {
  465. return this._context().state;
  466. }
  467. get value() {
  468. const context = this._context();
  469. if (context.state) {
  470. return null;
  471. }
  472. context.limit = Number.MAX_SAFE_INTEGER;
  473. return this._decode(context, 0);
  474. }
  475. toString() {
  476. const context = this._context();
  477. if (context.state) {
  478. return '';
  479. }
  480. context.limit = 10000;
  481. const value = this._decode(context, 0);
  482. return JSON.stringify(value, null, 4);
  483. }
  484. _context() {
  485. const context = {};
  486. context.state = null;
  487. context.index = 0;
  488. context.count = 0;
  489. if (this._data == null || this._data.length === 0) {
  490. context.state = 'Tensor data is empty.';
  491. return context;
  492. }
  493. context.dataType = this._type.dataType;
  494. context.shape = this._type.shape.dimensions;
  495. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  496. if (this._type.dataType == 'string') {
  497. let offset = 0;
  498. const count = context.data.getInt32(0, true);
  499. offset += 4;
  500. const offsetTable = [];
  501. for (let j = 0; j < count; j++) {
  502. offsetTable.push(context.data.getInt32(offset, true));
  503. offset += 4;
  504. }
  505. offsetTable.push(this._data.length);
  506. const stringTable = [];
  507. const utf8Decoder = new TextDecoder('utf-8');
  508. for (let k = 0; k < count; k++) {
  509. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  510. stringTable.push(utf8Decoder.decode(textArray));
  511. }
  512. context.data = stringTable;
  513. }
  514. return context;
  515. }
  516. _decode(context, dimension) {
  517. const shape = (context.shape.length == 0) ? [ 1 ] : context.shape;
  518. const size = shape[dimension];
  519. const results = [];
  520. if (dimension == shape.length - 1) {
  521. for (let i = 0; i < size; i++) {
  522. if (context.count > context.limit) {
  523. results.push('...');
  524. return results;
  525. }
  526. switch (context.dataType) {
  527. case 'uint8':
  528. results.push(context.data.getUint8(context.index));
  529. context.index += 1;
  530. context.count++;
  531. break;
  532. case 'int8':
  533. results.push(context.data.getInt8(context.index));
  534. context.index += 1;
  535. context.count++;
  536. break;
  537. case 'int16':
  538. results.push(context.data.getInt16(context.index));
  539. context.index += 2;
  540. context.count++;
  541. break;
  542. case 'int32':
  543. results.push(context.data.getInt32(context.index, true));
  544. context.index += 4;
  545. context.count++;
  546. break;
  547. case 'int64':
  548. results.push(context.data.getInt64(context.index, true));
  549. context.index += 8;
  550. context.count++;
  551. break;
  552. case 'float16':
  553. results.push(context.data.getFloat16(context.index, true));
  554. context.index += 2;
  555. context.count++;
  556. break;
  557. case 'float32':
  558. results.push(context.data.getFloat32(context.index, true));
  559. context.index += 4;
  560. context.count++;
  561. break;
  562. case 'float64':
  563. results.push(context.data.getFloat64(context.index, true));
  564. context.index += 8;
  565. context.count++;
  566. break;
  567. case 'string':
  568. results.push(context.data[context.index++]);
  569. context.count++;
  570. break;
  571. default:
  572. break;
  573. }
  574. }
  575. }
  576. else {
  577. for (let j = 0; j < size; j++) {
  578. if (context.count > context.limit) {
  579. results.push('...');
  580. return results;
  581. }
  582. results.push(this._decode(context, dimension + 1));
  583. }
  584. }
  585. if (context.shape.length == 0) {
  586. return results[0];
  587. }
  588. return results;
  589. }
  590. };
  591. tflite.TensorType = class {
  592. constructor(tensor) {
  593. this._dataType = tflite.Utility.dataType(tensor.type);
  594. this._shape = new tflite.TensorShape(Array.from(tensor.shape || []));
  595. }
  596. get dataType() {
  597. return this._dataType;
  598. }
  599. get shape() {
  600. return this._shape;
  601. }
  602. set denotation(value) {
  603. this._denotation = value;
  604. }
  605. get denotation() {
  606. return this._denotation;
  607. }
  608. toString() {
  609. return this.dataType + this._shape.toString();
  610. }
  611. };
  612. tflite.TensorShape = class {
  613. constructor(dimensions) {
  614. this._dimensions = dimensions;
  615. }
  616. get dimensions() {
  617. return this._dimensions;
  618. }
  619. toString() {
  620. if (!this._dimensions || this._dimensions.length == 0) {
  621. return '';
  622. }
  623. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  624. }
  625. };
  626. tflite.Metadata = class {
  627. static open(host) {
  628. if (tflite.Metadata._metadata) {
  629. return Promise.resolve(tflite.Metadata._metadata);
  630. }
  631. return host.request(null, 'tflite-metadata.json', 'utf-8').then((data) => {
  632. tflite.Metadata._metadata = new tflite.Metadata(data);
  633. return tflite.Metadata._metadata;
  634. }).catch(() => {
  635. tflite.Metadata._metadata = new tflite.Metadata(null);
  636. return tflite.Metadata._metadata;
  637. });
  638. }
  639. constructor(data) {
  640. this._map = new Map();
  641. if (data) {
  642. const items = JSON.parse(data);
  643. if (items) {
  644. for (const item of items) {
  645. item.schema.name = item.name;
  646. this._map.set(item.name, item.schema);
  647. }
  648. }
  649. }
  650. }
  651. type(name) {
  652. return this._map.has(name) ? this._map.get(name) : null;
  653. }
  654. attribute(type, name) {
  655. const schema = this.type(type);
  656. if (schema) {
  657. let attributeMap = schema.attributeMap;
  658. if (!attributeMap) {
  659. attributeMap = {};
  660. if (schema.attributes) {
  661. for (const attribute of schema.attributes) {
  662. attributeMap[attribute.name] = attribute;
  663. }
  664. }
  665. schema.attributeMap = attributeMap;
  666. }
  667. const attributeSchema = attributeMap[name];
  668. if (attributeSchema) {
  669. return attributeSchema;
  670. }
  671. }
  672. return null;
  673. }
  674. };
  675. tflite.Utility = class {
  676. static dataType(type) {
  677. if (!tflite.Utility._tensorTypeMap) {
  678. tflite.Utility._tensorTypeMap = new Map();
  679. for (const name of Object.keys(tflite.schema.TensorType)) {
  680. tflite.Utility._tensorTypeMap.set(tflite.schema.TensorType[name], name.toLowerCase());
  681. }
  682. tflite.Utility._tensorTypeMap.set(6, 'boolean');
  683. }
  684. return tflite.Utility._tensorTypeMap.has(type) ? tflite.Utility._tensorTypeMap.get(type) : '?';
  685. }
  686. static enum(name, value) {
  687. const type = name && tflite.schema ? tflite.schema[name] : undefined;
  688. if (type) {
  689. tflite.Utility._enumKeyMap = tflite.Utility._enumKeyMap || new Map();
  690. if (!tflite.Utility._enumKeyMap.has(name)) {
  691. const map = new Map();
  692. for (const key of Object.keys(type)) {
  693. map.set(type[key], key);
  694. }
  695. tflite.Utility._enumKeyMap.set(name, map);
  696. }
  697. const map = tflite.Utility._enumKeyMap.get(name);
  698. if (map.has(value)) {
  699. return map.get(value);
  700. }
  701. }
  702. return value;
  703. }
  704. static type(name) {
  705. const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]);
  706. name === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : name;
  707. return name.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
  708. }
  709. };
  710. tflite.Error = class extends Error {
  711. constructor(message) {
  712. super(message);
  713. this.name = 'Error loading TensorFlow Lite model.';
  714. }
  715. };
  716. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  717. module.exports.ModelFactory = tflite.ModelFactory;
  718. }