tflite.js 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. import * as flatbuffers from './flatbuffers.js';
  2. import * as flexbuffers from './flexbuffers.js';
  3. import * as zip from './zip.js';
  4. const tflite = {};
  5. tflite.ModelFactory = class {
  6. async match(context) {
  7. const reader = await context.peek('flatbuffers.binary');
  8. if (reader && reader.identifier === 'TFL3') {
  9. return context.set('tflite.flatbuffers', reader);
  10. }
  11. const identifier = context.identifier;
  12. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  13. if (extension === 'tflite' && reader && reader.identifier === '') {
  14. const version = reader.uint32_(reader.root, 4, 0);
  15. if (version === 3) {
  16. return context.set('tflite.flatbuffers', reader);
  17. }
  18. }
  19. const obj = await context.peek('json');
  20. if (obj && obj.subgraphs && obj.operator_codes) {
  21. return context.set('tflite.flatbuffers.json', obj);
  22. }
  23. return null;
  24. }
  25. async open(context) {
  26. tflite.schema = await context.require('./tflite-schema');
  27. tflite.schema = tflite.schema.tflite;
  28. let model = null;
  29. const attachments = new Map();
  30. switch (context.type) {
  31. case 'tflite.flatbuffers.json': {
  32. try {
  33. const reader = await context.read('flatbuffers.text');
  34. model = tflite.schema.Model.createText(reader);
  35. } catch (error) {
  36. const message = error && error.message ? error.message : error.toString();
  37. throw new tflite.Error(`File text format is not tflite.Model (${message.replace(/\.$/, '')}).`);
  38. }
  39. break;
  40. }
  41. case 'tflite.flatbuffers': {
  42. try {
  43. const reader = context.value;
  44. model = tflite.schema.Model.create(reader);
  45. } catch (error) {
  46. const message = error && error.message ? error.message : error.toString();
  47. throw new tflite.Error(`File format is not tflite.Model (${message.replace(/\.$/, '')}).`);
  48. }
  49. try {
  50. const stream = context.stream;
  51. const archive = zip.Archive.open(stream);
  52. if (archive) {
  53. for (const [name, value] of archive.entries) {
  54. attachments.set(name, value);
  55. }
  56. }
  57. } catch {
  58. // continue regardless of error
  59. }
  60. break;
  61. }
  62. default: {
  63. throw new tflite.Error(`Unsupported TensorFlow Lite format '${context.type}'.`);
  64. }
  65. }
  66. const stream = context.stream;
  67. const metadata = await context.metadata('tflite-metadata.json');
  68. return new tflite.Model(metadata, model, stream);
  69. }
  70. };
  71. tflite.Model = class {
  72. constructor(metadata, model, stream) {
  73. this.graphs = [];
  74. this.format = 'TensorFlow Lite';
  75. this.format = `${this.format} v${model.version}`;
  76. this.description = model.description || '';
  77. this.metadata = [];
  78. const builtinOperators = new Map();
  79. const upperCase = new Set(['2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM']);
  80. for (const key of Object.keys(tflite.schema.BuiltinOperator)) {
  81. const value = key === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : key;
  82. const name = value.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
  83. const index = tflite.schema.BuiltinOperator[key];
  84. builtinOperators.set(index, name);
  85. }
  86. const operators = model.operator_codes.map((operator) => {
  87. const code = Math.max(operator.deprecated_builtin_code, operator.builtin_code || 0);
  88. const value = {};
  89. if (code === tflite.schema.BuiltinOperator.CUSTOM) {
  90. value.name = operator.custom_code ? operator.custom_code : 'Custom';
  91. value.version = operator.version;
  92. value.custom = true;
  93. } else {
  94. value.name = builtinOperators.has(code) ? builtinOperators.get(code) : code.toString();
  95. value.version = operator.version;
  96. value.custom = false;
  97. }
  98. return value;
  99. });
  100. let modelMetadata = null;
  101. for (const metadata of model.metadata) {
  102. const buffer = model.buffers[metadata.buffer];
  103. let data = null;
  104. const position = stream.position;
  105. if (buffer && buffer.data && buffer.data.length > 0) {
  106. data = buffer.data;
  107. } else if (buffer && buffer.offset !== 0n && buffer.size !== 0n) {
  108. const offset = buffer.offset.toNumber();
  109. const size = buffer.size.toNumber();
  110. stream.seek(offset);
  111. data = stream.read(size);
  112. }
  113. stream.seek(position);
  114. if (data) {
  115. switch (metadata.name) {
  116. case 'min_runtime_version': {
  117. const decoder = new TextDecoder('utf-8');
  118. this.runtime = decoder.decode(data);
  119. break;
  120. }
  121. case 'TFLITE_METADATA': {
  122. const reader = flatbuffers.BinaryReader.open(data);
  123. if (!reader || !tflite.schema.ModelMetadata.identifier(reader)) {
  124. throw new tflite.Error('Invalid TensorFlow Lite metadata.');
  125. }
  126. modelMetadata = tflite.schema.ModelMetadata.create(reader);
  127. if (modelMetadata.name) {
  128. this.name = modelMetadata.name;
  129. }
  130. if (modelMetadata.version) {
  131. this.version = modelMetadata.version;
  132. }
  133. if (modelMetadata.description) {
  134. this.description = this.description ? [this.description, modelMetadata.description].join(' ') : modelMetadata.description;
  135. }
  136. if (modelMetadata.author) {
  137. this.metadata.push(new tflite.Argument('author', modelMetadata.author));
  138. }
  139. if (modelMetadata.license) {
  140. this.metadata.push(new tflite.Argument('license', modelMetadata.license));
  141. }
  142. break;
  143. }
  144. default: {
  145. const value = data.length < 256 && data.every((c) => c >= 32 && c < 128) ? String.fromCharCode.apply(null, data) : '?';
  146. const argument = new tflite.Argument(metadata.name, value);
  147. this.metadata.push(argument);
  148. break;
  149. }
  150. }
  151. }
  152. }
  153. const subgraphs = model.subgraphs;
  154. const subgraphsMetadata = modelMetadata ? modelMetadata.subgraph_metadata : null;
  155. for (let i = 0; i < subgraphs.length; i++) {
  156. const subgraph = subgraphs[i];
  157. const name = subgraphs.length > 1 ? i.toString() : '';
  158. const subgraphMetadata = subgraphsMetadata && i < subgraphsMetadata.length ? subgraphsMetadata[i] : null;
  159. const signatures = model.signature_defs.filter((signature) => signature.subgraph_index === i);
  160. const graph = new tflite.Graph(metadata, subgraph, signatures, subgraphMetadata, name, operators, model, stream);
  161. this.graphs.push(graph);
  162. }
  163. }
  164. };
  165. tflite.Graph = class {
  166. constructor(metadata, subgraph, signatures, subgraphMetadata, name, operators, model, stream) {
  167. this.name = subgraph.name || name;
  168. if (subgraph.operators.length === 0 && subgraph.tensors.length > 0 && operators.length === 0) {
  169. operators.push({ name: 'Weights', custom: true });
  170. const layers = new Map();
  171. for (let i = 0; i < subgraph.tensors.length; i++) {
  172. const tensor = subgraph.tensors[i];
  173. const parts = tensor.name.split('.');
  174. parts.pop();
  175. const key = parts.join('.');
  176. if (!layers.has(key)) {
  177. const operator = { opcode_index: 0, inputs: [], outputs: [] };
  178. layers.set(key, operator);
  179. subgraph.operators.push(operator);
  180. }
  181. const operator = layers.get(key);
  182. operator.inputs.push(i);
  183. }
  184. }
  185. const tensors = new Map();
  186. tensors.map = (index, metadata) => {
  187. if (index === -1) {
  188. return null;
  189. }
  190. if (!tensors.has(index)) {
  191. let tensor = { name: '' };
  192. let initializer = null;
  193. let description = '';
  194. let denotation = '';
  195. if (index < subgraph.tensors.length) {
  196. tensor = subgraph.tensors[index];
  197. const buffer = model.buffers[tensor.buffer];
  198. const is_variable = tensor.is_variable;
  199. const variable = is_variable || (buffer && buffer.data && buffer.data.length > 0) || (buffer && buffer.offset !== 0n && buffer.size !== 0n);
  200. initializer = variable ? new tflite.Tensor(index, tensor, buffer, stream, is_variable) : null;
  201. }
  202. if (metadata) {
  203. description = metadata.description;
  204. const content = metadata.content;
  205. if (content) {
  206. const contentProperties = content.content_properties;
  207. if (contentProperties instanceof tflite.schema.FeatureProperties) {
  208. denotation = 'Feature';
  209. } else if (contentProperties instanceof tflite.schema.ImageProperties) {
  210. denotation = 'Image';
  211. switch (contentProperties.color_space) {
  212. case 0: denotation += '(Unknown)'; break;
  213. case 1: denotation += '(RGB)'; break;
  214. case 2: denotation += '(Grayscale)'; break;
  215. default: throw tflite.Error(`Unsupported image color space '${contentProperties.color_space}'.`);
  216. }
  217. } else if (contentProperties instanceof tflite.schema.BoundingBoxProperties) {
  218. denotation = 'BoundingBox';
  219. } else if (contentProperties instanceof tflite.schema.AudioProperties) {
  220. denotation = `Audio(${contentProperties.sample_rate},${contentProperties.channels})`;
  221. }
  222. }
  223. }
  224. const value = new tflite.Value(index, tensor, initializer, description, denotation);
  225. tensors.set(index, value);
  226. }
  227. return tensors.get(index);
  228. };
  229. this.inputs = Array.from(subgraph.inputs).map((tensor_index, index) => {
  230. const metadata = subgraphMetadata && index < subgraphMetadata.input_tensor_metadata.length ? subgraphMetadata.input_tensor_metadata[index] : null;
  231. const value = tensors.map(tensor_index, metadata);
  232. const values = value ? [value] : [];
  233. const name = value ? value.name.split('\n')[0] : '?';
  234. return new tflite.Argument(name, values);
  235. });
  236. this.outputs = Array.from(subgraph.outputs).map((tensor_index, index) => {
  237. const metadata = subgraphMetadata && index < subgraphMetadata.output_tensor_metadata.length ? subgraphMetadata.output_tensor_metadata[index] : null;
  238. const value = tensors.map(tensor_index, metadata);
  239. const values = value ? [value] : [];
  240. const name = value ? value.name.split('\n')[0] : '?';
  241. return new tflite.Argument(name, values);
  242. });
  243. this.signatures = signatures.map((signature) => {
  244. return new tflite.Signature(signature, tensors);
  245. });
  246. this.nodes = Array.from(subgraph.operators).map((operator, index) => {
  247. const opcode_index = operator.opcode_index;
  248. const opcode = opcode_index < operators.length ? operators[opcode_index] : { name: `(${opcode_index})` };
  249. return new tflite.Node(metadata, operator, opcode, index.toString(), tensors);
  250. });
  251. }
  252. };
  253. tflite.Signature = class {
  254. constructor(signature, tensors) {
  255. this.name = signature.signature_key;
  256. this.inputs = signature.inputs.map((input) => {
  257. const value = tensors.map(input.tensor_index);
  258. const values = value ? [value] : [];
  259. return new tflite.Argument(input.name, values);
  260. });
  261. this.outputs = signature.outputs.map((output) => {
  262. const value = tensors.map(output.tensor_index);
  263. const values = value ? [value] : [];
  264. return new tflite.Argument(output.name, values);
  265. });
  266. }
  267. };
  268. tflite.Node = class {
  269. constructor(metadata, node, type, identifier, tensors) {
  270. this.name = '';
  271. this.identifier = identifier;
  272. this.type = type.custom ? { name: type.name } : metadata.type(type.name);
  273. this.inputs = [];
  274. this.outputs = [];
  275. this.attributes = [];
  276. if (node) {
  277. const attributes = [];
  278. const inputs = Array.from(node.inputs || new Int32Array(0));
  279. for (let i = 0; i < inputs.length;) {
  280. let count = 1;
  281. let name = null;
  282. let visible = true;
  283. const values = [];
  284. if (this.type && this.type.inputs && i < this.type.inputs.length) {
  285. const input = this.type.inputs[i];
  286. name = input.name;
  287. if (input.list) {
  288. count = inputs.length - i;
  289. }
  290. if (input.visible === false) {
  291. visible = false;
  292. }
  293. }
  294. for (const index of inputs.slice(i, i + count)) {
  295. const value = tensors.map(index);
  296. if (value) {
  297. values.push(value);
  298. }
  299. }
  300. name = name ? name : (i + 1).toString();
  301. i += count;
  302. const argument = new tflite.Argument(name, values, null, visible);
  303. this.inputs.push(argument);
  304. }
  305. const outputs = Array.from(node.outputs || new Int32Array(0));
  306. for (let i = 0; i < outputs.length; i++) {
  307. const index = outputs[i];
  308. const value = tensors.map(index);
  309. const values = value ? [value] : [];
  310. let name = (i + 1).toString();
  311. if (this.type && this.type.outputs && i < this.type.outputs.length) {
  312. const output = this.type.outputs[i];
  313. if (output && output.name) {
  314. name = output.name;
  315. }
  316. }
  317. const argument = new tflite.Argument(name, values);
  318. this.outputs.push(argument);
  319. }
  320. if (type.custom && node.custom_options && node.custom_options.length > 0) {
  321. let decoded = false;
  322. if (node.custom_options_format === tflite.schema.CustomOptionsFormat.FLEXBUFFERS) {
  323. try {
  324. const reader = flexbuffers.BinaryReader.open(node.custom_options);
  325. if (reader) {
  326. const custom_options = reader.read();
  327. if (Array.isArray(custom_options)) {
  328. attributes.push([null, 'custom_options', custom_options]);
  329. decoded = true;
  330. } else if (custom_options) {
  331. for (const [key, value] of Object.entries(custom_options)) {
  332. const schema = metadata.attribute(type.name, key);
  333. attributes.push([schema, key, value]);
  334. }
  335. decoded = true;
  336. }
  337. }
  338. } catch {
  339. // continue regardless of error
  340. }
  341. }
  342. if (!decoded) {
  343. const schema = metadata.attribute(type.name, 'custom');
  344. attributes.push([schema, 'custom', Array.from(node.custom_options)]);
  345. }
  346. }
  347. const options = node.builtin_options;
  348. if (options) {
  349. for (const [name, value] of Object.entries(options)) {
  350. if (name === 'fused_activation_function' && value) {
  351. if (value < 1 || value > 5) {
  352. throw new tflite.Error(`Unsupported activation function index '${value}'.`);
  353. }
  354. const list = ['Unknown', 'Relu', 'ReluN1To1', 'Relu6', 'Tanh', 'SignBit'];
  355. const type = list[value];
  356. const node = new tflite.Node(metadata, null, { name: type }, null, []);
  357. this.chain = [node];
  358. }
  359. const schema = metadata.attribute(type.name, name);
  360. attributes.push([schema, name, value]);
  361. }
  362. }
  363. this.attributes = attributes.map(([metadata, name, value]) => {
  364. const type = metadata && metadata.type ? metadata.type : null;
  365. value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  366. let visible = true;
  367. if (name === 'fused_activation_function') {
  368. visible = false;
  369. }
  370. if (type) {
  371. value = tflite.Utility.enum(type, value);
  372. }
  373. if (metadata) {
  374. if (metadata.visible === false) {
  375. visible = false;
  376. } else if (metadata.default !== undefined) {
  377. if (typeof value === 'function') {
  378. value = value();
  379. }
  380. if (value === metadata.default) {
  381. visible = false;
  382. }
  383. }
  384. }
  385. return new tflite.Argument(name, value, type, visible);
  386. });
  387. }
  388. }
  389. };
  390. tflite.Argument = class {
  391. constructor(name, value, type, visible) {
  392. this.name = name;
  393. this.value = value;
  394. this.type = type || null;
  395. this.visible = visible !== false;
  396. }
  397. };
  398. tflite.Value = class {
  399. constructor(index, tensor, initializer, description, denotation) {
  400. const name = tensor.name || '';
  401. this.name = `${name}\n${index}`;
  402. this.identifier = index.toString();
  403. this.type = tensor.type !== undefined && tensor.shape !== undefined ? new tflite.TensorType(tensor, denotation) : null;
  404. this.initializer = initializer;
  405. this.description = description;
  406. const quantization = tensor.quantization;
  407. if (quantization && (quantization.scale.length > 0 || quantization.zero_point.length > 0 || quantization.min.length > 0 || quantization.max.length)) {
  408. this.quantization = {
  409. type: 'linear',
  410. dimension: quantization.quantized_dimension,
  411. scale: quantization.scale,
  412. offset: quantization.zero_point,
  413. min: quantization.min,
  414. max: quantization.max
  415. };
  416. }
  417. }
  418. };
  419. tflite.Tensor = class {
  420. constructor(index, tensor, buffer, stream, is_variable) {
  421. this.identifier = index.toString();
  422. this.name = tensor.name;
  423. this.type = new tflite.TensorType(tensor);
  424. this.category = is_variable ? 'Variable' : '';
  425. this.encoding = this.type.dataType === 'string' ? '|' : '<';
  426. if (buffer && buffer.data && buffer.data.length > 0) {
  427. this._data = buffer.data.slice(0);
  428. } else if (buffer && buffer.offset !== 0n && buffer.size !== 0n) {
  429. const offset = buffer.offset.toNumber();
  430. const size = buffer.size.toNumber();
  431. stream.seek(offset);
  432. this._data = stream.stream(size);
  433. } else {
  434. this._data = null;
  435. }
  436. }
  437. get values() {
  438. switch (this.type.dataType) {
  439. case 'string': {
  440. let offset = 0;
  441. const data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  442. const count = data.getInt32(0, true);
  443. offset += 4;
  444. const offsetTable = [];
  445. for (let j = 0; j < count; j++) {
  446. offsetTable.push(data.getInt32(offset, true));
  447. offset += 4;
  448. }
  449. offsetTable.push(this._data.length);
  450. const stringTable = [];
  451. const utf8Decoder = new TextDecoder('utf-8');
  452. for (let k = 0; k < count; k++) {
  453. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  454. stringTable.push(utf8Decoder.decode(textArray));
  455. }
  456. return stringTable;
  457. }
  458. default: {
  459. if (this._data instanceof Uint8Array) {
  460. return this._data;
  461. }
  462. if (this._data && this._data.peek) {
  463. return this._data.peek();
  464. }
  465. return null;
  466. }
  467. }
  468. }
  469. };
  470. tflite.TensorType = class {
  471. constructor(tensor, denotation) {
  472. const shape = tensor.shape_signature && tensor.shape_signature.length > 0 ? tensor.shape_signature : tensor.shape;
  473. this.dataType = tflite.Utility.dataType(tensor.type);
  474. this.shape = new tflite.TensorShape(Array.from(shape || []));
  475. this.denotation = denotation;
  476. }
  477. toString() {
  478. return this.dataType + this.shape.toString();
  479. }
  480. };
  481. tflite.TensorShape = class {
  482. constructor(dimensions) {
  483. this.dimensions = dimensions;
  484. }
  485. toString() {
  486. if (!this.dimensions || this.dimensions.length === 0) {
  487. return '';
  488. }
  489. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  490. }
  491. };
  492. tflite.Utility = class {
  493. static dataType(type) {
  494. if (!tflite.Utility._tensorTypes) {
  495. tflite.Utility._tensorTypes = new Map(Object.entries(tflite.schema.TensorType).map(([key, value]) => [value, key.toLowerCase()]));
  496. tflite.Utility._tensorTypes.set(6, 'boolean');
  497. }
  498. return tflite.Utility._tensorTypes.has(type) ? tflite.Utility._tensorTypes.get(type) : '?';
  499. }
  500. static enum(name, value) {
  501. const type = name && tflite.schema ? tflite.schema[name] : undefined;
  502. if (type) {
  503. tflite.Utility._enums = tflite.Utility._enums || new Map();
  504. if (!tflite.Utility._enums.has(name)) {
  505. const entries = new Map(Object.entries(type).map(([key, value]) => [value, key]));
  506. tflite.Utility._enums.set(name, entries);
  507. }
  508. const map = tflite.Utility._enums.get(name);
  509. if (map.has(value)) {
  510. return map.get(value);
  511. }
  512. }
  513. return value;
  514. }
  515. };
  516. tflite.Error = class extends Error {
  517. constructor(message) {
  518. super(message);
  519. this.name = 'Error loading TensorFlow Lite model.';
  520. }
  521. };
  522. export const ModelFactory = tflite.ModelFactory;