acuity.js 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. const acuity = {};
  2. acuity.ModelFactory = class {
  3. async match(context) {
  4. const obj = await context.peek('json');
  5. if (obj && obj.MetaData && obj.Layers && Object.keys(obj).length < 256) {
  6. return context.set('acuity', obj);
  7. }
  8. return null;
  9. }
  10. async open(context) {
  11. const metadata = await context.metadata('acuity-metadata.json');
  12. return new acuity.Model(metadata, context.value);
  13. }
  14. };
  15. acuity.Model = class {
  16. constructor(metadata, model, data, quantization) {
  17. this.name = model.MetaData.Name;
  18. this.format = `Acuity v${model.MetaData.AcuityVersion}`;
  19. this.runtime = model.MetaData.Platform;
  20. this.modules = [new acuity.Graph(metadata, model, data, quantization)];
  21. }
  22. };
  23. acuity.Graph = class {
  24. constructor(metadata, model) {
  25. this.nodes = [];
  26. this.inputs = [];
  27. this.outputs = [];
  28. this.metrics = [];
  29. const values = new Map();
  30. const value = (name) => {
  31. if (!values.has(name)) {
  32. values.set(name, { name, shape: null });
  33. }
  34. return values.get(name);
  35. };
  36. let totalFlops = 0;
  37. for (const [name, layer] of Object.entries(model.Layers)) {
  38. layer.inputs = layer.inputs.map((input) => {
  39. return value(input);
  40. });
  41. layer.outputs = layer.outputs.map((port) => {
  42. const output = value(`@${name}:${port}`);
  43. let shape = null;
  44. if (layer.op.toLowerCase() === 'input' ||
  45. layer.op.toLowerCase() === 'variable') {
  46. if (Object.prototype.hasOwnProperty.call(layer.parameters, 'shape') && layer.parameters.shape.length > 0) {
  47. shape = layer.parameters.shape;
  48. } else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'size') && Object.prototype.hasOwnProperty.call(layer.parameters, 'channels')) {
  49. const sizes = layer.parameters.size.split(' ');
  50. shape = [0, parseInt(sizes[0], 10), parseInt(sizes[1], 10), layer.parameters.channels];
  51. } else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'is_scalar')) {
  52. shape = [1];
  53. }
  54. if (shape && shape.length === 4 && shape[0] === 0) {
  55. shape[0] = 1;
  56. }
  57. }
  58. output.shape = shape;
  59. return output;
  60. });
  61. // Add other layer types (e.g., pooling, batch norm, etc.) as needed.
  62. if (layer.type === 'Conv2D') {
  63. const { kernelShape, inputShape, outputShape } = layer;
  64. const [kH, kW] = kernelShape;
  65. const [inC] = inputShape;
  66. const [outC, oH, oW] = outputShape;
  67. totalFlops += kH * kW * inC * oH * oW * outC;
  68. } else if (layer.type === 'Dense') {
  69. const { inputSize, outputSize } = layer;
  70. totalFlops += inputSize * outputSize;
  71. }
  72. }
  73. this.metrics.push(new acuity.Argument('flops', totalFlops));
  74. acuity.Inference.infer(model.Layers);
  75. for (const [name, obj] of values) {
  76. const type = new acuity.TensorType(null, new acuity.TensorShape(obj.shape));
  77. const value = new acuity.Value(name, type, null, null);
  78. values.set(name, value);
  79. }
  80. for (const [name, layer] of Object.entries(model.Layers)) {
  81. switch (layer.op.toLowerCase()) {
  82. case 'input': {
  83. const value = values.get(layer.outputs[0].name);
  84. const argument = new acuity.Argument(name, [value]);
  85. this.inputs.push(argument);
  86. break;
  87. }
  88. case 'output': {
  89. const value = values.get(layer.inputs[0].name);
  90. const argument = new acuity.Argument(name, [value]);
  91. this.outputs.push(argument);
  92. break;
  93. }
  94. default: {
  95. const node = new acuity.Node(metadata, name, layer, values);
  96. this.nodes.push(node);
  97. break;
  98. }
  99. }
  100. }
  101. }
  102. };
  103. acuity.Node = class {
  104. constructor(metadata, name, layer, values) {
  105. const op = layer.op;
  106. this.name = name;
  107. this.type = metadata.type(op) || { name: op };
  108. this.inputs = [];
  109. this.outputs = [];
  110. this.attributes = [];
  111. if (this.type) {
  112. if (layer.parameters) {
  113. for (const [name, value] of Object.entries(layer.parameters)) {
  114. const meta = metadata.attribute(op, name);
  115. const type = meta && meta.type ? meta.type : null;
  116. const visible = meta && meta.default !== undefined && meta.default === value ? false : true;
  117. const attribute = new acuity.Argument(name, value, type, visible);
  118. this.attributes.push(attribute);
  119. }
  120. }
  121. }
  122. for (let i = 0; i < layer.inputs.length; i++) {
  123. const input = layer.inputs[i];
  124. const value = values.get(input.name);
  125. const name = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i].name : `input${i}`;
  126. const argument = new acuity.Argument(name, [value]);
  127. this.inputs.push(argument);
  128. }
  129. if (this.type && this.type.constants) {
  130. for (const constant of this.type.constants) {
  131. // const name = "@" + this.name + ":" + constant.name;
  132. const type = new acuity.TensorType(null, new acuity.TensorShape(null));
  133. const value = new acuity.Value('', type, null, new acuity.Tensor(type));
  134. const argument = new acuity.Argument(constant.name, [value]);
  135. this.inputs.push(argument);
  136. }
  137. }
  138. for (let i = 0; i < layer.outputs.length; i++) {
  139. const output = layer.outputs[i];
  140. const value = values.get(output.name);
  141. const name = this.type && this.type.outputs && i < this.type.outputs.length ? this.type.outputs[i].name : `output${i}`;
  142. const argument = new acuity.Argument(name, [value]);
  143. this.outputs.push(argument);
  144. }
  145. }
  146. };
  147. acuity.Argument = class {
  148. constructor(name, value, type, visible) {
  149. this.name = name;
  150. this.value = value;
  151. if (type) {
  152. this.type = type;
  153. }
  154. if (visible === false) {
  155. this.visible = false;
  156. }
  157. }
  158. };
  159. acuity.Value = class {
  160. constructor(name, type, quantization, initializer) {
  161. if (typeof name !== 'string') {
  162. throw new acuity.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  163. }
  164. this.name = name;
  165. this.type = type || null;
  166. this.quantization = quantization || null;
  167. this.initializer = initializer || null;
  168. }
  169. };
  170. acuity.TensorType = class {
  171. constructor(dataType, shape) {
  172. this.dataType = dataType || '?';
  173. this.shape = shape;
  174. }
  175. toString() {
  176. return (this.dataType || '?') + this.shape.toString();
  177. }
  178. };
  179. acuity.TensorShape = class {
  180. constructor(dimensions) {
  181. this.dimensions = Array.isArray(dimensions) && dimensions.length === 1 && dimensions[0] === 0 ? [] : dimensions;
  182. }
  183. toString() {
  184. if (!Array.isArray(this.dimensions) || this.dimensions.length === 0 || (this.dimensions.length === 1 && this.dimensions[0] === 0)) {
  185. return '';
  186. }
  187. return `[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`;
  188. }
  189. };
  190. acuity.Tensor = class {
  191. constructor(type) {
  192. this.type = type;
  193. this.Category = 'Constant';
  194. }
  195. };
  196. acuity.Inference = class {
  197. static infer(layers) {
  198. const outputs = new Map();
  199. const outputLayers = [];
  200. for (const [, layer] of Object.entries(layers)) {
  201. if (layer.op.toLowerCase() === 'output') {
  202. outputLayers.push(layer);
  203. }
  204. for (const output of layer.outputs) {
  205. outputs.set(output.name, layer);
  206. }
  207. }
  208. const broadcasts = new Set([
  209. 'add', 'equal', 'fllor_mod', 'floor_div', 'greater', 'greater_equal', 'less', 'less_equal',
  210. 'logical_and', 'logical_or', 'minimum', 'multiply', 'not_equal', 'pow', 'real_div',
  211. 'squared_difference', 'subtract', 'divide', 'addn', 'Divide', 'bitwise_and', 'bitwise_or',
  212. 'bitwise_xor', 'average', 'logical_not', 'logical_xor'
  213. ]);
  214. const passthroughs = new Set([
  215. 'LocalResponseNormalization', 'a_times_b_plus_c', 'abs', 'batchnorm_single', 'batchnormalize',
  216. 'cast', 'cast', 'clipbyvalue', 'dequantize', 'dtype_converter', 'elu', 'exp', 'floor',
  217. 'groupnormalize', 'hard_sigmoid', 'hard_swish', 'instancenormalize', 'l2normalize', 'l2normalizescale',
  218. 'layernormalize', 'leakyrelu', 'log', 'log_softmax', 'mish', 'neg', 'norm_with_channel_mean',
  219. 'norm_with_min_max', 'norm_with_scale', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras',
  220. 'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh',
  221. 'swish', 'gelu', 'dropout', 'eltwise', 'cos', 'l1_layernormalize', 'inverse_sigmoid', 'selu', 'mod',
  222. 'mish', 'minimum_with_clip', 'celu', 'cumsum', 'dft', 'dropout2', 'erf', 'noop', 'squashing', 'tan', 'ceil',
  223. 'atan', 'atan2', 'atanh', 'alpha_dropout', 'acosh', 'rmsnormalize', 'sign'
  224. ]);
  225. const reduces = new Set([
  226. 'reduceany', 'reducemax', 'reducemean', 'reducemin', 'reduceprod', 'reducesum'
  227. ]);
  228. const poolings = new Set([
  229. 'pooling', 'l2pooling'
  230. ]);
  231. const operators = new Map();
  232. operators.set('broadcast', ([a, b]) => {
  233. const longer = a.length >= b.length ? a.slice() : b.slice();
  234. const shorter = a.length < b.length ? a.slice() : b.slice();
  235. const remain = longer.length - shorter.length;
  236. for (let i = 0; i < remain; i++) {
  237. shorter.splice(0, 0, 1);
  238. }
  239. for (let i = 0; i < longer.length; i++) {
  240. longer[i] = longer[i] > shorter[i] ? longer[i] : shorter[i];
  241. }
  242. return [longer];
  243. });
  244. operators.set('concat', (inputs, params) => {
  245. const outputShape = inputs[0].slice();
  246. outputShape[params.dim] = 0;
  247. for (const shape of inputs) {
  248. outputShape[params.dim] += shape[params.dim];
  249. }
  250. return [outputShape];
  251. });
  252. operators.set('conv1d', (inputs, params) => {
  253. if (params.padding === 'VALID') {
  254. const out_h = ~~((inputs[0][1] + params.stride - params.ksize) / params.stride);
  255. return [[inputs[0][0], out_h, params.weights]];
  256. } else if (params.padding === 'SAME') {
  257. const out_h = ~~((inputs[0][1] + params.stride - 1) / params.stride);
  258. return [[inputs[0][0], out_h, params.weights]];
  259. }
  260. return null;
  261. });
  262. operators.set('convolution', (inputs, params) => {
  263. if (params.padding === 'VALID') {
  264. const out_h = Math.floor((inputs[0][1] + params.stride_h + 2 * params.pad_h - params.ksize_h) / params.stride_h);
  265. const out_w = Math.floor((inputs[0][2] + params.stride_w + 2 * params.pad_w - params.ksize_w) / params.stride_w);
  266. return [[inputs[0][0], out_h, out_w, params.weights]];
  267. } else if (params.padding === 'SAME') {
  268. const out_h = Math.floor((inputs[0][1] + params.stride_h - 1) / params.stride_h);
  269. const out_w = Math.floor((inputs[0][2] + params.stride_w - 1) / params.stride_w);
  270. return [[inputs[0][0], out_h, out_w, params.weights]];
  271. }
  272. return null;
  273. });
  274. operators.set('depthwise_conv1d', (inputs, params) => {
  275. if (params.padding === 'VALID') {
  276. const out_h = ~~((inputs[0][1] + params.stride + params.pad[0] + params.pad[1] - params.ksize) / params.stride);
  277. return [[inputs[0][0], out_h, inputs[0][2] * params.multiplier]];
  278. } else if (params.padding === 'SAME') {
  279. const out_h = ~~((inputs[0][1] + params.stride - 1) / params.stride);
  280. return [[inputs[0][0], out_h, inputs[0][2] * params.multiplier]];
  281. }
  282. return null;
  283. });
  284. operators.set('depthwise_convolution', (inputs, params) => {
  285. if (params.padding === 'VALID') {
  286. const out_h = ~~((inputs[0][1] + params.stride_h + params.pad[0] + params.pad[1] - params.ksize_h) / params.stride_h);
  287. const out_w = ~~((inputs[0][2] + params.stride_w + params.pad[2] + params.pad[3] - params.ksize_w) / params.stride_w);
  288. return [[inputs[0][0], out_h, out_w, inputs[0][3] * params.multiplier]];
  289. } else if (params.padding === 'SAME') {
  290. const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
  291. const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
  292. return [[inputs[0][0], out_h, out_w, inputs[0][3] * params.multiplier]];
  293. }
  294. return null;
  295. });
  296. operators.set('deconvolution', (inputs, params) => {
  297. return [params.output_shape.map((item, index) => item === 0 ? inputs[0][index] : item)];
  298. });
  299. operators.set('deconvolution1d', (inputs, params) => {
  300. return [params.output_shape.map((item, index) => item === 0 ? inputs[0][index] : item)];
  301. });
  302. operators.set('fullconnect', (inputs, params) => {
  303. return [inputs[0].slice(0, params.axis).concat([params.weights])];
  304. });
  305. operators.set('gather', (inputs, params) => {
  306. const prefix = inputs[1].slice();
  307. const suffix = inputs[0].slice(params.axis + 1);
  308. return [prefix.concat(suffix)];
  309. });
  310. operators.set('lstm', (inputs, params) => {
  311. const [input] = inputs;
  312. const [a, b] = input;
  313. let batch = a;
  314. const output = params.num_proj === null ? params.weights : params.num_proj;
  315. if (params.time_major) {
  316. batch = b;
  317. }
  318. const newShape = params.return_sequences ? [a, b, output] : [batch, output];
  319. return [newShape, [batch, output], [batch, params.weights]];
  320. });
  321. operators.set('matmul', ([a, b], params) => {
  322. let newShape = a.slice(0, -2);
  323. if (params.transpose_a) {
  324. newShape = newShape.concat(a.slice(-1));
  325. } else {
  326. newShape = newShape.concat(a.slice(-2, -1));
  327. }
  328. if (params.transpose_b) {
  329. newShape = newShape.concat(b.slice(-2, -1));
  330. } else {
  331. newShape = newShape.concat(b.slice(-1));
  332. }
  333. return [newShape];
  334. });
  335. operators.set('pad', (inputs, params) => {
  336. return [inputs[0].map((item, index) => item + params.padding_value[index][0] + params.padding_value[index][1])];
  337. });
  338. operators.set('permute', (inputs, params) => {
  339. return [inputs[0].map((item, index) => inputs[0][params.perm[index]])];
  340. });
  341. operators.set('pooling', (inputs, params) => {
  342. if (params.padding === 'VALID') {
  343. const out_h = ~~((inputs[0][1] + params.stride_h - params.ksize_h) / params.stride_h);
  344. const out_w = ~~((inputs[0][2] + params.stride_w - params.ksize_w) / params.stride_w);
  345. return [[inputs[0][0], out_h, out_w, inputs[0][3]]];
  346. } else if (params.padding === 'SAME') {
  347. const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
  348. const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
  349. return [[inputs[0][0], out_h, out_w, inputs[0][3]]];
  350. }
  351. return null;
  352. });
  353. operators.set('reduce', (inputs, params) => {
  354. const newShape = inputs[0].slice();
  355. const axis_list = params.axis_list.map((item) => {
  356. return item < 0 ? newShape.length + item : item;
  357. });
  358. axis_list.sort((a, b) => {
  359. return b - a;
  360. });
  361. axis_list.forEach((i) => {
  362. newShape[i] = 1;
  363. });
  364. if (!params.keep_dims) {
  365. axis_list.forEach((i) => {
  366. newShape.splice(i, 1);
  367. });
  368. if (!newShape.length) {
  369. newShape.splice(0, 0, 0);
  370. }
  371. }
  372. return [newShape];
  373. });
  374. operators.set('repeat', (inputs, params) => {
  375. const newShape = inputs[0].slice();
  376. newShape[params.axis] = params.maxlen;
  377. return [newShape];
  378. });
  379. operators.set('reshape', (inputs, params) => {
  380. const negativeIndexs = [];
  381. let shape = params.shape;
  382. if (typeof params.shape === 'string') {
  383. shape = params.shape.split(/\s+/).map((item) => {
  384. return parseInt(item, 10);
  385. });
  386. }
  387. const newShape = shape.map((item, index) => {
  388. if (item === 0) {
  389. return inputs[0][index];
  390. }
  391. if (item === -1) {
  392. negativeIndexs.push(index);
  393. return 1;
  394. }
  395. return item;
  396. });
  397. if (negativeIndexs.length > 0) {
  398. newShape[negativeIndexs[0]] = inputs[0].reduce((a, c) => a * c) / newShape.reduce((a, c) => a * c);
  399. }
  400. return [newShape];
  401. });
  402. operators.set('sequence_mask', (inputs, params) => {
  403. return [inputs[0].slice().concat([params.maxlen])];
  404. });
  405. operators.set('slice', (inputs, params) => {
  406. return [params.size.map((item, index) => item === -1 ? inputs[0][index] : item)];
  407. });
  408. operators.set('squeeze', (inputs, params) => {
  409. const newShape = inputs[0].slice();
  410. const axis_list = [...new Set(params.axis_list)].sort((a, b) => b - a);
  411. for (const item of axis_list) {
  412. newShape.splice(item, 1);
  413. }
  414. return [newShape];
  415. });
  416. operators.set('space2depth', (inputs, params) => {
  417. const h = inputs[0][1] / params.block_size[0];
  418. const w = inputs[0][2] / params.block_size[1];
  419. const c = inputs[0][3] * params.block_size[1] * params.block_size[1];
  420. return [[inputs[0][0], h, w, c]];
  421. });
  422. operators.set('depth2space', (inputs, params) => {
  423. const h = inputs[0][1] * params.block_size;
  424. const w = inputs[0][2] * params.block_size;
  425. const c = inputs[0][3] / (params.block_size * params.block_size);
  426. return [[inputs[0][0], h, w, c]];
  427. });
  428. operators.set('upsampling', (inputs, params) => {
  429. const h = inputs[0][1] * params.factor;
  430. const w = inputs[0][2] * params.factor;
  431. return [[inputs[0][0], h, w, inputs[0][3]]];
  432. });
  433. operators.set('crop_image', (inputs, params) => {
  434. return [[inputs[0][0], params.crop_size[0], params.crop_size[1], inputs[0][3]]];
  435. });
  436. operators.set('split', (inputs, params) => {
  437. const sizes = [];
  438. const slices = params.slices.slice();
  439. slices.splice(0, 0, 0);
  440. slices.push(inputs[0][params.dim]);
  441. slices.reduce((a, b) => {
  442. sizes.push(b - a);
  443. return b;
  444. });
  445. return sizes.map((item) => {
  446. const shape = inputs[0].slice();
  447. shape[params.dim] = item;
  448. return shape;
  449. });
  450. });
  451. operators.set('stack', (inputs, params) => {
  452. const newShape = inputs[0].slice();
  453. if (newShape.length === 1 && newShape[0] === 0) {
  454. newShape[0] = 1;
  455. } else {
  456. newShape.splice(params.axis, 0, inputs.length);
  457. }
  458. return [newShape];
  459. });
  460. operators.set('stridedslice', (inputs, params) => {
  461. const input_shape = inputs[0].slice();
  462. const begin = params.slice_begin.slice();
  463. const end = params.slice_end.slice();
  464. if (params.slice_begin_mask > 0) {
  465. for (let i = 0; i < begin.length; i++) {
  466. if ((params.slice_begin_mask >>> i) & 0x1) {
  467. begin[i] = -1;
  468. }
  469. }
  470. }
  471. if (params.slice_end_mask > 0) {
  472. for (let i = 0; i < end.length; i++) {
  473. if ((params.slice_end_mask >>> i) & 0x1) {
  474. end[i] = -1;
  475. }
  476. }
  477. }
  478. for (let i = 0; i < begin.length; i++) {
  479. if (begin[i] === -1) {
  480. begin[i] = 0;
  481. }
  482. }
  483. if (inputs[0].length === end.length) {
  484. for (let i = 0; i < end.length; i++) {
  485. if (end[i] === -1 || end[i] > input_shape[i]) {
  486. end[i] = input_shape[i];
  487. }
  488. }
  489. } else if (inputs[0].length < end.length) {
  490. if (params.slice_new_axis_mask) {
  491. const len = (params.slice_new_axis_mask >>> 0).toString(2).length;
  492. for (let i = 0; i < len; i++) {
  493. if ((params.slice_new_axis_mask >>> i) & 0x1) {
  494. input_shape.splice(i, 0, 1);
  495. }
  496. }
  497. for (let i = 0; i < end.length; i++) {
  498. if (end[i] === -1) {
  499. end[i] = input_shape[i];
  500. }
  501. }
  502. }
  503. }
  504. let newShape = [];
  505. for (let i = 0; i < begin.length; i++) {
  506. newShape = newShape.concat([(end[i] - begin[i]) / params.slice_strides[i]]);
  507. }
  508. if (params.slice_shrink_axis_mask) {
  509. const len = (params.slice_shrink_axis_mask >>> 0).toString(2).length;
  510. for (let i = 0; i < len; i++) {
  511. if ((params.slice_shrink_axis_mask >>> i) & 0x1) {
  512. newShape.splice(i, 1);
  513. }
  514. }
  515. }
  516. if (params.slice_new_axis_mask) {
  517. const len = (params.slice_new_axis_mask >>> 0).toString(2).length;
  518. for (let i = 0; i < len; i++) {
  519. if ((params.slice_new_axis_mask >>> i) & 0x1) {
  520. if (inputs[0].length === begin.length) {
  521. newShape.splice(i, 0, 1);
  522. } else if (inputs[0].length < begin.length) {
  523. newShape[i] = 1;
  524. }
  525. }
  526. }
  527. }
  528. return [newShape];
  529. });
  530. operators.set('image_resize', (inputs, params) => {
  531. const newShape = inputs[0].slice();
  532. /* eslint-disable prefer-destructuring */
  533. newShape[1] = params.new_size[0];
  534. newShape[2] = params.new_size[1];
  535. /* eslint-enable prefer-destructuring */
  536. return [newShape];
  537. });
  538. operators.set('argmax', (inputs, params) => {
  539. const newShape = inputs[0].slice();
  540. if (params.keepdims) {
  541. newShape[params.axis] = 1;
  542. } else {
  543. newShape.splice(params.axis, 1);
  544. if (!newShape.length) {
  545. newShape.splice(0, 0, 0);
  546. }
  547. }
  548. return [newShape];
  549. });
  550. operators.set('argmin', operators.get('argmax'));
  551. /* eslint-disable no-unused-vars */
  552. operators.set('shapelayer', (inputs, params) => {
  553. return [[inputs[0].length]];
  554. });
  555. operators.set('capsule_norm', (inputs, params) => {
  556. return [[inputs[0][0], inputs[0][inputs[0].length - 1]]];
  557. });
  558. operators.set('size', (inputs, params) => {
  559. return [[1]];
  560. });
  561. /* eslint-enable no-unused-vars */
  562. operators.set('einsum', ((operators, inputs, params) => {
  563. const identifyOperation = (inputs, equation) => {
  564. const identifyFuncs = new Map();
  565. identifyFuncs.set('matmul', (inputs, equation) => {
  566. if (inputs.length !== 2) {
  567. return { found: false };
  568. }
  569. const parts = equation.replace(/\s+/g, '').split(/,|->/);
  570. if (parts.length !== 3) {
  571. return { found: false };
  572. }
  573. const [first, second, output] = parts.map((p) => p.split(''));
  574. if (!(first.length === output.length || second.length === output.length)) {
  575. return { found: false };
  576. }
  577. let a = first.slice(-2);
  578. const b = second.slice(-2);
  579. const c = output.slice(-2);
  580. let transpose_a = false;
  581. let transpose_b = false;
  582. if (a[0] === c[0]) {
  583. transpose_a = false;
  584. } else if (a[1] === c[0]) {
  585. transpose_a = true;
  586. a = [].concat(a.reverse());
  587. } else {
  588. return { found: false };
  589. }
  590. if (a[1] === b[0]) {
  591. transpose_b = false;
  592. } else if (a[1] === b[1]) {
  593. transpose_b = true;
  594. } else {
  595. return { found: false };
  596. }
  597. return { found: true, op: 'matmul', params: { transpose_a, transpose_b } };
  598. });
  599. /* eslint-disable no-unused-vars */
  600. for (const [name, func] of identifyFuncs.entries()) {
  601. const result = func(inputs, equation);
  602. if (result.found) {
  603. return result;
  604. }
  605. }
  606. /* eslint-enable no-unused-vars */
  607. return { found: false };
  608. };
  609. const result = identifyOperation(inputs, params.equation);
  610. if (result.found) {
  611. if (operators.has(result.op)) {
  612. return operators.get(result.op)(inputs, result.params);
  613. }
  614. }
  615. return [];
  616. }).bind(undefined, operators));
  617. const infer = (output) => {
  618. if (outputs.has(output.name)) {
  619. let ready = true;
  620. const layer = outputs.get(output.name);
  621. for (const input of layer.inputs) {
  622. if (input.shape === null) {
  623. infer(input);
  624. if (input.shape === null) {
  625. ready = false;
  626. break;
  627. }
  628. }
  629. }
  630. if (ready) {
  631. let callback = null;
  632. if (operators.has(layer.op)) {
  633. callback = operators.get(layer.op);
  634. } else if (passthroughs.has(layer.op)) {
  635. callback = (inputs) => [inputs[0].slice()];
  636. } else if (broadcasts.has(layer.op)) {
  637. callback = operators.get('broadcast');
  638. } else if (reduces.has(layer.op)) {
  639. callback = operators.get('reduce');
  640. } else if (poolings.has(layer.op)) {
  641. callback = operators.get('pooling');
  642. }
  643. if (!callback) {
  644. callback = () => [];
  645. }
  646. const parameters = layer.parameters;
  647. const inputs = layer.inputs.map((input) => input.shape);
  648. const outputs = callback(inputs, parameters);
  649. for (let i = 0; i < outputs.length; i++) {
  650. if (i < layer.outputs.length) {
  651. layer.outputs[i].shape = outputs[i];
  652. }
  653. }
  654. }
  655. }
  656. };
  657. for (const layer of outputLayers) {
  658. for (const output of layer.outputs) {
  659. infer(output);
  660. }
  661. }
  662. }
  663. };
  664. acuity.Error = class extends Error {
  665. constructor(message) {
  666. super(message);
  667. this.name = 'Error loading Acuity model.';
  668. }
  669. };
  670. export const ModelFactory = acuity.ModelFactory;