acuity.js 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774
  1. var acuity = acuity || {};
  2. var json = json || require('./json');
  3. acuity.ModelFactory = class {
  4. match(context) {
  5. const extension = context.identifier.split('.').pop().toLowerCase();
  6. if (extension === 'json') {
  7. const obj = context.open('json');
  8. if (obj && obj.MetaData && obj.Layers) {
  9. return 'acuity.json';
  10. }
  11. }
  12. return undefined;
  13. }
  14. open(context) {
  15. return acuity.Metadata.open(context).then((metadata) => {
  16. const obj = context.open('json');
  17. return new acuity.Model(metadata, obj);
  18. });
  19. }
  20. };
  21. acuity.Model = class {
  22. constructor(metadata, model, data, quantization) {
  23. this._name = model.MetaData.Name;
  24. this._format = 'Acuity ' + 'v' + model.MetaData.AcuityVersion;
  25. this._runtime = model.MetaData.Platform;
  26. this._graphs = [ new acuity.Graph(metadata, model, data, quantization) ];
  27. }
  28. get format() {
  29. return this._format;
  30. }
  31. get name() {
  32. return this._name;
  33. }
  34. get runtime() {
  35. return this._runtime;
  36. }
  37. get graphs() {
  38. return this._graphs;
  39. }
  40. };
  41. acuity.Graph = class {
  42. constructor(metadata, model) {
  43. this._nodes = [];
  44. this._inputs = [];
  45. this._outputs = [];
  46. const args = new Map();
  47. const arg = (name) => {
  48. if (!args.has(name)) {
  49. args.set(name, { name: name, shape: null });
  50. }
  51. return args.get(name);
  52. };
  53. for (const layerName of Object.keys(model.Layers)) {
  54. const layer = model.Layers[layerName];
  55. layer.inputs = layer.inputs.map((input) => {
  56. return arg(input);
  57. });
  58. layer.outputs = layer.outputs.map((port) => {
  59. const argument = arg("@" + layerName + ":" + port);
  60. let shape = null;
  61. if (layer.op.toLowerCase() == 'input' ||
  62. layer.op.toLowerCase() == 'variable') {
  63. if (Object.prototype.hasOwnProperty.call(layer.parameters, 'shape') && layer.parameters.shape.length > 0) {
  64. shape = layer.parameters.shape;
  65. }
  66. else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'size') && Object.prototype.hasOwnProperty.call(layer.parameters, 'channels')) {
  67. const sizes = layer.parameters.size.split(' ');
  68. shape = [0, parseInt(sizes[0]), parseInt(sizes[1]), layer.parameters.channels];
  69. }
  70. if (shape && shape.length === 4 && shape[0] === 0) {
  71. shape[0] = 1;
  72. }
  73. }
  74. argument.shape = shape;
  75. return argument;
  76. });
  77. }
  78. acuity.Inference.infer(model.Layers);
  79. for (const pair of args) {
  80. const type = new acuity.TensorType(null, new acuity.TensorShape(pair[1].shape));
  81. const arg = new acuity.Argument(pair[0], type, null, null);
  82. args.set(pair[0], arg);
  83. }
  84. for (const layerName of Object.keys(model.Layers)) {
  85. const layer = model.Layers[layerName];
  86. switch (layer.op.toLowerCase()) {
  87. case 'input': {
  88. this._inputs.push(new acuity.Parameter(layerName, true, [
  89. args.get(layer.outputs[0].name)
  90. ]));
  91. break;
  92. }
  93. case 'output': {
  94. this._outputs.push(new acuity.Parameter(layerName, true, [
  95. args.get(layer.inputs[0].name)
  96. ]));
  97. break;
  98. }
  99. default: {
  100. this._nodes.push(new acuity.Node(metadata, layerName, layer, args));
  101. break;
  102. }
  103. }
  104. }
  105. }
  106. get inputs() {
  107. return this._inputs;
  108. }
  109. get outputs() {
  110. return this._outputs;
  111. }
  112. get nodes() {
  113. return this._nodes;
  114. }
  115. };
  116. acuity.Node = class {
  117. constructor(metadata, name, layer, args) {
  118. this._name = name;
  119. this._type = metadata.type(layer.op) || { name: layer.op };
  120. this._inputs = [];
  121. this._outputs = [];
  122. this._attributes = [];
  123. this._layer = layer;
  124. if (this._type) {
  125. if (layer.parameters) {
  126. for (const key of Object.keys(layer.parameters)) {
  127. const attributeMetadata = metadata.attribute(this._type, key);
  128. this._attributes.push(new acuity.Attribute(attributeMetadata, key, layer.parameters[key]));
  129. }
  130. }
  131. }
  132. for (let i = 0; i < layer.inputs.length; i++) {
  133. const input = layer.inputs[i];
  134. const arg = args.get(input.name);
  135. const name = this._type && this._type.inputs && i < this._type.inputs.length ? this._type.inputs[i].name : 'input' + i.toString();
  136. this._inputs.push(new acuity.Parameter(name, true, [ arg ]));
  137. }
  138. if (this._type && this._type.constants) {
  139. for (const constant of this._type.constants) {
  140. // const name = "@" + this._name + ":" + constant.name;
  141. const type = new acuity.TensorType(null, new acuity.TensorShape(null));
  142. const argument = new acuity.Argument('', type, null, new acuity.Tensor(type));
  143. this._inputs.push(new acuity.Parameter(constant.name, true, [ argument ]));
  144. }
  145. }
  146. for (let i = 0; i < layer.outputs.length; i++) {
  147. const output = layer.outputs[i];
  148. const arg = args.get(output.name);
  149. const name = this._type && this._type.outputs && i < this._type.outputs.length ? this._type.outputs[i].name : 'output' + i.toString();
  150. this._outputs.push(new acuity.Parameter(name, true, [arg]));
  151. }
  152. }
  153. get type() {
  154. return this._type;
  155. }
  156. get name() {
  157. return this._name;
  158. }
  159. get inputs() {
  160. return this._inputs;
  161. }
  162. get outputs() {
  163. return this._outputs;
  164. }
  165. get attributes() {
  166. return this._attributes;
  167. }
  168. };
  169. acuity.Attribute = class {
  170. constructor(metadata, name, value) {
  171. this._type = null;
  172. this._name = name;
  173. this._value = value;
  174. if (metadata) {
  175. this._type = metadata.type || null;
  176. if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  177. if (metadata.default === value) {
  178. this._visible = false;
  179. }
  180. }
  181. }
  182. }
  183. get name() {
  184. return this._name;
  185. }
  186. get type() {
  187. return this._type;
  188. }
  189. get value() {
  190. return this._value;
  191. }
  192. get visible() {
  193. return this._visible == false ? false : true;
  194. }
  195. };
  196. acuity.Parameter = class {
  197. constructor(name, visible, args) {
  198. this._name = name;
  199. this._visible = visible;
  200. this._arguments = args;
  201. if (this._arguments.some((arg) => !arg)) {
  202. throw "";
  203. }
  204. }
  205. get name() {
  206. return this._name;
  207. }
  208. get visible() {
  209. return this._visible;
  210. }
  211. get arguments() {
  212. return this._arguments;
  213. }
  214. };
  215. acuity.Argument = class {
  216. constructor(name, type, quantization, initializer) {
  217. if (typeof name !== 'string') {
  218. throw new acuity.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  219. }
  220. this._name = name;
  221. this._type = type || null;
  222. this._quantization = quantization || null;
  223. this._initializer = initializer || null;
  224. }
  225. get name() {
  226. return this._name;
  227. }
  228. get type() {
  229. return this._type;
  230. }
  231. get quantization() {
  232. return this._quantization;
  233. }
  234. set quantization(quantization) {
  235. this._quantization = quantization;
  236. }
  237. get initializer() {
  238. return this._initializer;
  239. }
  240. set initializer(initializer) {
  241. this._initializer = initializer;
  242. }
  243. };
  244. acuity.TensorType = class {
  245. constructor(dataType, shape) {
  246. this._dataType = dataType || '?';
  247. this._shape = shape;
  248. }
  249. get dataType() {
  250. return this._dataType;
  251. }
  252. set dataType(dataType) {
  253. this._dataType = dataType;
  254. }
  255. get shape() {
  256. return this._shape;
  257. }
  258. set shape(shape) {
  259. this._shape = shape;
  260. }
  261. toString() {
  262. return (this.dataType || '?') + this._shape.toString();
  263. }
  264. };
  265. acuity.TensorShape = class {
  266. constructor(dimensions) {
  267. this._dimensions = dimensions || null;
  268. }
  269. get dimensions() {
  270. if (Array.isArray(this._dimensions) && this._dimensions.length == 1 && this._dimensions[0] == 0) {
  271. return [];
  272. }
  273. return this._dimensions;
  274. }
  275. toString() {
  276. if (!Array.isArray(this._dimensions) || this._dimensions.length == 0 || (this._dimensions.length == 1 && this._dimensions[0] == 0)) {
  277. return '';
  278. }
  279. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  280. }
  281. };
  282. acuity.Tensor = class {
  283. constructor(type) {
  284. this._type = type;
  285. }
  286. get kind() {
  287. return 'Constant';
  288. }
  289. get type() {
  290. return this._type;
  291. }
  292. get state() {
  293. return 'Tensor data not implemented.';
  294. }
  295. toString() {
  296. return '';
  297. }
  298. };
  299. acuity.Metadata = class {
  300. static open(context) {
  301. if (acuity.Metadata._metadata) {
  302. return Promise.resolve(acuity.Metadata._metadata);
  303. }
  304. return context.request('acuity-metadata.json', 'utf-8', null).then((data) => {
  305. acuity.Metadata._metadata = new acuity.Metadata(data);
  306. return acuity.Metadata._metadata;
  307. }).catch(() => {
  308. acuity.Metadata._metadata = new acuity.Metadata(null);
  309. return acuity.Metadata._metadata;
  310. });
  311. }
  312. constructor(data) {
  313. this._map = new Map();
  314. if (data) {
  315. const metadata = JSON.parse(data);
  316. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  317. }
  318. }
  319. type(name) {
  320. return this._map.get(name);
  321. }
  322. attribute(type, name) {
  323. const schema = this.type(type);
  324. if (schema) {
  325. let attributeMap = schema.attributeMap;
  326. if (!attributeMap) {
  327. attributeMap = {};
  328. if (schema.attributes) {
  329. for (const attribute of schema.attributes) {
  330. attributeMap[attribute.name] = attribute;
  331. }
  332. }
  333. schema.attributeMap = attributeMap;
  334. }
  335. const attributeSchema = attributeMap[name];
  336. if (attributeSchema) {
  337. return attributeSchema;
  338. }
  339. }
  340. return null;
  341. }
  342. };
  343. acuity.Inference = class {
  344. static infer(layers) {
  345. const outputs = new Map();
  346. const outputLayers = [];
  347. for (const layerName of Object.keys(layers)) {
  348. const layer = layers[layerName];
  349. if (layer.op.toLowerCase() == 'output') {
  350. outputLayers.push(layer);
  351. }
  352. for (const output of layer.outputs) {
  353. outputs.set(output.name, layer);
  354. }
  355. }
  356. const broadcasts = new Set([
  357. 'add', 'equal', 'fllor_mod', 'floor_div', 'greater', 'greater_equal', 'less', 'less_equal',
  358. 'logical_and', 'logical_or', 'minimum', 'multiply', 'not_equal', 'pow', 'real_div',
  359. 'squared_difference', 'subtract'
  360. ]);
  361. const passthroughs = new Set([
  362. 'LocalResponseNormalization', 'a_times_b_plus_c', 'abs', 'batchnorm_single', 'batchnormalize',
  363. 'cast', 'cast', 'clipbyvalue', 'dequantize', 'dtype_converter', 'elu', 'exp', 'floor',
  364. 'groupnormalize', 'hard_sigmoid', 'hard_swish', 'instancenormalize', 'l2normalize', 'l2normalizescale',
  365. 'layernormalize', 'leakyrelu', 'log', 'log_softmax', 'mish', 'neg', 'norm_with_channel_mean',
  366. 'norm_with_min_max', 'norm_with_scale', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras',
  367. 'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh'
  368. ]);
  369. const reduces = new Set([
  370. 'reduceany', 'reducemax', 'reducemean', 'reducemin', 'reduceprod', 'reducesum'
  371. ]);
  372. const operators = new Map();
  373. operators.set('broadcast', (inputs) => {
  374. const a = inputs[0];
  375. const b = inputs[1];
  376. const longer = a.length >= b.length ? a.slice() : b.slice();
  377. const shorter = a.length < b.length ? a.slice() : b.slice();
  378. const remain = longer.length - shorter.length;
  379. for (let i = 0; i < remain; i++) {
  380. shorter.splice(0, 0, 1);
  381. }
  382. for (let i = 0; i < longer.length; i++) {
  383. longer[i] = longer[i] > shorter[i] ? longer[i] : shorter[i];
  384. }
  385. return [ longer ];
  386. });
  387. operators.set('concat', (inputs, params) => {
  388. const outputShape = inputs[0].slice();
  389. outputShape[params.dim] = 0;
  390. for (const shape of inputs) {
  391. outputShape[params.dim] += shape[params.dim];
  392. }
  393. return [ outputShape ];
  394. });
  395. operators.set('conv1d', (inputs, params) => {
  396. if (params.padding == 'VALID') {
  397. const out_h = ~~((inputs[0][1] + params.stride - params.ksize) / params.stride);
  398. return [ [ inputs[0][0], out_h, params.weights ] ];
  399. }
  400. else if (params.padding == 'SAME') {
  401. const out_h = ~~((inputs[0][1] + params.stride - 1) / params.stride);
  402. return [ [ inputs[0][0], out_h, params.weights ] ];
  403. }
  404. return null;
  405. });
  406. operators.set('convolution', (inputs, params) => {
  407. if (params.padding == 'VALID') {
  408. const out_h = ~~((inputs[0][1] + params.stride_h + params.pad[0] + params.pad[1] - params.ksize_h) / params.stride_h);
  409. const out_w = ~~((inputs[0][2] + params.stride_w + params.pad[2] + params.pad[3]- params.ksize_w) / params.stride_w);
  410. return [ [ inputs[0][0], out_h, out_w, params.weights ] ];
  411. }
  412. else if (params.padding == 'SAME') {
  413. const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
  414. const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
  415. return [ [ inputs[0][0], out_h, out_w, params.weights ] ];
  416. }
  417. return null;
  418. });
  419. operators.set('deconvolution', (inputs, params) => {
  420. return [ params.output_shape.map((item, index) => item == 0 ? inputs[0][index] : item) ];
  421. });
  422. operators.set('fullconnect', (inputs, params) => {
  423. return [ inputs[0].slice(0, params.axis).concat([params.weights]) ];
  424. });
  425. operators.set('gather', (inputs, params) => {
  426. const prefix = inputs[1].slice();
  427. const suffix = inputs[0].slice(params.axis + 1);
  428. return [ prefix.concat(suffix) ];
  429. });
  430. operators.set('lstm', (inputs, params) => {
  431. let batch = inputs[0][0];
  432. const output = params.num_proj != null ? params.num_proj : params.weights;
  433. if (params.time_major) {
  434. batch = inputs[0][1];
  435. }
  436. const newShape = params.return_sequences ? [ inputs[0][0], inputs[0][1], output ] : [ batch, output ];
  437. return [ newShape, [batch, output], [batch, params.weights] ];
  438. });
  439. operators.set('matmul', (inputs, params) => {
  440. const a = inputs[0];
  441. const b = inputs[1];
  442. let newShape = a.slice(0, -2);
  443. if (params.transpose_a) {
  444. newShape = newShape.concat(a.slice(-1));
  445. }
  446. else {
  447. newShape = newShape.concat(a.slice(-2, -1));
  448. }
  449. if (params.transpose_b) {
  450. newShape = newShape.concat(b.slice(-2, -1));
  451. }
  452. else {
  453. newShape = newShape.concat(b.slice(-1));
  454. }
  455. return [ newShape ];
  456. });
  457. operators.set('pad', (inputs, params) => {
  458. return [ inputs[0].map((item, index) => item + params.padding_value[index][0] + params.padding_value[index][1]) ];
  459. });
  460. operators.set('permute', (inputs, params) => {
  461. return [ inputs[0].map((item, index) => inputs[0][params.perm[index]]) ];
  462. });
  463. operators.set('pooling', (inputs, params) => {
  464. if (params.padding == 'VALID') {
  465. const out_h = ~~((inputs[0][1] + params.stride_h - params.ksize_h) / params.stride_h);
  466. const out_w = ~~((inputs[0][2] + params.stride_w - params.ksize_w) / params.stride_w);
  467. return [ [inputs[0][0], out_h, out_w, inputs[0][3]] ];
  468. }
  469. else if (params.padding == 'SAME') {
  470. const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
  471. const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
  472. return [ [inputs[0][0], out_h, out_w, inputs[0][3]] ];
  473. }
  474. return null;
  475. });
  476. operators.set('reduce', (inputs, params) => {
  477. const newShape = inputs[0].slice();
  478. if (params.keep_dims) {
  479. for (const i in params.axis_list) {
  480. newShape[i] = 1;
  481. }
  482. }
  483. else {
  484. const axis_list = params.axis_list.map((item) => {
  485. return item < 0 ? newShape.length + item : item;
  486. });
  487. axis_list.sort((a, b) => {
  488. return b - a;
  489. });
  490. for (const item of axis_list) {
  491. newShape.splice(item, 1);
  492. }
  493. if (!newShape.length) {
  494. newShape.splice(0, 0, 0);
  495. }
  496. }
  497. return [ newShape ];
  498. });
  499. operators.set('repeat', (inputs, params) => {
  500. const newShape = inputs[0].slice();
  501. newShape[params.axis] = params.maxlen;
  502. return [ newShape ];
  503. });
  504. operators.set('reshape', (inputs, params) => {
  505. const negativeIndexs = [];
  506. let shape = params.shape;
  507. if (typeof params.shape === 'string') {
  508. shape = params.shape.split(/\s+/).map((item) => {
  509. return parseInt(item);
  510. });
  511. }
  512. const newShape = shape.map((item, index) => {
  513. if (item == 0) {
  514. return inputs[0][index];
  515. }
  516. if (item == -1) {
  517. negativeIndexs.push(index);
  518. return 1;
  519. }
  520. return item;
  521. });
  522. if (negativeIndexs.length > 0) {
  523. newShape[negativeIndexs[0]] = inputs[0].reduce((a, c) => a * c) / newShape.reduce((a, c) => a * c);
  524. }
  525. return [ newShape ];
  526. });
  527. operators.set('sequence_mask', (inputs, params) => {
  528. return [ inputs[0].slice().concat([params.maxlen]) ];
  529. });
  530. operators.set('slice', (inputs, params) => {
  531. return [ params.size.map((item, index) => item == -1 ? inputs[0][index] : item) ];
  532. });
  533. operators.set('squeeze', (inputs, params) => {
  534. const newShape = inputs[0].slice();
  535. const axis_list = [...new Set(params.axis_list)].sort((a, b) => b - a);
  536. for (const item of axis_list) {
  537. newShape.splice(item, 1);
  538. }
  539. return [ newShape ];
  540. });
  541. operators.set('space2depth', (inputs, params) => {
  542. const h = inputs[0][1] / params.block_size[0];
  543. const w = inputs[0][2] / params.block_size[1];
  544. const c = inputs[0][3] * params.block_size[1] * params.block_size[1];
  545. return [ [inputs[0][0], h, w, c] ];
  546. });
  547. operators.set('split', (inputs, params) => {
  548. const sizes = [];
  549. const slices = params.slices.slice();
  550. slices.splice(0, 0, 0);
  551. slices.push(inputs[0][params.dim]);
  552. slices.reduce((a, b) => {
  553. sizes.push(b - a);
  554. return b;
  555. });
  556. return sizes.map((item) => {
  557. const shape = inputs[0].slice();
  558. shape[params.dim] = item;
  559. return shape;
  560. });
  561. });
  562. operators.set('stack', (inputs, params) => {
  563. const newShape = inputs[0].slice();
  564. if (newShape.length == 1 && newShape[0] == 0) {
  565. newShape[0] = 1;
  566. }
  567. else {
  568. newShape.splice(params.axis, 0, inputs.length);
  569. }
  570. return [ newShape ];
  571. });
  572. operators.set('stridedslice', (inputs, params) => {
  573. const input_shape = inputs[0].slice();
  574. const begin = params.slice_begin.slice();
  575. const end = params.slice_end.slice();
  576. if (params.slice_begin_mask > 0) {
  577. for (let i = 0; i < begin.length; i++) {
  578. if ((params.slice_begin_mask >>> i) & 0x1) {
  579. begin[i] = -1;
  580. }
  581. }
  582. }
  583. if (params.slice_end_mask > 0) {
  584. for (let i = 0; i < end.length; i++) {
  585. if ((params.slice_end_mask >>> i) & 0x1) {
  586. end[i] = -1;
  587. }
  588. }
  589. }
  590. for (let i = 0; i < begin.length; i++) {
  591. if (begin[i] == -1) {
  592. begin[i] = 0;
  593. }
  594. }
  595. if (inputs[0].length == end.length){
  596. for (let i = 0; i < end.length; i++) {
  597. if (end[i] == -1 || end[i] > input_shape[i]) {
  598. end[i] = input_shape[i];
  599. }
  600. }
  601. }
  602. else if (inputs[0].length < end.length){
  603. if (params.slice_new_axis_mask) {
  604. const len = (params.slice_new_axis_mask >>> 0).toString(2).length;
  605. for (let i = 0; i < len; i++) {
  606. if ((params.slice_new_axis_mask >>> i) & 0x1) {
  607. input_shape.splice(i, 0, 1);
  608. }
  609. }
  610. for (let i = 0; i < end.length; i++) {
  611. if (end[i] == -1) {
  612. end[i] = input_shape[i];
  613. }
  614. }
  615. }
  616. }
  617. let newShape = [];
  618. for (let i = 0; i < begin.length; i++) {
  619. newShape = newShape.concat([(end[i] - begin[i])/params.slice_strides[i]]);
  620. }
  621. if (params.slice_shrink_axis_mask) {
  622. const len = (params.slice_shrink_axis_mask >>> 0).toString(2).length;
  623. for (let i = 0; i < len; i++) {
  624. if ((params.slice_shrink_axis_mask >>> i) & 0x1) {
  625. newShape.splice(i, 1);
  626. }
  627. }
  628. }
  629. if (params.slice_new_axis_mask) {
  630. const len = (params.slice_new_axis_mask >>> 0).toString(2).length;
  631. for (let i = 0; i < len; i++) {
  632. if ((params.slice_new_axis_mask >>> i) & 0x1) {
  633. if (inputs[0].length == begin.length) {
  634. newShape.splice(i, 0, 1);
  635. }
  636. else if (inputs[0].length < begin.length) {
  637. newShape[i] = 1;
  638. }
  639. }
  640. }
  641. }
  642. return [ newShape ];
  643. });
  644. const infer = (output) => {
  645. if (outputs.has(output.name)) {
  646. let ready = true;
  647. const layer = outputs.get(output.name);
  648. for (const input of layer.inputs) {
  649. if (input.shape === null) {
  650. infer(input);
  651. if (input.shape === null) {
  652. ready = false;
  653. break;
  654. }
  655. }
  656. }
  657. if (ready) {
  658. let callback = null;
  659. if (operators.has(layer.op)) {
  660. callback = operators.get(layer.op);
  661. }
  662. else if (passthroughs.has(layer.op)) {
  663. callback = (inputs) => [ inputs[0].slice() ];
  664. }
  665. else if (broadcasts.has(layer.op)) {
  666. callback = operators.get('broadcast');
  667. }
  668. else if (reduces.has(layer.op)) {
  669. callback = operators.get('reduce');
  670. }
  671. else {
  672. callback = () => [];
  673. }
  674. const parameters = layer.parameters;
  675. const inputs = layer.inputs.map((input) => input.shape);
  676. const outputs = callback(inputs, parameters);
  677. for (let i = 0; i < outputs.length; i++) {
  678. if (i < layer.outputs.length) {
  679. layer.outputs[i].shape = outputs[i];
  680. }
  681. }
  682. }
  683. }
  684. };
  685. for (const layer of outputLayers) {
  686. for (const output of layer.outputs) {
  687. infer(output);
  688. }
  689. }
  690. }
  691. };
  692. acuity.Error = class extends Error {
  693. constructor(message) {
  694. super(message);
  695. this.name = 'Error loading Acuity model.';
  696. }
  697. };
  698. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  699. module.exports.ModelFactory = acuity.ModelFactory;
  700. }