mxnet.js 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. import * as json from './json.js';
  2. const mxnet = {};
  3. mxnet.ModelFactory = class {
  4. async match(context) {
  5. const identifier = context.identifier;
  6. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  7. if (extension === 'json') {
  8. const obj = await context.peek('json');
  9. if (obj && Array.isArray(obj.nodes) && Array.isArray(obj.arg_nodes) && Array.isArray(obj.heads) && !obj.nodes.some((node) => node && node.op === 'tvm_op')) {
  10. return context.set('mxnet.json', obj);
  11. }
  12. }
  13. const stream = context.stream;
  14. const signature = [0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
  15. if (stream && stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  16. return context.set('mxnet.params');
  17. }
  18. return null;
  19. }
  20. filter(context, match) {
  21. return context.type !== 'mxnet.json' || match.type !== 'mxnet.params';
  22. }
  23. async open(context) {
  24. const metadata = await context.metadata('mxnet-metadata.json');
  25. const basename = (base, identifier, extension, suffix, append) => {
  26. if (!base) {
  27. if (identifier.toLowerCase().endsWith(extension)) {
  28. const items = identifier.substring(0, identifier.length - extension.length).split('-');
  29. if (items.length >= 2) {
  30. const token = items.pop();
  31. if ((suffix && token === suffix) || /[a-zA-Z0-9]*/.exec(token)) {
  32. return items.join('-') + append;
  33. }
  34. }
  35. }
  36. }
  37. return base;
  38. };
  39. const convertVersion = (value) => {
  40. if (Array.isArray(value)) {
  41. if (value.length === 2 && value[0] === 'int') {
  42. const major = Math.floor(value[1] / 10000) % 100;
  43. const minor = Math.floor(value[1] / 100) % 100;
  44. const patch = Math.floor(value[1]) % 100;
  45. return [major.toString(), minor.toString(), patch.toString()].join('.');
  46. }
  47. }
  48. return null;
  49. };
  50. const requestManifest = async () => {
  51. const parse = async (stream) => {
  52. try {
  53. const manifest = {};
  54. if (stream) {
  55. const reader = json.TextReader.open(stream);
  56. const obj = reader.read();
  57. if (obj.Model) {
  58. const modelFormat = obj.Model['Model-Format'];
  59. if (modelFormat && modelFormat !== 'MXNet-Symbolic') {
  60. throw new mxnet.Error(`Model format '${modelFormat}' not supported.`);
  61. }
  62. manifest.format = 'MXNet Model Server';
  63. if (obj['Model-Archive-Version']) {
  64. manifest.format += ` v${obj['Model-Archive-Version']}`;
  65. }
  66. if (!obj.Model.Symbol) {
  67. throw new mxnet.Error('Manifest does not contain symbol entry.');
  68. }
  69. manifest.symbol = obj.Model.Symbol;
  70. if (obj.Model.Signature) {
  71. manifest.signature = obj.Model.Signature;
  72. }
  73. if (obj.Model.Parameters) {
  74. manifest.params = obj.Model.Parameters;
  75. }
  76. if (obj.Model['Model-Name']) {
  77. manifest.name = obj.Model['Model-Name'];
  78. }
  79. if (obj.Model.Description && manifest.name !== obj.Model.Description) {
  80. manifest.description = obj.Model.Description;
  81. }
  82. } else if (obj.model) {
  83. manifest.format = 'MXNet Model Archive';
  84. if (obj.specificationVersion) {
  85. manifest.format += ` v${obj.specificationVersion}`;
  86. }
  87. if (obj.model.modelName) {
  88. manifest.symbol = `${obj.model.modelName}-symbol.json`;
  89. }
  90. if (obj.model.modelName) {
  91. manifest.name = obj.model.modelName;
  92. }
  93. if (manifest.model && obj.model.modelVersion) {
  94. manifest.version = obj.model.modelVersion;
  95. }
  96. if (manifest.model && manifest.model.modelName && manifest.name !== obj.model.description) {
  97. manifest.description = obj.model.description;
  98. }
  99. } else {
  100. throw new mxnet.Error('Manifest does not contain model.');
  101. }
  102. if (obj.Engine && obj.Engine.MXNet) {
  103. const version = convertVersion(obj.Engine.MXNet);
  104. manifest.runtime = `MXNet v${version ? version : obj.Engine.MXNet}`;
  105. }
  106. if (obj.License) {
  107. manifest.license = obj.License;
  108. }
  109. if (obj.runtime) {
  110. manifest.runtime = obj.runtime;
  111. }
  112. if (obj.engine && obj.engine.engineName) {
  113. const engine = obj.engine.engineVersion ? `${obj.engine.engineName} ${obj.engine.engineVersion}` : obj.engine.engineName;
  114. manifest.runtime = manifest.runtime ? (`${manifest.runtime} (${engine})`) : engine;
  115. }
  116. if (obj.publisher && obj.publisher.author) {
  117. manifest.author = obj.publisher.author;
  118. if (obj.publisher.email) {
  119. manifest.author = `${manifest.author} <${obj.publisher.email}>`;
  120. }
  121. }
  122. if (obj.license) {
  123. manifest.license = obj.license;
  124. }
  125. if (obj.Model && obj.Model.Signature) {
  126. try {
  127. const content = await context.fetch(obj.Model.Signature);
  128. manifest.signature = await content.read('json');
  129. return manifest;
  130. } catch {
  131. return manifest;
  132. }
  133. }
  134. }
  135. return manifest;
  136. } catch (error) {
  137. throw new mxnet.Error(`Failed to read manifest. ${error.message}`);
  138. }
  139. };
  140. try {
  141. const content = await context.fetch('MANIFEST.json');
  142. return parse(content.stream);
  143. } catch {
  144. try {
  145. const content = await context.fetch('MAR-INF/MANIFEST.json');
  146. return parse(content.stream);
  147. } catch {
  148. return parse(null);
  149. }
  150. }
  151. };
  152. const createModel = (metadata, manifest, symbol, params) => {
  153. const parameters = new Map();
  154. if (params) {
  155. try {
  156. for (const [key, array] of mxnet.ndarray.load(params)) {
  157. const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
  158. parameters.set(name, array);
  159. }
  160. } catch {
  161. // continue regardless of error
  162. }
  163. }
  164. if (symbol) {
  165. if (!manifest.format) {
  166. const version = convertVersion(symbol.attrs && symbol.attrs.mxnet_version ? symbol.attrs.mxnet_version : null);
  167. manifest.format = `MXNet${version ? ` v${version}` : ''}`;
  168. }
  169. if (symbol.nodes && symbol.nodes.some((node) => node && node.op === 'tvm_op')) {
  170. manifest.format = 'TVM';
  171. }
  172. }
  173. return new mxnet.Model(metadata, manifest, symbol, parameters);
  174. };
  175. const identifier = context.identifier;
  176. switch (context.type) {
  177. case 'mxnet.json': {
  178. let symbol = null;
  179. try {
  180. symbol = context.value;
  181. } catch (error) {
  182. const message = error && error.message ? error.message : error.toString();
  183. throw new mxnet.Error(`Failed to load symbol entry (${message.replace(/\.$/, '')}).`);
  184. }
  185. const requestParams = async (manifest) => {
  186. const file = basename(manifest.params, identifier, '.json', 'symbol', '-0000.params');
  187. if (file) {
  188. try {
  189. const content = await context.fetch(file);
  190. const reader = await content.read('binary');
  191. return createModel(metadata, manifest, symbol, reader);
  192. } catch {
  193. return createModel(metadata, manifest, symbol, null);
  194. }
  195. }
  196. return createModel(metadata, manifest, symbol, null);
  197. };
  198. const manifest = await requestManifest();
  199. return requestParams(manifest);
  200. }
  201. case 'mxnet.params': {
  202. const params = await context.read('binary');
  203. const requestSymbol = async (manifest) => {
  204. const name = basename(manifest.symbol, identifier, '.params', null, '-symbol.json');
  205. if (name) {
  206. try {
  207. const content = await context.fetch(name);
  208. const symbol = await content.read('json');
  209. return createModel(metadata, manifest, symbol, params);
  210. } catch {
  211. return createModel(metadata, manifest, null, params);
  212. }
  213. }
  214. return createModel(metadata, manifest, null, params);
  215. };
  216. const manifest = await requestManifest();
  217. return requestSymbol(manifest);
  218. }
  219. default: {
  220. throw new mxnet.Error(`Unsupported MXNet format '${context.type}'.`);
  221. }
  222. }
  223. }
  224. };
  225. mxnet.Model = class {
  226. constructor(metadata, manifest, symbol, params) {
  227. if (!symbol && !params) {
  228. throw new mxnet.Error('JSON symbol data not available.');
  229. }
  230. if (symbol) {
  231. if (!Object.prototype.hasOwnProperty.call(symbol, 'nodes')) {
  232. throw new mxnet.Error('JSON file does not contain an MXNet \'nodes\' property.');
  233. }
  234. if (!Object.prototype.hasOwnProperty.call(symbol, 'arg_nodes')) {
  235. throw new mxnet.Error('JSON file does not contain an MXNet \'arg_nodes\' property.');
  236. }
  237. if (!Object.prototype.hasOwnProperty.call(symbol, 'heads')) {
  238. throw new mxnet.Error('JSON file does not contain an MXNet \'heads\' property.');
  239. }
  240. }
  241. this.format = manifest.format || 'MXNet';
  242. this.producer = manifest.producer || '';
  243. this.name = manifest.name || '';
  244. this.version = manifest.version;
  245. this.description = manifest.description || '';
  246. this.runtime = manifest.runtime || '';
  247. this.metadata = [];
  248. if (manifest.author) {
  249. this.metadata.push(new mxnet.Argument('author', manifest.author));
  250. }
  251. if (manifest.license) {
  252. this.metadata.push(new mxnet.Argument('license', manifest.license));
  253. }
  254. this.modules = [new mxnet.Graph(metadata, manifest, symbol, params)];
  255. }
  256. };
  257. mxnet.Graph = class {
  258. constructor(metadata, manifest, symbol, params) {
  259. this.nodes = [];
  260. this.inputs = [];
  261. this.outputs = [];
  262. const tensors = new Map();
  263. if (params) {
  264. for (const [name, value] of params) {
  265. const shape = new mxnet.TensorShape(value.shape);
  266. const type = new mxnet.TensorType(value.dtype, shape);
  267. const tensor = new mxnet.Tensor(name, type, value.data);
  268. tensors.set(name, tensor);
  269. }
  270. }
  271. const values = new Map();
  272. values.map = (name, type, tensor) => {
  273. if (!values.has(name)) {
  274. values.set(name, new mxnet.Value(name, type || null, tensor || null));
  275. } else if (type || (tensor && tensor !== values.get(name).initializer)) {
  276. throw new mxnet.Error(`Duplicate value '${name}'.`);
  277. }
  278. return values.get(name);
  279. };
  280. const updateOutput = (nodes, input) => {
  281. const [nodeIndex, outputIndex] = input;
  282. const node = nodes[nodeIndex];
  283. if (node) {
  284. while (outputIndex >= node.outputs.length) {
  285. node.outputs.push([nodeIndex, node.outputs.length]);
  286. }
  287. }
  288. return [nodeIndex, outputIndex];
  289. };
  290. if (symbol) {
  291. const nodes = symbol.nodes;
  292. const inputs = {};
  293. const outputs = {};
  294. if (manifest && manifest.signature && manifest.signature.inputs) {
  295. for (const input of manifest.signature.inputs) {
  296. inputs[input.data_name] = input;
  297. }
  298. }
  299. if (manifest && manifest.signature && manifest.signature.outputs) {
  300. for (const output of manifest.signature.outputs) {
  301. outputs[output.data_name] = output;
  302. }
  303. }
  304. for (const node of nodes) {
  305. node.outputs = [];
  306. }
  307. for (const node of nodes) {
  308. node.inputs = node.inputs || [];
  309. node.inputs = node.inputs.map((input) => updateOutput(nodes, input));
  310. }
  311. const arg_nodes = new Map(symbol.arg_nodes.map((index) => [index, index < nodes.length ? nodes[index] : null]));
  312. for (let i = 0; i < symbol.heads.length; i++) {
  313. const head = symbol.heads[i];
  314. const identifier = updateOutput(nodes, head);
  315. const name = `output${(i === 0) ? '' : (i + 1)}`;
  316. const signature = outputs[name];
  317. const type = signature && signature.data_shape ? new mxnet.TensorType(-1, new mxnet.TensorShape(signature.data_shape)) : null;
  318. const value = values.map(`[${identifier.join(',')}]`, type);
  319. const argument = new mxnet.Argument(name, [value]);
  320. this.outputs.push(argument);
  321. }
  322. const filtered = nodes.filter((node, index) => !arg_nodes.has(index));
  323. const initializers = new Map();
  324. for (const node of filtered) {
  325. if (node.op === 'RNN') {
  326. node.inputs = node.inputs.filter((input) => {
  327. const [index] = input;
  328. const arg_node = arg_nodes.get(index);
  329. if (arg_node && arg_node.op === 'null' && arg_node.name && arg_node.name.endsWith('_parameters') && arg_node.attr && arg_node.attr.__init__) {
  330. let attr = node.attrs || node.attr || node.param;
  331. if (!attr) {
  332. node.attr = {};
  333. attr = node.attr;
  334. }
  335. attr[arg_node.name] = arg_node.attr.__init__;
  336. arg_nodes.delete(index);
  337. return false;
  338. }
  339. return true;
  340. });
  341. }
  342. for (const input of node.inputs) {
  343. const identifier = `[${input.join(',')}]`;
  344. if (!initializers.has(identifier)) {
  345. const [index] = input;
  346. const arg_node = arg_nodes.get(index);
  347. if (arg_node && arg_node.name && (!arg_node.inputs || arg_node.inputs.length === 0) && (arg_node.outputs && arg_node.outputs.length === 1)) {
  348. if (tensors.has(arg_node.name)) {
  349. initializers.set(identifier, tensors.get(arg_node.name));
  350. arg_nodes.delete(index);
  351. } else {
  352. const prefix = node.name.endsWith('_fwd') ? node.name.slice(0, -3) : node.name;
  353. if (arg_node.name && (arg_node.name.startsWith(`${prefix}_`) || arg_node.name.startsWith(`${prefix}.`))) {
  354. let dataType = -1;
  355. let shape = [];
  356. if (arg_node.attrs && arg_node.attrs.__dtype__ && arg_node.attrs.__shape__) {
  357. try {
  358. dataType = parseInt(arg_node.attrs.__dtype__, 10);
  359. shape = JSON.parse(`[${arg_node.attrs.__shape__.replace(/[()]/g, '').split(' ').join('').split(',').map(((v) => v || '"?"')).join(',')}]`);
  360. } catch {
  361. // continue regardless of error
  362. }
  363. }
  364. const type = (dataType !== -1 || shape.length > 0) ?
  365. new mxnet.TensorType(dataType, new mxnet.TensorShape(shape)) :
  366. new mxnet.TensorType(-1, new mxnet.TensorShape(null));
  367. initializers.set(identifier, new mxnet.Tensor(arg_node.name, type, null));
  368. arg_nodes.delete(index);
  369. }
  370. }
  371. }
  372. }
  373. }
  374. if (node.params) {
  375. for (const param of node.params) {
  376. values.map(param.id, null, tensors.get(param.id));
  377. }
  378. }
  379. }
  380. for (const [, arg_node] of arg_nodes) {
  381. if (arg_node && (!arg_node.inputs || arg_node.inputs.length === 0) && (arg_node.outputs && arg_node.outputs.length === 1)) {
  382. const identifier = `[${arg_node.outputs[0].join(',')}]`;
  383. const name = arg_node.name;
  384. const signature = inputs[name];
  385. const type = signature && signature.data_shape ? new mxnet.TensorType(-1, new mxnet.TensorShape(signature.data_shape)) : null;
  386. const value = values.map(identifier, type, tensors.get(identifier));
  387. const argument = new mxnet.Argument(name, [value]);
  388. this.inputs.push(argument);
  389. }
  390. }
  391. for (const node of filtered) {
  392. this.nodes.push(new mxnet.Node(metadata, node, initializers, values));
  393. }
  394. } else if (params) {
  395. const blocks = new Map();
  396. let separator = Array.from(params.keys()).every((key) => key.indexOf('_') !== -1) ? '_' : '';
  397. if (separator.length === 0) {
  398. separator = Array.from(params.keys()).every((key) => key.indexOf('.') !== -1) ? '.' : '';
  399. }
  400. if (separator.length > 0) {
  401. for (const [key] of params) {
  402. const parts = key.split(separator);
  403. let argumentName = parts.pop();
  404. if (key.endsWith('moving_mean') || key.endsWith('moving_var')) {
  405. argumentName = [parts.pop(), argumentName].join(separator);
  406. }
  407. const nodeName = parts.join(separator);
  408. if (!blocks.has(nodeName)) {
  409. blocks.set(nodeName, { name: nodeName, op: 'Weights', params: [] });
  410. }
  411. blocks.get(nodeName).params.push({ name: argumentName, id: key });
  412. values.map(key, null, tensors.get(key));
  413. }
  414. } else {
  415. throw new mxnet.Error("Unsupported key format in params.");
  416. }
  417. for (const block of blocks.values()) {
  418. this.nodes.push(new mxnet.Node(metadata, block, new Map(), values));
  419. }
  420. }
  421. }
  422. };
  423. mxnet.Argument = class {
  424. constructor(name, value, type = null, visible = true) {
  425. this.name = name;
  426. this.value = value;
  427. this.type = type;
  428. this.visible = visible;
  429. }
  430. };
  431. mxnet.Value = class {
  432. constructor(name, type, initializer = null) {
  433. if (typeof name !== 'string') {
  434. throw new mxnet.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  435. }
  436. this.name = !name && initializer && initializer.name ? initializer.name : name;
  437. this.type = !type && initializer && initializer.type ? initializer.type : type;
  438. this.initializer = initializer;
  439. }
  440. };
  441. mxnet.Node = class {
  442. constructor(metadata, node, initializers, values) {
  443. let type = node.op;
  444. this.name = node.name;
  445. this.attributes = [];
  446. this.inputs = [];
  447. this.outputs = [];
  448. const attrs = node.attrs || node.attr || node.param;
  449. if (attrs) {
  450. if (type === 'tvm_op' && attrs.func_name) {
  451. type = attrs.func_name;
  452. }
  453. for (const [name, obj] of Object.entries(attrs)) {
  454. if (type !== 'tvm_op' && name !== 'func_name') {
  455. let value = obj;
  456. let visible = true;
  457. const schema = metadata.attribute(type, name);
  458. if (schema && schema.type) {
  459. switch (schema.type) {
  460. case 'boolean':
  461. switch (value) {
  462. case 0:
  463. case '0':
  464. case 'False':
  465. value = false;
  466. break;
  467. case 1:
  468. case '1':
  469. case 'True':
  470. value = true;
  471. break;
  472. default:
  473. throw new mxnet.Error(`Unsupported attribute boolean value '${value}'.`);
  474. }
  475. break;
  476. case 'int32': {
  477. const number = Number.parseInt(value, 10);
  478. value = Number.isNaN(value - number) ? value : number;
  479. break;
  480. }
  481. case 'float32':
  482. case 'float64': {
  483. const number = Number.parseFloat(value);
  484. value = Number.isNaN(value - number) ? value : number;
  485. break;
  486. }
  487. case 'int32[]':
  488. if (value.length > 2 && value.startsWith('(') && value.endsWith(')')) {
  489. let array = [];
  490. const items = value.substring(1, value.length - 1).split(',')
  491. .map((item) => item.trim())
  492. .map((item) => item.endsWith('L') ? item.substring(0, item.length - 1) : item);
  493. for (const item of items) {
  494. const value = Number.parseInt(item, 10);
  495. if (Number.isNaN(item - value)) {
  496. array = null;
  497. } else if (array !== null) {
  498. array.push(value);
  499. }
  500. }
  501. if (array !== null) {
  502. value = array;
  503. }
  504. }
  505. break;
  506. default:
  507. throw new mxnet.Error(`Unsupported attribute type '${metadata.type}'.`);
  508. }
  509. }
  510. if (metadata) {
  511. if (metadata.visible === false) {
  512. visible = false;
  513. } else if (metadata.default !== undefined) {
  514. const defaultValue = metadata.default;
  515. if (value === defaultValue) {
  516. visible = false;
  517. } else if (Array.isArray(value) && Array.isArray(defaultValue)) {
  518. const repeat = defaultValue.length > 1 && defaultValue[defaultValue.length - 1] === null;
  519. if (value.every((item, index) => item === (repeat && index >= defaultValue.length - 1 ? defaultValue[defaultValue.length - 2] : defaultValue[index]))) {
  520. visible = false;
  521. }
  522. }
  523. }
  524. }
  525. const attribute = new mxnet.Argument(name, value, type, visible);
  526. this.attributes.push(attribute);
  527. }
  528. }
  529. }
  530. this.type = metadata.type(type) || { name: type };
  531. if (node.inputs) {
  532. const inputs = node.inputs;
  533. let inputIndex = 0;
  534. if (this.type && this.type.inputs) {
  535. for (const inputDef of this.type.inputs) {
  536. if (inputIndex < inputs.length || inputDef.optional !== true) {
  537. const count = (inputDef.type === 'Tensor[]') ? (inputs.length - inputIndex) : 1;
  538. const list = [];
  539. for (const input of inputs.slice(inputIndex, inputIndex + count)) {
  540. const identifier = `[${input.join(',')}]`;
  541. if (identifier !== '' || (inputDef.optional !== true || inputDef.type === 'Tensor[]')) {
  542. const value = values.map(identifier, null, initializers.get(identifier));
  543. list.push(value);
  544. }
  545. }
  546. const argument = new mxnet.Argument(inputDef.name, list);
  547. this.inputs.push(argument);
  548. inputIndex += count;
  549. }
  550. }
  551. }
  552. if (inputIndex < inputs.length) {
  553. this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  554. const name = (inputIndex + index).toString();
  555. const identifier = `[${input.join(',')}]`;
  556. const value = values.map(identifier, null, initializers.get(identifier));
  557. return new mxnet.Argument(name, [value]);
  558. }));
  559. }
  560. }
  561. if (node.outputs) {
  562. const outputs = node.outputs;
  563. let outputIndex = 0;
  564. if (this.type && this.type.outputs) {
  565. for (const outputDef of this.type.outputs) {
  566. if (outputIndex < outputs.length || outputDef.optional !== true) {
  567. const list = [];
  568. const count = (outputDef.type === 'Tensor[]') ? (outputs.length - outputIndex) : 1;
  569. for (const output of outputs.slice(outputIndex, outputIndex + count)) {
  570. const value = values.map(`[${output.join(',')}]`);
  571. list.push(value);
  572. }
  573. const argument = new mxnet.Argument(outputDef.name, list);
  574. this.outputs.push(argument);
  575. outputIndex += count;
  576. }
  577. }
  578. }
  579. if (outputIndex < outputs.length) {
  580. this.outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  581. const name = (outputIndex + index).toString();
  582. const value = values.map(`[${output.join(',')}]`);
  583. return new mxnet.Argument(name, [value]);
  584. }));
  585. }
  586. }
  587. if (node.params) {
  588. for (const param of node.params) {
  589. const value = values.map(param.id);
  590. const argument = new mxnet.Argument(param.name, [value]);
  591. this.inputs.push(argument);
  592. }
  593. }
  594. }
  595. };
  596. mxnet.Tensor = class {
  597. constructor(name, type, data) {
  598. this.name = name;
  599. this.type = type;
  600. this.values = data;
  601. this.encoding = '<';
  602. }
  603. };
  604. mxnet.TensorType = class {
  605. constructor(dataType, shape) {
  606. switch (dataType) {
  607. case 0: this.dataType = 'float32'; break;
  608. case 1: this.dataType = 'float64'; break;
  609. case 2: this.dataType = 'float16'; break;
  610. case 3: this.dataType = 'uint8'; break;
  611. case 4: this.dataType = 'int32'; break;
  612. case 5: this.dataType = 'int8'; break;
  613. case 6: this.dataType = 'int64'; break;
  614. case -1: this.dataType = '?'; break;
  615. default: throw new mxnet.Error(`Unsupported type '${dataType}'.`);
  616. }
  617. this.shape = shape;
  618. }
  619. toString() {
  620. return this.dataType + this.shape.toString();
  621. }
  622. };
  623. mxnet.TensorShape = class {
  624. constructor(dimensions) {
  625. this.dimensions = dimensions;
  626. }
  627. toString() {
  628. if (this.dimensions) {
  629. if (this.dimensions.length === 0) {
  630. return '';
  631. }
  632. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  633. }
  634. return '';
  635. }
  636. };
  637. mxnet.ndarray = class {
  638. static load(reader) {
  639. // NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
  640. const params = new Map();
  641. reader = new mxnet.BinaryReader(reader);
  642. if (reader.uint64().toNumber() !== 0x112) { // kMXAPINDArrayListMagic
  643. throw new mxnet.Error('Invalid signature.');
  644. }
  645. if (reader.uint64().toNumber() !== 0) {
  646. throw new mxnet.Error('Invalid reserved block.');
  647. }
  648. const values = new Array(reader.uint64().toNumber());
  649. for (let i = 0; i < values.length; i++) {
  650. values[i] = new mxnet.ndarray.NDArray(reader);
  651. }
  652. const decoder = new TextDecoder('ascii');
  653. const names = new Array(reader.uint64().toNumber());
  654. for (let i = 0; i < names.length; i++) {
  655. const size = reader.uint64().toNumber();
  656. const buffer = reader.read(size);
  657. names[i] = decoder.decode(buffer);
  658. }
  659. if (names.length !== values.length) {
  660. throw new mxnet.Error('Invalid parameters.');
  661. }
  662. for (let i = 0; i < names.length; i++) {
  663. params.set(names[i], values[i]);
  664. }
  665. return params;
  666. }
  667. };
  668. mxnet.ndarray.NDArray = class {
  669. constructor(reader) {
  670. mxnet.ndarray.NDArray._dataTypeSizeTable = [4, 8, 2, 1, 4, 1, 8];
  671. switch (reader.uint32()) {
  672. case 0xf993faca: { // NDARRAY_V3_MAGIC
  673. throw new mxnet.Array('mxnet.ndarray.NDArray v3 not supported.');
  674. }
  675. case 0xf993fac9: { // NDARRAY_V2_MAGIC
  676. const stype = reader.uint32();
  677. let num_aux_data = 0;
  678. switch (stype) {
  679. case 0: num_aux_data = 0; break; // kDefaultStorage
  680. case 1: num_aux_data = 1; break; // kRowSparseStorage
  681. case 2: num_aux_data = 2; break; // kCSRStorage
  682. default: throw mxnet.Error(`Unsupported NDArray type '${stype}'.`);
  683. }
  684. this.sshape = null;
  685. if (num_aux_data > 0) {
  686. this.sshape = reader.uint64s();
  687. }
  688. this.shape = reader.uint64s();
  689. if (this.shape.length !== 0) {
  690. this.context = {
  691. deviceType: reader.uint32(),
  692. deviceId: reader.uint32()
  693. };
  694. this.dtype = reader.uint32();
  695. if (num_aux_data > 0) {
  696. throw new mxnet.Error('Not implemented.');
  697. }
  698. const dataTypeSize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  699. const size = dataTypeSize * this.size;
  700. this.data = reader.read(size);
  701. }
  702. break;
  703. }
  704. case 0xf993fac8: { // NDARRAY_V1_MAGIC
  705. this.shape = reader.uint64s();
  706. if (this.shape.length !== 0) {
  707. this.context = {
  708. deviceType: reader.uint32(),
  709. deviceId: reader.uint32()
  710. };
  711. this.dtype = reader.uint32();
  712. const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  713. const size = itemsize * this.size;
  714. this.data = reader.read(size);
  715. }
  716. break;
  717. }
  718. default: {
  719. reader.skip(-4);
  720. this.shape = reader.uint32s();
  721. this.context = {
  722. deviceType: reader.uint32(),
  723. deviceId: reader.uint32()
  724. };
  725. this.dtype = reader.uint32();
  726. const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  727. const size = itemsize * this.size;
  728. this.data = reader.read(size);
  729. break;
  730. }
  731. }
  732. }
  733. get size() {
  734. return this.shape.reduce((a, b) => a * b, 1);
  735. }
  736. };
  737. mxnet.BinaryReader = class {
  738. constructor(reader) {
  739. this._reader = reader;
  740. }
  741. skip(offset) {
  742. this._reader.skip(offset);
  743. }
  744. read(length) {
  745. return this._reader.read(length);
  746. }
  747. uint32() {
  748. return this._reader.uint32();
  749. }
  750. uint32s() {
  751. const size = this.uint32();
  752. const array = new Array(size);
  753. for (let i = 0; i < size; i++) {
  754. array[i] = this.uint32();
  755. }
  756. return array;
  757. }
  758. uint64() {
  759. return this._reader.uint64();
  760. }
  761. uint64s() {
  762. const size = this.uint32();
  763. const array = new Array(size);
  764. for (let i = 0; i < size; i++) {
  765. array[i] = this.uint64().toNumber();
  766. }
  767. return array;
  768. }
  769. };
  770. mxnet.Error = class extends Error {
  771. constructor(message) {
  772. super(message);
  773. this.name = 'Error loading MXNet model.';
  774. }
  775. };
  776. export const ModelFactory = mxnet.ModelFactory;