tflite.js 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. var tflite = {};
  2. var flatbuffers = require('./flatbuffers');
  3. var flexbuffers = require('./flexbuffers');
  4. var zip = require('./zip');
  5. tflite.ModelFactory = class {
  6. match(context) {
  7. const tags = context.tags('flatbuffers');
  8. if (tags.get('file_identifier') === 'TFL3') {
  9. return 'tflite.flatbuffers';
  10. }
  11. const identifier = context.identifier;
  12. const extension = identifier.split('.').pop().toLowerCase();
  13. const stream = context.stream;
  14. if (extension === 'tflite' && stream.length >= 8) {
  15. const buffer = stream.peek(Math.min(32, stream.length));
  16. const reader = flatbuffers.BinaryReader.open(buffer);
  17. if (reader.root === 0x00000018) {
  18. const version = reader.uint32_(reader.root, 4, 0);
  19. if (version === 3) {
  20. return 'tflite.flatbuffers';
  21. }
  22. }
  23. }
  24. const obj = context.open('json');
  25. if (obj && obj.subgraphs && obj.operator_codes) {
  26. return 'tflite.flatbuffers.json';
  27. }
  28. return undefined;
  29. }
  30. open(context, match) {
  31. return context.require('./tflite-schema').then(() => {
  32. tflite.schema = flatbuffers.get('tflite').tflite;
  33. let model = null;
  34. const attachments = new Map();
  35. switch (match) {
  36. case 'tflite.flatbuffers.json': {
  37. try {
  38. const obj = context.open('json');
  39. const reader = new flatbuffers.TextReader(obj);
  40. model = tflite.schema.Model.createText(reader);
  41. }
  42. catch (error) {
  43. const message = error && error.message ? error.message : error.toString();
  44. throw new tflite.Error('File text format is not tflite.Model (' + message.replace(/\.$/, '') + ').');
  45. }
  46. break;
  47. }
  48. case 'tflite.flatbuffers': {
  49. const stream = context.stream;
  50. try {
  51. const reader = flatbuffers.BinaryReader.open(stream);
  52. model = tflite.schema.Model.create(reader);
  53. }
  54. catch (error) {
  55. const message = error && error.message ? error.message : error.toString();
  56. throw new tflite.Error('File format is not tflite.Model (' + message.replace(/\.$/, '') + ').');
  57. }
  58. try {
  59. const archive = zip.Archive.open(stream);
  60. if (archive) {
  61. for (const entry of archive.entries) {
  62. attachments.set(entry[0], entry[1]);
  63. }
  64. }
  65. }
  66. catch (error) {
  67. // continue regardless of error
  68. }
  69. break;
  70. }
  71. default: {
  72. throw new tflite.Error("Unsupported TensorFlow Lite format '" + match + "'.");
  73. }
  74. }
  75. return context.metadata('tflite-metadata.json').then((metadata) => {
  76. return new tflite.Model(metadata, model);
  77. });
  78. });
  79. }
  80. };
  81. tflite.Model = class {
  82. constructor(metadata, model) {
  83. this._graphs = [];
  84. this._format = 'TensorFlow Lite';
  85. this._format = this._format + ' v' + model.version.toString();
  86. this._description = model.description || '';
  87. this._metadata = [];
  88. const builtinOperators = new Map();
  89. const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]);
  90. for (const key of Object.keys(tflite.schema.BuiltinOperator)) {
  91. const value = key === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : key;
  92. const name = value.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
  93. const index = tflite.schema.BuiltinOperator[key];
  94. builtinOperators.set(index, name);
  95. }
  96. const operators = model.operator_codes.map((operator) => {
  97. const code = Math.max(operator.deprecated_builtin_code, operator.builtin_code || 0);
  98. const version = operator.version;
  99. const custom = code === tflite.schema.BuiltinOperator.CUSTOM;
  100. const name = custom ? operator.custom_code ? operator.custom_code : 'Custom' : builtinOperators.has(code) ? builtinOperators.get(code) : code.toString();
  101. return custom ? { name: name, version: version, custom: true } : { name: name, version: version };
  102. });
  103. let modelMetadata = null;
  104. for (const metadata of model.metadata) {
  105. const buffer = model.buffers[metadata.buffer];
  106. if (buffer) {
  107. switch (metadata.name) {
  108. case 'min_runtime_version': {
  109. const data = buffer.data || new Uint8Array(0);
  110. this._runtime = new TextDecoder().decode(data);
  111. break;
  112. }
  113. case 'TFLITE_METADATA': {
  114. const data = buffer.data || new Uint8Array(0);
  115. const reader = flatbuffers.BinaryReader.open(data);
  116. if (tflite.schema.ModelMetadata.identifier(reader)) {
  117. modelMetadata = tflite.schema.ModelMetadata.create(reader);
  118. if (modelMetadata.name) {
  119. this._name = modelMetadata.name;
  120. }
  121. if (modelMetadata.version) {
  122. this._version = modelMetadata.version;
  123. }
  124. if (modelMetadata.description) {
  125. this._description = this._description ? [ this._description, modelMetadata.description].join(' ') : modelMetadata.description;
  126. }
  127. if (modelMetadata.author) {
  128. this._metadata.push({ name: 'author', value: modelMetadata.author });
  129. }
  130. if (modelMetadata.license) {
  131. this._metadata.push({ name: 'license', value: modelMetadata.license });
  132. }
  133. }
  134. break;
  135. }
  136. default: {
  137. break;
  138. }
  139. }
  140. }
  141. }
  142. const subgraphs = model.subgraphs;
  143. const subgraphsMetadata = modelMetadata ? modelMetadata.subgraph_metadata : null;
  144. for (let i = 0; i < subgraphs.length; i++) {
  145. const subgraph = subgraphs[i];
  146. const name = subgraphs.length > 1 ? i.toString() : '';
  147. const subgraphMetadata = subgraphsMetadata && i < subgraphsMetadata.length ? subgraphsMetadata[i] : null;
  148. this._graphs.push(new tflite.Graph(metadata, subgraph, subgraphMetadata, name, operators, model));
  149. }
  150. }
  151. get format() {
  152. return this._format;
  153. }
  154. get runtime() {
  155. return this._runtime;
  156. }
  157. get name() {
  158. return this._name;
  159. }
  160. get version() {
  161. return this._version;
  162. }
  163. get description() {
  164. return this._description;
  165. }
  166. get metadata() {
  167. return this._metadata;
  168. }
  169. get graphs() {
  170. return this._graphs;
  171. }
  172. };
  173. tflite.Graph = class {
  174. constructor(metadata, subgraph, subgraphMetadata, name, operators, model) {
  175. this._nodes = [];
  176. this._inputs = [];
  177. this._outputs = [];
  178. this._name = subgraph.name || name;
  179. const tensors = new Map();
  180. const args = (index) => {
  181. if (index === -1) {
  182. return null;
  183. }
  184. if (!tensors.has(index)) {
  185. if (index < subgraph.tensors.length) {
  186. const tensor = subgraph.tensors[index];
  187. const buffer = model.buffers[tensor.buffer];
  188. const is_variable = tensor.is_variable;
  189. const data = buffer ? buffer.data : null;
  190. const initializer = (data && data.length > 0) || is_variable ? new tflite.Tensor(index, tensor, buffer, is_variable) : null;
  191. tensors.set(index, new tflite.Argument(index, tensor, initializer));
  192. }
  193. else {
  194. tensors.set(index, new tflite.Argument(index, { name: '' }, null));
  195. }
  196. }
  197. return tensors.get(index);
  198. };
  199. for (let i = 0; i < subgraph.operators.length; i++) {
  200. const node = subgraph.operators[i];
  201. const index = node.opcode_index;
  202. const operator = index < operators.length ? operators[index] : { name: '(' + index.toString() + ')' };
  203. this._nodes.push(new tflite.Node(metadata, node, operator, i.toString(), args));
  204. }
  205. const applyTensorMetadata = (argument, tensorMetadata) => {
  206. if (tensorMetadata) {
  207. const description = tensorMetadata.description;
  208. if (description) {
  209. argument.description = description;
  210. }
  211. const content = tensorMetadata.content;
  212. if (argument.type && content) {
  213. let denotation = null;
  214. const contentProperties = content.content_properties;
  215. if (contentProperties instanceof tflite.schema.FeatureProperties) {
  216. denotation = 'Feature';
  217. }
  218. else if (contentProperties instanceof tflite.schema.ImageProperties) {
  219. denotation = 'Image';
  220. switch (contentProperties.color_space) {
  221. case 0: denotation += '(Unknown)'; break;
  222. case 1: denotation += '(RGB)'; break;
  223. case 2: denotation += '(Grayscale)'; break;
  224. default: throw tflite.Error("Unsupported image color space '" + contentProperties.color_space + "'.");
  225. }
  226. }
  227. else if (contentProperties instanceof tflite.schema.BoundingBoxProperties) {
  228. denotation = 'BoundingBox';
  229. }
  230. else if (contentProperties instanceof tflite.schema.AudioProperties) {
  231. denotation = 'Audio(' + contentProperties.sample_rate.toString() + ',' + contentProperties.channels.toString() + ')';
  232. }
  233. if (denotation) {
  234. argument.type.denotation = denotation;
  235. }
  236. }
  237. }
  238. };
  239. const inputs = subgraph.inputs;
  240. for (let i = 0; i < inputs.length; i++) {
  241. const input = inputs[i];
  242. const argument = args(input);
  243. if (subgraphMetadata && i < subgraphMetadata.input_tensor_metadata.length) {
  244. applyTensorMetadata(argument, subgraphMetadata.input_tensor_metadata[i]);
  245. }
  246. this._inputs.push(new tflite.Parameter(argument ? argument.name : '?', true, argument ? [ argument ] : []));
  247. }
  248. const outputs = subgraph.outputs;
  249. for (let i = 0; i < outputs.length; i++) {
  250. const output = outputs[i];
  251. const argument = args(output);
  252. if (subgraphMetadata && i < subgraphMetadata.output_tensor_metadata.length) {
  253. applyTensorMetadata(argument, subgraphMetadata.output_tensor_metadata[i]);
  254. }
  255. this._outputs.push(new tflite.Parameter(argument ? argument.name : '?', true, argument ? [ argument ] : []));
  256. }
  257. }
  258. get name() {
  259. return this._name;
  260. }
  261. get inputs() {
  262. return this._inputs;
  263. }
  264. get outputs() {
  265. return this._outputs;
  266. }
  267. get nodes() {
  268. return this._nodes;
  269. }
  270. };
  271. tflite.Node = class {
  272. constructor(metadata, node, type, location, args) {
  273. this._location = location;
  274. this._type = type.custom ? { name: type.name, category: 'custom' } : metadata.type(type.name);
  275. this._inputs = [];
  276. this._outputs = [];
  277. this._attributes = [];
  278. if (node) {
  279. let inputs = [];
  280. let outputs = [];
  281. inputs = Array.from(node.inputs || new Int32Array(0));
  282. outputs = Array.from(node.outputs || new Int32Array(0));
  283. let inputIndex = 0;
  284. while (inputIndex < inputs.length) {
  285. let count = 1;
  286. let inputName = null;
  287. let inputVisible = true;
  288. const inputArguments = [];
  289. if (this._type && this._type.inputs && inputIndex < this._type.inputs.length) {
  290. const input = this._type.inputs[inputIndex];
  291. inputName = input.name;
  292. if (input.list) {
  293. count = inputs.length - inputIndex;
  294. }
  295. if (Object.prototype.hasOwnProperty.call(input, 'visible') && !input.visible) {
  296. inputVisible = false;
  297. }
  298. }
  299. const inputArray = inputs.slice(inputIndex, inputIndex + count);
  300. for (const index of inputArray) {
  301. const argument = args(index);
  302. if (argument) {
  303. inputArguments.push(argument);
  304. }
  305. }
  306. inputIndex += count;
  307. inputName = inputName ? inputName : inputIndex.toString();
  308. this._inputs.push(new tflite.Parameter(inputName, inputVisible, inputArguments));
  309. }
  310. for (let k = 0; k < outputs.length; k++) {
  311. const index = outputs[k];
  312. const outputArguments = [];
  313. const argument = args(index);
  314. if (argument) {
  315. outputArguments.push(argument);
  316. }
  317. let outputName = k.toString();
  318. if (this._type && this._type.outputs && k < this._type.outputs.length) {
  319. const output = this._type.outputs[k];
  320. if (output && output.name) {
  321. outputName = output.name;
  322. }
  323. }
  324. this._outputs.push(new tflite.Parameter(outputName, true, outputArguments));
  325. }
  326. if (type.custom && node.custom_options.length > 0) {
  327. let decoded = false;
  328. if (node.custom_options_format === tflite.schema.CustomOptionsFormat.FLEXBUFFERS) {
  329. try {
  330. const reader = flexbuffers.BinaryReader.open(node.custom_options);
  331. if (reader) {
  332. const custom_options = reader.read();
  333. if (Array.isArray(custom_options)) {
  334. const attribute = new tflite.Attribute(null, 'custom_options', custom_options);
  335. this._attributes.push(attribute);
  336. decoded = true;
  337. }
  338. else if (custom_options) {
  339. for (const pair of Object.entries(custom_options)) {
  340. const key = pair[0];
  341. const value = pair[1];
  342. const schema = metadata.attribute(type.name, key);
  343. const attribute = new tflite.Attribute(schema, key, value);
  344. this._attributes.push(attribute);
  345. }
  346. decoded = true;
  347. }
  348. }
  349. }
  350. catch (err) {
  351. // continue regardless of error
  352. }
  353. }
  354. if (!decoded) {
  355. const schema = metadata.attribute(type.name, 'custom');
  356. this._attributes.push(new tflite.Attribute(schema, 'custom', Array.from(node.custom_options)));
  357. }
  358. }
  359. const options = node.builtin_options;
  360. if (options) {
  361. for (const entry of Object.entries(options)) {
  362. const name = entry[0];
  363. const value = entry[1];
  364. if (name === 'fused_activation_function' && value) {
  365. const activationFunctionMap = { 1: 'Relu', 2: 'ReluN1To1', 3: 'Relu6', 4: 'Tanh', 5: 'SignBit' };
  366. if (!activationFunctionMap[value]) {
  367. throw new tflite.Error("Unsupported activation funtion index '" + JSON.stringify(value) + "'.");
  368. }
  369. const type = activationFunctionMap[value];
  370. this._chain = [ new tflite.Node(metadata, null, { name: type }, null, []) ];
  371. }
  372. const schema = metadata.attribute(type.name, name);
  373. this._attributes.push(new tflite.Attribute(schema, name, value));
  374. }
  375. }
  376. }
  377. }
  378. get type() {
  379. return this._type;
  380. }
  381. get name() {
  382. return '';
  383. }
  384. get location() {
  385. return this._location;
  386. }
  387. get inputs() {
  388. return this._inputs;
  389. }
  390. get outputs() {
  391. return this._outputs;
  392. }
  393. get chain() {
  394. return this._chain;
  395. }
  396. get attributes() {
  397. return this._attributes;
  398. }
  399. };
  400. tflite.Attribute = class {
  401. constructor(metadata, name, value) {
  402. this._name = name;
  403. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  404. this._type = metadata && metadata.type ? metadata.type : null;
  405. if (this._name == 'fused_activation_function') {
  406. this._visible = false;
  407. }
  408. if (this._type) {
  409. this._value = tflite.Utility.enum(this._type, this._value);
  410. }
  411. if (metadata) {
  412. if (Object.prototype.hasOwnProperty.call(metadata, 'visible') && !metadata.visible) {
  413. this._visible = false;
  414. }
  415. else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  416. value = this._value;
  417. if (typeof value == 'function') {
  418. value = value();
  419. }
  420. if (value == metadata.default) {
  421. this._visible = false;
  422. }
  423. }
  424. }
  425. }
  426. get name() {
  427. return this._name;
  428. }
  429. get type() {
  430. return this._type;
  431. }
  432. get value() {
  433. return this._value;
  434. }
  435. get visible() {
  436. return this._visible == false ? false : true;
  437. }
  438. };
  439. tflite.Parameter = class {
  440. constructor(name, visible, args) {
  441. this._name = name;
  442. this._visible = visible;
  443. this._arguments = args;
  444. }
  445. get name() {
  446. return this._name;
  447. }
  448. get visible() {
  449. return this._visible;
  450. }
  451. get arguments() {
  452. return this._arguments;
  453. }
  454. };
  455. tflite.Argument = class {
  456. constructor(index, tensor, initializer) {
  457. const name = tensor.name || '';
  458. this._name = name + '\n' + index.toString();
  459. this._location = index.toString();
  460. this._type = tensor.type !== undefined && tensor.shape !== undefined ? new tflite.TensorType(tensor) : null;
  461. this._initializer = initializer;
  462. const quantization = tensor.quantization;
  463. if (quantization) {
  464. const length = Math.max(quantization.scale.length, quantization.zero_point.length, quantization.min.length, quantization.max.length);
  465. const list = [];
  466. for (let i = 0; i < length; i++) {
  467. let value = 'q';
  468. const scale = i < quantization.scale.length ? quantization.scale[i] : 0;
  469. const zeroPoint = (i < quantization.zero_point.length ? quantization.zero_point[i] : 0).toString();
  470. if (scale !== 0 || zeroPoint !== '0') {
  471. value = scale.toString() + ' * ' + (zeroPoint === '0' ? 'q' : ('(q' + (!zeroPoint.startsWith('-') ? ' - ' + zeroPoint : ' + ' + zeroPoint.substring(1)) + ')'));
  472. }
  473. if (i < quantization.min.length) {
  474. value = quantization.min[i].toString() + ' \u2264 ' + value;
  475. }
  476. if (i < quantization.max.length) {
  477. value = value + ' \u2264 ' + quantization.max[i].toString();
  478. }
  479. list.push(value);
  480. }
  481. if (list.length > 0 && !list.every((value) => value === 'q')) {
  482. this._quantization = list.length === 1 ? list[0] : list;
  483. }
  484. }
  485. }
  486. get name() {
  487. return this._name;
  488. }
  489. get location() {
  490. return this._location;
  491. }
  492. get type() {
  493. return this._type;
  494. }
  495. get quantization() {
  496. return this._quantization;
  497. }
  498. set description(value) {
  499. this._description = value;
  500. }
  501. get description() {
  502. return this._description;
  503. }
  504. get initializer() {
  505. return this._initializer;
  506. }
  507. };
  508. tflite.Tensor = class {
  509. constructor(index, tensor, buffer, is_variable) {
  510. this._location = index.toString();
  511. this._type = new tflite.TensorType(tensor);
  512. this._is_variable = is_variable;
  513. this._name = tensor.name;
  514. this._data = buffer.data.slice(0);
  515. }
  516. get category() {
  517. return this._is_variable ? 'Variable' : '';
  518. }
  519. get name() {
  520. return this._name;
  521. }
  522. get location() {
  523. return this._location;
  524. }
  525. get type() {
  526. return this._type;
  527. }
  528. get layout() {
  529. switch (this._type.dataType) {
  530. case 'string': return '|';
  531. default: return '<';
  532. }
  533. }
  534. get values() {
  535. switch (this._type.dataType) {
  536. case 'string': {
  537. let offset = 0;
  538. const data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  539. const count = data.getInt32(0, true);
  540. offset += 4;
  541. const offsetTable = [];
  542. for (let j = 0; j < count; j++) {
  543. offsetTable.push(data.getInt32(offset, true));
  544. offset += 4;
  545. }
  546. offsetTable.push(this._data.length);
  547. const stringTable = [];
  548. const utf8Decoder = new TextDecoder('utf-8');
  549. for (let k = 0; k < count; k++) {
  550. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  551. stringTable.push(utf8Decoder.decode(textArray));
  552. }
  553. return stringTable;
  554. }
  555. default: {
  556. return this._data;
  557. }
  558. }
  559. }
  560. };
  561. tflite.TensorType = class {
  562. constructor(tensor) {
  563. this._dataType = tflite.Utility.dataType(tensor.type);
  564. this._shape = new tflite.TensorShape(Array.from(tensor.shape || []));
  565. }
  566. get dataType() {
  567. return this._dataType;
  568. }
  569. get shape() {
  570. return this._shape;
  571. }
  572. set denotation(value) {
  573. this._denotation = value;
  574. }
  575. get denotation() {
  576. return this._denotation;
  577. }
  578. toString() {
  579. return this.dataType + this._shape.toString();
  580. }
  581. };
  582. tflite.TensorShape = class {
  583. constructor(dimensions) {
  584. this._dimensions = dimensions;
  585. }
  586. get dimensions() {
  587. return this._dimensions;
  588. }
  589. toString() {
  590. if (!this._dimensions || this._dimensions.length == 0) {
  591. return '';
  592. }
  593. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  594. }
  595. };
  596. tflite.Utility = class {
  597. static dataType(type) {
  598. if (!tflite.Utility._tensorTypeMap) {
  599. tflite.Utility._tensorTypeMap = new Map(Object.keys(tflite.schema.TensorType).map((key) => [ tflite.schema.TensorType[key], key.toLowerCase() ]));
  600. tflite.Utility._tensorTypeMap.set(6, 'boolean');
  601. }
  602. return tflite.Utility._tensorTypeMap.has(type) ? tflite.Utility._tensorTypeMap.get(type) : '?';
  603. }
  604. static enum(name, value) {
  605. const type = name && tflite.schema ? tflite.schema[name] : undefined;
  606. if (type) {
  607. tflite.Utility._enums = tflite.Utility._enums || new Map();
  608. if (!tflite.Utility._enums.has(name)) {
  609. const map = new Map(Object.keys(type).map((key) => [ type[key], key ]));
  610. tflite.Utility._enums.set(name, map);
  611. }
  612. const map = tflite.Utility._enums.get(name);
  613. if (map.has(value)) {
  614. return map.get(value);
  615. }
  616. }
  617. return value;
  618. }
  619. };
  620. tflite.Error = class extends Error {
  621. constructor(message) {
  622. super(message);
  623. this.name = 'Error loading TensorFlow Lite model.';
  624. }
  625. };
  626. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  627. module.exports.ModelFactory = tflite.ModelFactory;
  628. }