circle.js 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import * as flatbuffers from './flatbuffers.js';
  2. import * as flexbuffers from './flexbuffers.js';
  3. import * as zip from './zip.js';
  4. const circle = {};
  5. circle.ModelFactory = class {
  6. async match(context) {
  7. const reader = await context.peek('flatbuffers.binary');
  8. if (reader && reader.identifier === 'CIR0') {
  9. return context.set('circle.flatbuffers', reader);
  10. }
  11. const obj = await context.peek('json');
  12. if (obj && obj.subgraphs && obj.operator_codes) {
  13. return context.set('circle.flatbuffers.json', obj);
  14. }
  15. return null;
  16. }
  17. async open(context) {
  18. circle.schema = await context.require('./circle-schema');
  19. circle.schema = circle.schema.circle;
  20. let model = null;
  21. const attachments = new Map();
  22. switch (context.type) {
  23. case 'circle.flatbuffers.json': {
  24. try {
  25. const reader = await context.read('flatbuffers.text');
  26. model = circle.schema.Model.createText(reader);
  27. } catch (error) {
  28. const message = error && error.message ? error.message : error.toString();
  29. throw new circle.Error(`File text format is not circle.Model (${message.replace(/\.$/, '')}).`);
  30. }
  31. break;
  32. }
  33. case 'circle.flatbuffers': {
  34. try {
  35. const reader = context.value;
  36. model = circle.schema.Model.create(reader);
  37. } catch (error) {
  38. const message = error && error.message ? error.message : error.toString();
  39. throw new circle.Error(`File format is not circle.Model (${message.replace(/\.$/, '')}).`);
  40. }
  41. try {
  42. const stream = context.stream;
  43. const archive = zip.Archive.open(stream);
  44. if (archive) {
  45. for (const [name, value] of archive.entries) {
  46. attachments.set(name, value);
  47. }
  48. }
  49. } catch {
  50. // continue regardless of error
  51. }
  52. break;
  53. }
  54. default: {
  55. throw new circle.Error(`Unsupported Circle format '${context.type}'.`);
  56. }
  57. }
  58. const stream = context.stream;
  59. const metadata = await context.metadata('circle-metadata.json');
  60. return new circle.Model(metadata, model, stream);
  61. }
  62. };
  63. circle.Model = class {
  64. constructor(metadata, model, stream) {
  65. this.format = 'Circle';
  66. this.format = `${this.format} v${model.version}`;
  67. this.description = model.description || '';
  68. this.modules = [];
  69. this.metadata = [];
  70. const builtinOperators = new Map();
  71. const upperCase = new Set(['2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM']);
  72. for (const key of Object.keys(circle.schema.BuiltinOperator)) {
  73. const value = key === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : key;
  74. const name = value.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
  75. const index = circle.schema.BuiltinOperator[key];
  76. builtinOperators.set(index, name);
  77. }
  78. const operators = model.operator_codes.map((operator) => {
  79. const code = Math.max(operator.deprecated_builtin_code, operator.builtin_code || 0);
  80. const value = {};
  81. if (code === circle.schema.BuiltinOperator.CUSTOM) {
  82. value.name = operator.custom_code ? operator.custom_code : 'Custom';
  83. value.version = operator.version;
  84. value.custom = true;
  85. } else {
  86. value.name = builtinOperators.has(code) ? builtinOperators.get(code) : code.toString();
  87. value.version = operator.version;
  88. value.custom = false;
  89. }
  90. return value;
  91. });
  92. let modelMetadata = null;
  93. for (const metadata of model.metadata) {
  94. const buffer = model.buffers[metadata.buffer];
  95. let data = null;
  96. const position = stream.position;
  97. if (buffer && buffer.data && buffer.data.length > 0) {
  98. data = buffer.data;
  99. } else if (buffer && buffer.offset !== 0n && buffer.size !== 0n) {
  100. const offset = buffer.offset.toNumber();
  101. const size = buffer.size.toNumber();
  102. stream.seek(offset);
  103. data = stream.read(size);
  104. }
  105. stream.seek(position);
  106. if (data) {
  107. switch (metadata.name) {
  108. case 'min_runtime_version': {
  109. const decoder = new TextDecoder('utf-8');
  110. this.runtime = decoder.decode(data);
  111. break;
  112. }
  113. case 'TFLITE_METADATA': {
  114. const reader = flatbuffers.BinaryReader.open(data);
  115. if (!reader || !circle.schema.ModelMetadata.identifier(reader)) {
  116. throw new circle.Error('Invalid TensorFlow Lite metadata.');
  117. }
  118. modelMetadata = circle.schema.ModelMetadata.create(reader);
  119. if (modelMetadata.name) {
  120. this.name = modelMetadata.name;
  121. }
  122. if (modelMetadata.version) {
  123. this.version = modelMetadata.version;
  124. }
  125. if (modelMetadata.description) {
  126. this.description = this.description ? [this.description, modelMetadata.description].join(' ') : modelMetadata.description;
  127. }
  128. if (modelMetadata.author) {
  129. this.metadata.push(new circle.Argument('author', modelMetadata.author));
  130. }
  131. if (modelMetadata.license) {
  132. this.metadata.push(new circle.Argument('license', modelMetadata.license));
  133. }
  134. break;
  135. }
  136. default: {
  137. const value = data.length < 256 && data.every((c) => c >= 32 && c < 128) ? String.fromCharCode.apply(null, data) : '?';
  138. const argument = new circle.Argument(metadata.name, value);
  139. this.metadata.push(argument);
  140. break;
  141. }
  142. }
  143. }
  144. }
  145. const subgraphs = model.subgraphs;
  146. const subgraphsMetadata = modelMetadata ? modelMetadata.subgraph_metadata : null;
  147. for (let i = 0; i < subgraphs.length; i++) {
  148. const subgraph = subgraphs[i];
  149. const name = subgraphs.length > 1 ? i.toString() : '';
  150. const subgraphMetadata = subgraphsMetadata && i < subgraphsMetadata.length ? subgraphsMetadata[i] : null;
  151. const signatures = model.signature_defs.filter((signature) => signature.subgraph_index === i);
  152. const graph = new circle.Graph(metadata, subgraph, signatures, subgraphMetadata, name, operators, model, stream);
  153. this.modules.push(graph);
  154. }
  155. }
  156. };
  157. circle.Graph = class {
  158. constructor(metadata, subgraph, signatures, subgraphMetadata, name, operators, model, stream) {
  159. this.name = subgraph.name || name;
  160. if (subgraph.operators.length === 0 && subgraph.tensors.length > 0 && operators.length === 0) {
  161. operators.push({ name: 'Weights', custom: true });
  162. const layers = new Map();
  163. for (let i = 0; i < subgraph.tensors.length; i++) {
  164. const tensor = subgraph.tensors[i];
  165. const parts = tensor.name.split('.');
  166. parts.pop();
  167. const key = parts.join('.');
  168. if (!layers.has(key)) {
  169. const operator = { opcode_index: 0, inputs: [], outputs: [] };
  170. layers.set(key, operator);
  171. subgraph.operators.push(operator);
  172. }
  173. const operator = layers.get(key);
  174. operator.inputs.push(i);
  175. }
  176. }
  177. const tensors = new Map();
  178. tensors.map = (index, metadata) => {
  179. if (index === -1) {
  180. return null;
  181. }
  182. if (!tensors.has(index)) {
  183. let tensor = { name: '' };
  184. let initializer = null;
  185. let description = '';
  186. let denotation = '';
  187. if (index < subgraph.tensors.length) {
  188. tensor = subgraph.tensors[index];
  189. const buffer = model.buffers[tensor.buffer];
  190. const is_variable = tensor.is_variable;
  191. const variable = is_variable || (buffer && buffer.data && buffer.data.length > 0) || (buffer && buffer.offset !== 0n && buffer.size !== 0n);
  192. initializer = variable ? new circle.Tensor(index, tensor, buffer, stream, is_variable) : null;
  193. }
  194. if (metadata) {
  195. description = metadata.description;
  196. const content = metadata.content;
  197. if (content) {
  198. const contentProperties = content.content_properties;
  199. if (contentProperties instanceof circle.schema.FeatureProperties) {
  200. denotation = 'Feature';
  201. } else if (contentProperties instanceof circle.schema.ImageProperties) {
  202. denotation = 'Image';
  203. switch (contentProperties.color_space) {
  204. case 0: denotation += '(Unknown)'; break;
  205. case 1: denotation += '(RGB)'; break;
  206. case 2: denotation += '(Grayscale)'; break;
  207. default: throw circle.Error(`Unsupported image color space '${contentProperties.color_space}'.`);
  208. }
  209. } else if (contentProperties instanceof circle.schema.BoundingBoxProperties) {
  210. denotation = 'BoundingBox';
  211. } else if (contentProperties instanceof circle.schema.AudioProperties) {
  212. denotation = `Audio(${contentProperties.sample_rate},${contentProperties.channels})`;
  213. }
  214. }
  215. }
  216. const value = new circle.Value(index, tensor, initializer, description, denotation);
  217. tensors.set(index, value);
  218. }
  219. return tensors.get(index);
  220. };
  221. this.inputs = Array.from(subgraph.inputs).map((tensor_index, index) => {
  222. const metadata = subgraphMetadata && index < subgraphMetadata.input_tensor_metadata.length ? subgraphMetadata.input_tensor_metadata[index] : null;
  223. const value = tensors.map(tensor_index, metadata);
  224. const values = value ? [value] : [];
  225. const name = value ? value.name.split('\n')[0] : '?';
  226. return new circle.Argument(name, values);
  227. });
  228. this.outputs = Array.from(subgraph.outputs).map((tensor_index, index) => {
  229. const metadata = subgraphMetadata && index < subgraphMetadata.output_tensor_metadata.length ? subgraphMetadata.output_tensor_metadata[index] : null;
  230. const value = tensors.map(tensor_index, metadata);
  231. const values = value ? [value] : [];
  232. const name = value ? value.name.split('\n')[0] : '?';
  233. return new circle.Argument(name, values);
  234. });
  235. this.signatures = signatures.map((signature) => {
  236. return new circle.Signature(signature, tensors);
  237. });
  238. this.nodes = Array.from(subgraph.operators).map((operator, index) => {
  239. const opcode_index = operator.opcode_index;
  240. const opcode = opcode_index < operators.length ? operators[opcode_index] : { name: `(${opcode_index})` };
  241. return new circle.Node(metadata, operator, opcode, index.toString(), tensors);
  242. });
  243. }
  244. };
  245. circle.Signature = class {
  246. constructor(signature, tensors) {
  247. this.name = signature.signature_key;
  248. this.inputs = signature.inputs.map((input) => {
  249. const value = tensors.map(input.tensor_index);
  250. const values = value ? [value] : [];
  251. return new circle.Argument(input.name, values);
  252. });
  253. this.outputs = signature.outputs.map((output) => {
  254. const value = tensors.map(output.tensor_index);
  255. const values = value ? [value] : [];
  256. return new circle.Argument(output.name, values);
  257. });
  258. }
  259. };
  260. circle.Node = class {
  261. constructor(metadata, node, type, identifier, tensors) {
  262. this.name = '';
  263. this.identifier = identifier;
  264. this.type = type.custom ? { name: type.name } : metadata.type(type.name);
  265. this.inputs = [];
  266. this.outputs = [];
  267. this.attributes = [];
  268. if (node) {
  269. const attributes = [];
  270. const inputs = Array.from(node.inputs || new Int32Array(0));
  271. for (let i = 0; i < inputs.length;) {
  272. let count = 1;
  273. let name = null;
  274. let visible = true;
  275. const values = [];
  276. if (this.type && this.type.inputs && i < this.type.inputs.length) {
  277. const input = this.type.inputs[i];
  278. name = input.name;
  279. if (input.list) {
  280. count = inputs.length - i;
  281. }
  282. if (input.visible === false) {
  283. visible = false;
  284. }
  285. }
  286. for (const index of inputs.slice(i, i + count)) {
  287. const value = tensors.map(index);
  288. if (value) {
  289. values.push(value);
  290. }
  291. }
  292. name = name ? name : (i + 1).toString();
  293. i += count;
  294. const argument = new circle.Argument(name, values, null, visible);
  295. this.inputs.push(argument);
  296. }
  297. const outputs = Array.from(node.outputs || new Int32Array(0));
  298. for (let i = 0; i < outputs.length; i++) {
  299. const index = outputs[i];
  300. const value = tensors.map(index);
  301. const values = value ? [value] : [];
  302. let name = (i + 1).toString();
  303. if (this.type && this.type.outputs && i < this.type.outputs.length) {
  304. const output = this.type.outputs[i];
  305. if (output && output.name) {
  306. name = output.name;
  307. }
  308. }
  309. const argument = new circle.Argument(name, values);
  310. this.outputs.push(argument);
  311. }
  312. if (type.custom && node.custom_options && node.custom_options.length > 0) {
  313. let decoded = false;
  314. if (node.custom_options_format === circle.schema.CustomOptionsFormat.FLEXBUFFERS) {
  315. try {
  316. const reader = flexbuffers.BinaryReader.open(node.custom_options);
  317. if (reader) {
  318. const custom_options = reader.read();
  319. if (Array.isArray(custom_options)) {
  320. attributes.push([null, 'custom_options', custom_options]);
  321. decoded = true;
  322. } else if (custom_options) {
  323. for (const [key, value] of Object.entries(custom_options)) {
  324. const schema = metadata.attribute(type.name, key);
  325. attributes.push([schema, key, value]);
  326. }
  327. decoded = true;
  328. }
  329. }
  330. } catch {
  331. // continue regardless of error
  332. }
  333. }
  334. if (!decoded) {
  335. const schema = metadata.attribute(type.name, 'custom');
  336. attributes.push([schema, 'custom', Array.from(node.custom_options)]);
  337. }
  338. }
  339. const options = node.builtin_options;
  340. if (options) {
  341. for (const [name, value] of Object.entries(options)) {
  342. if (name === 'fused_activation_function' && value) {
  343. const ActivationFunctionType = circle.schema.ActivationFunctionType;
  344. let type = '';
  345. switch (value) {
  346. case ActivationFunctionType.RELU: type = 'Relu'; break;
  347. case ActivationFunctionType.RELU_N1_TO_1: type = 'ReluN1To1'; break;
  348. case ActivationFunctionType.RELU6: type = 'Relu6'; break;
  349. case ActivationFunctionType.TANH: type = 'Tanh'; break;
  350. case ActivationFunctionType.SIGN_BIT: type = 'SignBit'; break;
  351. case 6: type = 'Sigmoid'; break;
  352. default: type = value.toString(); break;
  353. }
  354. const node = new circle.Node(metadata, null, { name: type }, null, []);
  355. this.chain = [node];
  356. }
  357. const schema = metadata.attribute(type.name, name);
  358. attributes.push([schema, name, value]);
  359. }
  360. }
  361. this.attributes = attributes.map(([metadata, name, value]) => {
  362. const type = metadata && metadata.type ? metadata.type : null;
  363. value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  364. let visible = true;
  365. if (name === 'fused_activation_function') {
  366. visible = false;
  367. }
  368. if (type) {
  369. const enumType = circle.schema[type];
  370. if (enumType) {
  371. value = enumType[value] || value;
  372. }
  373. }
  374. if (metadata) {
  375. if (metadata.visible === false) {
  376. visible = false;
  377. } else if (metadata.default !== undefined) {
  378. if (typeof value === 'function') {
  379. value = value();
  380. }
  381. if (value === metadata.default) {
  382. visible = false;
  383. }
  384. }
  385. }
  386. return new circle.Argument(name, value, type, visible);
  387. });
  388. }
  389. }
  390. };
  391. circle.Argument = class {
  392. constructor(name, value, type = null, visible = true) {
  393. this.name = name;
  394. this.value = value;
  395. this.type = type;
  396. this.visible = visible;
  397. }
  398. };
  399. circle.Value = class {
  400. constructor(index, tensor, initializer, description, denotation) {
  401. const name = tensor.name || '';
  402. this.name = `${name}\n${index}`;
  403. this.identifier = index.toString();
  404. this.type = tensor.type !== undefined && tensor.shape !== undefined ? new circle.TensorType(tensor, denotation) : null;
  405. this.initializer = initializer;
  406. this.description = description;
  407. const quantization = tensor.quantization;
  408. if (quantization && (quantization.scale.length > 0 || quantization.zero_point.length > 0 || quantization.min.length > 0 || quantization.max.length)) {
  409. this.quantization = {
  410. type: 'linear',
  411. dimension: quantization.quantized_dimension,
  412. scale: quantization.scale,
  413. offset: quantization.zero_point,
  414. min: quantization.min,
  415. max: quantization.max
  416. };
  417. }
  418. }
  419. };
  420. circle.Tensor = class {
  421. constructor(index, tensor, buffer, stream, is_variable) {
  422. this.identifier = index.toString();
  423. this.name = tensor.name;
  424. this.type = new circle.TensorType(tensor);
  425. this.category = is_variable ? 'Variable' : '';
  426. this.encoding = this.type.dataType === 'string' ? '|' : '<';
  427. if (buffer && buffer.data && buffer.data.length > 0) {
  428. this._data = buffer.data.slice(0);
  429. } else if (buffer && buffer.offset !== 0n && buffer.size !== 0n) {
  430. const offset = buffer.offset.toNumber();
  431. const size = buffer.size.toNumber();
  432. stream.seek(offset);
  433. this._data = stream.stream(size);
  434. } else {
  435. this._data = null;
  436. }
  437. }
  438. get values() {
  439. switch (this.type.dataType) {
  440. case 'string': {
  441. let offset = 0;
  442. const data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  443. const count = data.getInt32(0, true);
  444. offset += 4;
  445. const offsetTable = [];
  446. for (let j = 0; j < count; j++) {
  447. offsetTable.push(data.getInt32(offset, true));
  448. offset += 4;
  449. }
  450. offsetTable.push(this._data.length);
  451. const stringTable = [];
  452. const utf8Decoder = new TextDecoder('utf-8');
  453. for (let k = 0; k < count; k++) {
  454. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  455. stringTable.push(utf8Decoder.decode(textArray));
  456. }
  457. return stringTable;
  458. }
  459. default: {
  460. if (this._data instanceof Uint8Array) {
  461. return this._data;
  462. }
  463. if (this._data && this._data.peek) {
  464. return this._data.peek();
  465. }
  466. return null;
  467. }
  468. }
  469. }
  470. };
  471. circle.TensorType = class {
  472. constructor(tensor, denotation) {
  473. const shape = tensor.shape_signature && tensor.shape_signature.length > 0 ? tensor.shape_signature : tensor.shape;
  474. switch (tensor.type) {
  475. case circle.schema.TensorType.BOOL: this.dataType = 'boolean'; break;
  476. default: {
  477. const name = circle.schema.TensorType[tensor.type];
  478. this.dataType = name ? name.toLowerCase() : '?';
  479. break;
  480. }
  481. }
  482. this.shape = new circle.TensorShape(Array.from(shape || []));
  483. this.denotation = denotation;
  484. }
  485. toString() {
  486. return this.dataType + this.shape.toString();
  487. }
  488. };
  489. circle.TensorShape = class {
  490. constructor(dimensions) {
  491. this.dimensions = dimensions;
  492. }
  493. toString() {
  494. if (!this.dimensions || this.dimensions.length === 0) {
  495. return '';
  496. }
  497. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  498. }
  499. };
  500. circle.Error = class extends Error {
  501. constructor(message) {
  502. super(message);
  503. this.name = 'Error loading Circle model.';
  504. }
  505. };
  506. export const ModelFactory = circle.ModelFactory;