mxnet.js 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102
  1. var mxnet = mxnet || {};
  2. var json = json || require('./json');
  3. var zip = zip || require('./zip');
  4. var ndarray = ndarray || {};
  5. var base = base || require('./base');
  6. mxnet.ModelFactory = class {
  7. match(context) {
  8. const identifier = context.identifier;
  9. const extension = identifier.split('.').pop().toLowerCase();
  10. if (extension === 'json') {
  11. const obj = context.open('json');
  12. if (obj && obj.nodes && obj.arg_nodes && obj.heads) {
  13. return 'mxnet.json';
  14. }
  15. }
  16. if (extension === 'params') {
  17. const stream = context.stream;
  18. const signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
  19. if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value == signature[index])) {
  20. return 'mxnet.params';
  21. }
  22. }
  23. return undefined;
  24. }
  25. open(context, match) {
  26. return context.metadata('mxnet-metadata.json').then((metadata) => {
  27. const basename = (base, identifier, extension, suffix, append) => {
  28. if (!base) {
  29. if (identifier.toLowerCase().endsWith(extension)) {
  30. const items = identifier.substring(0, identifier.length - extension.length).split('-');
  31. if (items.length >= 2) {
  32. const token = items.pop();
  33. if ((suffix && token === suffix) || /[a-zA-Z0-9]*/.exec(token)) {
  34. return items.join('-') + append;
  35. }
  36. }
  37. }
  38. }
  39. return base;
  40. };
  41. const convertVersion = (value) => {
  42. if (Array.isArray(value)) {
  43. if (value.length === 2 && value[0] === 'int') {
  44. const major = Math.floor(value[1] / 10000) % 100;
  45. const minor = Math.floor(value[1] / 100) % 100;
  46. const patch = Math.floor(value[1]) % 100;
  47. return [ major.toString(), minor.toString(), patch.toString() ].join('.');
  48. }
  49. }
  50. return null;
  51. };
  52. const requestManifest = () => {
  53. const parse = (stream) => {
  54. try {
  55. const manifest = {};
  56. if (stream) {
  57. const reader = json.TextReader.open(stream);
  58. const obj = reader.read();
  59. if (obj.Model) {
  60. const modelFormat = obj.Model['Model-Format'];
  61. if (modelFormat && modelFormat !== 'MXNet-Symbolic') {
  62. throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
  63. }
  64. manifest.format = 'MXNet Model Server';
  65. if (obj['Model-Archive-Version']) {
  66. manifest.format += ' v' + obj['Model-Archive-Version'].toString();
  67. }
  68. if (!obj.Model.Symbol) {
  69. throw new mxnet.Error('Manifest does not contain symbol entry.');
  70. }
  71. manifest.symbol = obj.Model.Symbol;
  72. if (obj.Model.Signature) {
  73. manifest.signature = obj.Model.Signature;
  74. }
  75. if (obj.Model.Parameters) {
  76. manifest.params = obj.Model.Parameters;
  77. }
  78. if (obj.Model['Model-Name']) {
  79. manifest.name = obj.Model['Model-Name'];
  80. }
  81. if (obj.Model.Description && manifest.name !== obj.Model.Description) {
  82. manifest.description = obj.Model.Description;
  83. }
  84. }
  85. else if (obj.model) {
  86. manifest.format = 'MXNet Model Archive';
  87. if (obj.specificationVersion) {
  88. manifest.format += ' v' + obj.specificationVersion.toString();
  89. }
  90. if (obj.model.modelName) {
  91. manifest.symbol = obj.model.modelName + '-symbol.json';
  92. }
  93. if (obj.model.modelName) {
  94. manifest.name = obj.model.modelName;
  95. }
  96. if (manifest.model && obj.model.modelVersion) {
  97. manifest.version = obj.model.modelVersion;
  98. }
  99. if (manifest.model && manifest.model.modelName && manifest.name != obj.model.description) {
  100. manifest.description = obj.model.description;
  101. }
  102. }
  103. else {
  104. throw new mxnet.Error('Manifest does not contain model.');
  105. }
  106. if (obj.Engine && obj.Engine.MXNet) {
  107. const version = convertVersion(obj.Engine.MXNet);
  108. manifest.runtime = 'MXNet v' + (version ? version : obj.Engine.MXNet.toString());
  109. }
  110. if (obj.License) {
  111. manifest.license = obj.License;
  112. }
  113. if (obj.runtime) {
  114. manifest.runtime = obj.runtime;
  115. }
  116. if (obj.engine && obj.engine.engineName) {
  117. const engine = obj.engine.engineVersion ? obj.engine.engineName + ' ' + obj.engine.engineVersion : obj.engine.engineName;
  118. manifest.runtime = manifest.runtime ? (manifest.runtime + ' (' + engine + ')') : engine;
  119. }
  120. if (obj.publisher && obj.publisher.author) {
  121. manifest.author = obj.publisher.author;
  122. if (obj.publisher.email) {
  123. manifest.author = manifest.author + ' <' + obj.publisher.email + '>';
  124. }
  125. }
  126. if (obj.license) {
  127. manifest.license = obj.license;
  128. }
  129. if (obj.Model && obj.Model.Signature) {
  130. return context.request(obj.Model.Signature).then((stream) => {
  131. const reader = json.TextReader.open(stream);
  132. manifest.signature = reader.read();
  133. return manifest;
  134. }).catch (() => {
  135. return manifest;
  136. });
  137. }
  138. }
  139. return manifest;
  140. }
  141. catch (err) {
  142. throw new mxnet.Error('Failed to read manifest. ' + err.message);
  143. }
  144. };
  145. return context.request('MANIFEST.json').then((stream) => {
  146. return parse(stream);
  147. }).catch (() => {
  148. return context.request('MAR-INF/MANIFEST.json').then((stream) => {
  149. return parse(stream);
  150. }).catch(() => {
  151. return parse(null);
  152. });
  153. });
  154. };
  155. const createModel = (metadata, manifest, symbol, params) => {
  156. const parameters = new Map();
  157. if (params) {
  158. try {
  159. for (const entry of mxnet.ndarray.load(params)) {
  160. const key = entry[0];
  161. const array = entry[1];
  162. const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
  163. parameters.set(name, array);
  164. }
  165. }
  166. catch (error) {
  167. // continue regardless of error
  168. }
  169. }
  170. if (symbol) {
  171. if (!manifest.format) {
  172. const version = convertVersion(symbol && symbol.attrs && symbol.attrs.mxnet_version ? symbol.attrs.mxnet_version : null);
  173. manifest.format = 'MXNet' + (version ? ' v' + version : '');
  174. }
  175. if (symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
  176. manifest.producer = 'TVM';
  177. }
  178. }
  179. return new mxnet.Model(metadata, manifest, symbol, parameters);
  180. };
  181. const identifier = context.identifier;
  182. switch (match) {
  183. case 'mxnet.json': {
  184. let symbol = null;
  185. try {
  186. symbol = context.open('json');
  187. }
  188. catch (error) {
  189. const message = error && error.message ? error.message : error.toString();
  190. throw new mxnet.Error("Failed to load symbol entry (" + message.replace(/\.$/, '') + ').');
  191. }
  192. const requestParams = (manifest) => {
  193. const file = basename(manifest.params, identifier, '.json', 'symbol', '-0000.params');
  194. if (file) {
  195. return context.request(file, null).then((stream) => {
  196. const buffer = stream.peek();
  197. return createModel(metadata, manifest, symbol, buffer);
  198. }).catch(() => {
  199. return createModel(metadata, manifest, symbol, null);
  200. });
  201. }
  202. return createModel(metadata, manifest, symbol, null);
  203. };
  204. return requestManifest().then((manifest) => {
  205. return requestParams(manifest);
  206. });
  207. }
  208. case 'mxnet.params': {
  209. const params = context.stream.peek();
  210. const requestSymbol = (manifest) => {
  211. const file = basename(manifest.symbol, identifier, '.params', null, '-symbol.json');
  212. if (file) {
  213. return context.request(file, 'utf-8').then((text) => {
  214. const symbol = JSON.parse(text);
  215. return createModel(metadata, manifest, symbol, params);
  216. }).catch(() => {
  217. return createModel(metadata, manifest, null, params);
  218. });
  219. }
  220. return createModel(metadata, manifest, null, params);
  221. };
  222. return requestManifest().then((manifest) => {
  223. return requestSymbol(manifest);
  224. });
  225. }
  226. default: {
  227. throw new mxnet.Error("Unsupported MXNet format '" + match + "'.");
  228. }
  229. }
  230. });
  231. }
  232. };
  233. mxnet.Model = class {
  234. constructor(metadata, manifest, symbol, params) {
  235. if (!symbol && !params) {
  236. throw new mxnet.Error('JSON symbol data not available.');
  237. }
  238. if (symbol) {
  239. if (!Object.prototype.hasOwnProperty.call(symbol, 'nodes')) {
  240. throw new mxnet.Error('JSON file does not contain an MXNet \'nodes\' property.');
  241. }
  242. if (!Object.prototype.hasOwnProperty.call(symbol, 'arg_nodes')) {
  243. throw new mxnet.Error('JSON file does not contain an MXNet \'arg_nodes\' property.');
  244. }
  245. if (!Object.prototype.hasOwnProperty.call(symbol, 'heads')) {
  246. throw new mxnet.Error('JSON file does not contain an MXNet \'heads\' property.');
  247. }
  248. }
  249. this._format = manifest.format || 'MXNet';
  250. this._producer = manifest.producer || '';
  251. this._name = manifest.name || '';
  252. this._version = manifest.version;
  253. this._description = manifest.description || '';
  254. this._runtime = manifest.runtime || '';
  255. this._metadata = [];
  256. if (manifest.author) {
  257. this._metadata.push({ name: 'author', value: manifest.author });
  258. }
  259. if (manifest.license) {
  260. this._metadata.push({ name: 'license', value: manifest.license });
  261. }
  262. this._graphs = [ new mxnet.Graph(metadata, manifest, symbol, params) ];
  263. }
  264. get format() {
  265. return this._format;
  266. }
  267. get producer() {
  268. return this._producer;
  269. }
  270. get runtime() {
  271. return this._runtime;
  272. }
  273. get name() {
  274. return this._name;
  275. }
  276. get version() {
  277. return this._version;
  278. }
  279. get description() {
  280. return this._description;
  281. }
  282. get metadata() {
  283. return this._metadata;
  284. }
  285. get graphs() {
  286. return this._graphs;
  287. }
  288. };
  289. mxnet.Graph = class {
  290. constructor(metadata, manifest, symbol, params) {
  291. this._metadata = metadata;
  292. this._nodes = [];
  293. this._inputs = [];
  294. this._outputs = [];
  295. const tensors = new Map();
  296. if (params) {
  297. for (const pair of params) {
  298. const key = pair[0];
  299. const value = pair[1];
  300. tensors.set(key, new mxnet.Tensor('Initializer', key, new mxnet.TensorType(value.dtype, new mxnet.TensorShape(value.shape)), value.data));
  301. }
  302. }
  303. if (symbol) {
  304. const nodes = symbol.nodes;
  305. const inputs = {};
  306. if (manifest && manifest.signature && manifest.signature.inputs) {
  307. for (const input of manifest.signature.inputs) {
  308. inputs[input.data_name] = input;
  309. }
  310. }
  311. const outputs = {};
  312. if (manifest && manifest.signature && manifest.signature.outputs) {
  313. for (const output of manifest.signature.outputs) {
  314. outputs[output.data_name] = output;
  315. }
  316. }
  317. for (const node of nodes) {
  318. node.outputs = [];
  319. }
  320. for (const node of nodes) {
  321. node.inputs = (node.inputs || []).map((input) => {
  322. return mxnet.Graph._updateOutput(nodes, input);
  323. });
  324. }
  325. const outputCountMap = {};
  326. for (const node of nodes) {
  327. for (const output of node.outputs || []) {
  328. outputCountMap[output] = (outputCountMap[output] || 0) + 1;
  329. }
  330. }
  331. const argumentMap = {};
  332. for (const index of symbol.arg_nodes) {
  333. argumentMap[index] = (index < nodes.length) ? nodes[index] : null;
  334. }
  335. for (let i = 0; i < symbol.heads.length; i++) {
  336. const head = symbol.heads[i];
  337. const outputId = mxnet.Graph._updateOutput(nodes, head);
  338. const outputName = nodes[outputId[0]] ? nodes[outputId[0]].name : ('output' + ((i == 0) ? '' : (i + 1).toString()));
  339. let outputType = null;
  340. const outputSignature = outputs[outputName];
  341. if (outputSignature && outputSignature.data_shape) {
  342. outputType = new mxnet.TensorType(-1, new mxnet.TensorShape(outputSignature.data_shape));
  343. }
  344. this._outputs.push(new mxnet.Parameter(outputName, [ new mxnet.Argument('[' + outputId.join(',') + ']', outputType, null) ]));
  345. }
  346. const initializerMap = {};
  347. for (const node of nodes.filter((node, index) => !argumentMap[index])) {
  348. this._nodes.push(new mxnet.Node(this._metadata, node, argumentMap, initializerMap, tensors));
  349. }
  350. for (const argumentKey of Object.keys(argumentMap)) {
  351. const argument = argumentMap[argumentKey];
  352. if (argument && (!argument.inputs || argument.inputs.length == 0) && (argument.outputs && argument.outputs.length == 1)) {
  353. const inputId = argument.outputs[0];
  354. const inputName = argument.name;
  355. let inputType = null;
  356. const inputSignature = inputs[inputName];
  357. if (inputSignature && inputSignature.data_shape) {
  358. inputType = new mxnet.TensorType(-1, new mxnet.TensorShape(inputSignature.data_shape));
  359. }
  360. this._inputs.push(new mxnet.Parameter(inputName, [ new mxnet.Argument('[' + inputId.join(',') + ']', inputType) ]));
  361. }
  362. }
  363. }
  364. else if (params) {
  365. const blocks = new Map();
  366. let separator = Array.from(params.keys()).every((key) => key.indexOf('_') != -1) ? '_' : '';
  367. if (separator.length == 0) {
  368. separator = Array.from(params.keys()).every((key) => key.indexOf('.') != -1) ? '.' : '';
  369. }
  370. if (separator.length > 0) {
  371. for (const param of params) {
  372. const key = param[0];
  373. const parts = key.split(separator);
  374. let argumentName = parts.pop();
  375. if (key.endsWith('moving_mean') || key.endsWith('moving_var')) {
  376. argumentName = [ parts.pop(), argumentName ].join(separator);
  377. }
  378. const nodeName = parts.join(separator);
  379. if (!blocks.has(nodeName)) {
  380. blocks.set(nodeName, { name: nodeName, op: 'Weights', params: [] });
  381. }
  382. blocks.get(nodeName).params.push({ name: argumentName, id: key });
  383. }
  384. }
  385. else {
  386. throw new mxnet.Error("Unsupported key format in params.");
  387. }
  388. for (const block of blocks.values()) {
  389. this._nodes.push(new mxnet.Node(metadata, block, {}, {}, tensors));
  390. }
  391. }
  392. }
  393. get name() {
  394. return '';
  395. }
  396. get inputs() {
  397. return this._inputs;
  398. }
  399. get outputs() {
  400. return this._outputs;
  401. }
  402. get nodes() {
  403. return this._nodes;
  404. }
  405. static _updateOutput(nodes, input) {
  406. const nodeIndex = input[0];
  407. const node = nodes[nodeIndex];
  408. const outputIndex = input[1];
  409. if (node) {
  410. while (outputIndex >= node.outputs.length) {
  411. node.outputs.push([ nodeIndex, node.outputs.length ]);
  412. }
  413. }
  414. return [ nodeIndex, outputIndex ];
  415. }
  416. };
  417. mxnet.Parameter = class {
  418. constructor(name, args) {
  419. this._name = name;
  420. this._arguments = args;
  421. }
  422. get name() {
  423. return this._name;
  424. }
  425. get visible() {
  426. return true;
  427. }
  428. get arguments() {
  429. return this._arguments;
  430. }
  431. };
  432. mxnet.Argument = class {
  433. constructor(name, type, initializer) {
  434. if (typeof name !== 'string') {
  435. throw new mxnet.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  436. }
  437. this._name = name;
  438. this._type = type || null;
  439. this._initializer = initializer || null;
  440. }
  441. get name() {
  442. if (this._initializer) {
  443. return this._initializer.name;
  444. }
  445. return this._name;
  446. }
  447. get type() {
  448. if (this._initializer) {
  449. return this._initializer.type;
  450. }
  451. return this._type;
  452. }
  453. get initializer() {
  454. return this._initializer;
  455. }
  456. };
  457. mxnet.Node = class {
  458. constructor(metadata, node, argumentMap, initializerMap, tensors) {
  459. let type = node.op;
  460. this._name = node.name;
  461. this._attributes = [];
  462. this._inputs = [];
  463. this._outputs = [];
  464. const attrs = node.attrs || node.attr || node.param;
  465. if (attrs) {
  466. if (type == 'tvm_op' && attrs.func_name) {
  467. type = attrs.func_name;
  468. }
  469. for (const attributeName of Object.keys(attrs)) {
  470. if (type != 'tvm_op' && attributeName != 'func_name') {
  471. this._attributes.push(new mxnet.Attribute(metadata, type, attributeName, attrs[attributeName]));
  472. }
  473. }
  474. }
  475. let initializer = null;
  476. this._type = metadata.type(type) || { name: type };
  477. if (node.inputs) {
  478. let inputs = node.inputs;
  479. if (type == 'RNN') {
  480. inputs = inputs.map((input) => {
  481. const argumentNodeIndex = input[0];
  482. const argument = argumentMap[argumentNodeIndex];
  483. if (argument && argument.op == 'null' && argument.name &&
  484. argument.name.endsWith('_parameters') && argument.attr && argument.attr.__init__) {
  485. this._attributes.push(new mxnet.Attribute(metadata, type, argument.name, argument.attr.__init__));
  486. delete argumentMap[argumentNodeIndex];
  487. return null;
  488. }
  489. return input;
  490. });
  491. inputs = inputs.filter((item) => item != null);
  492. }
  493. const initializers = {};
  494. for (const input of inputs) {
  495. const id = '[' + input.join(',') + ']';
  496. initializer = initializerMap[id];
  497. if (!initializer) {
  498. const argumentNodeIndex = input[0];
  499. const argument = argumentMap[argumentNodeIndex];
  500. if (argument && argument.name &&
  501. (!argument.inputs || argument.inputs.length == 0) &&
  502. (argument.outputs && argument.outputs.length == 1)) {
  503. initializer = tensors.get(argument.name) || null;
  504. if (initializer) {
  505. delete argumentMap[argumentNodeIndex];
  506. }
  507. else {
  508. let prefix = this._name;
  509. if (prefix.endsWith('_fwd')) {
  510. prefix = prefix.slice(0, -3);
  511. }
  512. if (argument.name && (argument.name.startsWith(prefix + '_') || argument.name.startsWith(prefix + '.'))) {
  513. let dataType = -1;
  514. let shape = [];
  515. if (argument.attrs && argument.attrs.__dtype__ && argument.attrs.__shape__) {
  516. try {
  517. dataType = parseInt(argument.attrs.__dtype__);
  518. shape = JSON.parse('[' + argument.attrs.__shape__.replace('(', '').replace(')', '').split(' ').join('').split(',').map((dimension => dimension || '"?"' )).join(',') + ']');
  519. }
  520. catch (err) {
  521. // continue regardless of error
  522. }
  523. }
  524. let argumentType = null;
  525. if (dataType !== -1 || shape.length > 0) {
  526. argumentType = new mxnet.TensorType(dataType, new mxnet.TensorShape(shape));
  527. }
  528. else {
  529. argumentType = new mxnet.TensorType(-1, new mxnet.TensorShape(null));
  530. }
  531. initializer = new mxnet.Tensor('Initializer', argument.name, argumentType, null);
  532. delete argumentMap[argumentNodeIndex];
  533. }
  534. }
  535. }
  536. }
  537. if (initializer) {
  538. initializers[id] = initializer;
  539. initializerMap[id] = initializer;
  540. }
  541. }
  542. let inputIndex = 0;
  543. if (this._type && this._type.inputs) {
  544. for (const inputDef of this._type.inputs) {
  545. if (inputIndex < inputs.length || inputDef.option != 'optional') {
  546. const inputCount = (inputDef.option == 'variadic') ? (inputs.length - inputIndex) : 1;
  547. const inputArguments = [];
  548. for (const input of inputs.slice(inputIndex, inputIndex + inputCount)) {
  549. const inputId = '[' + input.join(',') + ']';
  550. if (inputId != '' || inputDef.option != 'optional') {
  551. inputArguments.push(new mxnet.Argument(inputId, inputDef.type, initializers[inputId]));
  552. }
  553. }
  554. this._inputs.push(new mxnet.Parameter(inputDef.name, inputArguments));
  555. inputIndex += inputCount;
  556. }
  557. }
  558. }
  559. if (inputIndex < inputs.length) {
  560. this._inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  561. const inputId = '[' + input.join(',') + ']';
  562. return new mxnet.Parameter((inputIndex + index).toString(), [
  563. new mxnet.Argument(inputId, null, initializers[inputId])
  564. ]);
  565. }));
  566. }
  567. }
  568. if (node.outputs) {
  569. const outputs = node.outputs;
  570. let outputIndex = 0;
  571. if (this._type && this._type.outputs) {
  572. for (const outputDef of this._type.outputs) {
  573. if (outputIndex < outputs.length || outputDef.option != 'optional') {
  574. const outputArguments = [];
  575. const outputCount = (outputDef.option == 'variadic') ? (outputs.length - outputIndex) : 1;
  576. for (const output of outputs.slice(outputIndex, outputIndex + outputCount)) {
  577. outputArguments.push(new mxnet.Argument('[' + output.join(',') + ']', null, null));
  578. }
  579. this._outputs.push(new mxnet.Parameter(outputDef.name, outputArguments));
  580. outputIndex += outputCount;
  581. }
  582. }
  583. }
  584. if (outputIndex < outputs.length) {
  585. this._outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  586. return new mxnet.Parameter((outputIndex + index).toString(), [
  587. new mxnet.Argument('[' + output.join(',') + ']', null, null)
  588. ]);
  589. }));
  590. }
  591. }
  592. if (node.params) {
  593. for (const param of node.params) {
  594. this._inputs.push(new mxnet.Parameter(param.name, [
  595. new mxnet.Argument(param.id, null, tensors.get(param.id) || null)
  596. ]));
  597. }
  598. }
  599. }
  600. get type() {
  601. return this._type;
  602. }
  603. get name() {
  604. return this._name;
  605. }
  606. get inputs() {
  607. return this._inputs;
  608. }
  609. get outputs() {
  610. return this._outputs;
  611. }
  612. get attributes() {
  613. return this._attributes;
  614. }
  615. };
  616. mxnet.Attribute = class {
  617. constructor(metadata, type, name, value) {
  618. this._name = name;
  619. this._value = value;
  620. let number;
  621. metadata = metadata.attribute(type, name);
  622. if (metadata && metadata.type) {
  623. switch (metadata.type) {
  624. case 'boolean':
  625. switch (value) {
  626. case 0:
  627. case 'False':
  628. this._value = false;
  629. break;
  630. case 1:
  631. case 'True':
  632. this._value = true;
  633. break;
  634. default:
  635. throw new mxnet.Error("Unsupported attribute boolean value '" + value + "'.");
  636. }
  637. break;
  638. case 'int32':
  639. number = Number.parseInt(this._value, 10);
  640. this._value = Number.isNaN(this._value - number) ? value : number;
  641. break;
  642. case 'float32':
  643. case 'float64':
  644. number = Number.parseFloat(this._value);
  645. this._value = Number.isNaN(this._value - number) ? value : number;
  646. break;
  647. case 'int32[]':
  648. if (this._value.length > 2 && this._value.startsWith('(') && this._value.endsWith(')')) {
  649. let array = [];
  650. const items = this._value.substring(1, this._value.length - 1).split(',')
  651. .map((item) => item.trim())
  652. .map((item) => item.endsWith('L') ? item.substring(0, item.length - 1) : item);
  653. for (const item of items) {
  654. number = Number.parseInt(item, 10);
  655. if (Number.isNaN(item - number)) {
  656. array = null;
  657. }
  658. else if (array != null) {
  659. array.push(number);
  660. }
  661. }
  662. if (array != null) {
  663. this._value = array;
  664. }
  665. }
  666. break;
  667. default:
  668. throw new mxnet.Error("Unsupported attribute type '" + metadata.type + "'.");
  669. }
  670. }
  671. if (metadata) {
  672. if (metadata.visible === false) {
  673. this._visible = false;
  674. }
  675. else if (metadata.default !== undefined) {
  676. let defaultValue = metadata.default;
  677. if (this._value == defaultValue) {
  678. this._visible = false;
  679. }
  680. else if (Array.isArray(this._value) && Array.isArray(defaultValue)) {
  681. defaultValue = defaultValue.slice(0, defaultValue.length);
  682. if (defaultValue.length > 1 && defaultValue[defaultValue.length - 1] == null) {
  683. defaultValue.pop();
  684. while (defaultValue.length < this._value.length) {
  685. defaultValue.push(defaultValue[defaultValue.length - 1]);
  686. }
  687. }
  688. if (this._value.every((item, index) => { return item == defaultValue[index]; })) {
  689. this._visible = false;
  690. }
  691. }
  692. }
  693. }
  694. }
  695. get name() {
  696. return this._name;
  697. }
  698. get type() {
  699. return this._type;
  700. }
  701. get value() {
  702. return this._value;
  703. }
  704. get visible() {
  705. return this._visible == false ? false : true;
  706. }
  707. };
  708. mxnet.Tensor = class {
  709. constructor(kind, name, type, data) {
  710. this._kind = kind;
  711. this._name = name;
  712. this._type = type;
  713. this._data = data;
  714. }
  715. get kind() {
  716. return 'Initializer';
  717. }
  718. get name() {
  719. return this._name;
  720. }
  721. get type() {
  722. return this._type;
  723. }
  724. get state() {
  725. return this._context().state;
  726. }
  727. get value() {
  728. const context = this._context();
  729. if (context.state) {
  730. return null;
  731. }
  732. context.limit = Number.MAX_SAFE_INTEGER;
  733. return this._decode(context, 0);
  734. }
  735. toString() {
  736. const context = this._context();
  737. if (context.state) {
  738. return '';
  739. }
  740. context.limit = 10000;
  741. const value = this._decode(context, 0);
  742. return JSON.stringify(value, null, 4);
  743. }
  744. _context() {
  745. const context = {};
  746. context.state = null;
  747. context.index = 0;
  748. context.count = 0;
  749. if (!this._data) {
  750. context.state = 'Tensor data is empty.';
  751. return context;
  752. }
  753. if (!this._type && this._type.dataType === '?') {
  754. context.state = 'Tensor has no data type.';
  755. return context;
  756. }
  757. if (this._type.shape.length < 1) {
  758. context.state = 'Tensor has unknown shape.';
  759. return context;
  760. }
  761. context.dataType = this._type.dataType;
  762. context.dimensions = this._type.shape.dimensions;
  763. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  764. return context;
  765. }
  766. _decode(context, dimension) {
  767. const results = [];
  768. const size = context.dimensions[dimension];
  769. if (dimension == context.dimensions.length - 1) {
  770. for (let i = 0; i < size; i++) {
  771. if (context.count > context.limit) {
  772. results.push('...');
  773. return results;
  774. }
  775. switch (context.dataType) {
  776. case 'float32':
  777. results.push(context.data.getFloat32(context.index, true));
  778. context.index += 4;
  779. context.count++;
  780. break;
  781. case 'float64':
  782. results.push(context.data.getFloat64(context.index, true));
  783. context.index += 8;
  784. context.count++;
  785. break;
  786. case 'float16':
  787. results.push(mxnet.Tensor._decodeNumberFromFloat16(context.data.getUint16(context.index, true)));
  788. context.index += 2;
  789. context.count++;
  790. break;
  791. case 'uint8':
  792. results.push(context.data.getUint8(context.index, true));
  793. context.index += 1;
  794. context.count++;
  795. break;
  796. case 'int32':
  797. results.push(context.data.getInt32(context.index, true));
  798. context.index += 4;
  799. context.count++;
  800. break;
  801. case 'int8':
  802. results.push(context.data.getInt8(context.index, true));
  803. context.index += 1;
  804. context.count++;
  805. break;
  806. case 'int64':
  807. results.push(context.data.getInt64(context.index, true));
  808. context.index += 8;
  809. context.count++;
  810. break;
  811. default:
  812. throw new mxnet.Error("Unsupported tensor data type '" + context.dataType + "'.");
  813. }
  814. }
  815. }
  816. else {
  817. for (let j = 0; j < size; j++) {
  818. if (context.count > context.limit) {
  819. results.push('...');
  820. return results;
  821. }
  822. results.push(this._decode(context, dimension + 1));
  823. }
  824. }
  825. return results;
  826. }
  827. static _decodeNumberFromFloat16(value) {
  828. const s = (value & 0x8000) >> 15;
  829. const e = (value & 0x7C00) >> 10;
  830. const f = value & 0x03FF;
  831. if(e == 0) {
  832. return (s ? -1 : 1) * Math.pow(2, -14) * (f / Math.pow(2, 10));
  833. }
  834. else if (e == 0x1F) {
  835. return f ? NaN : ((s ? -1 : 1) * Infinity);
  836. }
  837. return (s ? -1 : 1) * Math.pow(2, e-15) * (1 + (f / Math.pow(2, 10)));
  838. }
  839. };
  840. mxnet.TensorType = class {
  841. constructor(dataType, shape) {
  842. switch (dataType) {
  843. case 0: this._dataType = 'float32'; break;
  844. case 1: this._dataType = 'float64'; break;
  845. case 2: this._dataType = 'float16'; break;
  846. case 3: this._dataType = 'uint8'; break;
  847. case 4: this._dataType = 'int32'; break;
  848. case 5: this._dataType = 'int8'; break;
  849. case 6: this._dataType = 'int64'; break;
  850. case -1: this._dataType = '?'; break;
  851. default: throw new mxnet.Error("Unsupported type '" + dataType + "'.");
  852. }
  853. this._shape = shape;
  854. }
  855. get dataType() {
  856. return this._dataType;
  857. }
  858. get shape() {
  859. return this._shape;
  860. }
  861. toString() {
  862. return this._dataType + this._shape.toString();
  863. }
  864. };
  865. mxnet.TensorShape = class {
  866. constructor(dimensions) {
  867. this._dimensions = dimensions;
  868. }
  869. get dimensions() {
  870. return this._dimensions;
  871. }
  872. toString() {
  873. if (this._dimensions) {
  874. if (this._dimensions.length == 0) {
  875. return '';
  876. }
  877. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  878. }
  879. return '';
  880. }
  881. };
  882. mxnet.ndarray = class {
  883. static load(buffer) {
  884. // NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
  885. const map = new Map();
  886. const reader = new mxnet.BinaryReader(buffer);
  887. if (reader.uint64() !== 0x112) { // kMXAPINDArrayListMagic
  888. throw new mxnet.Error('Invalid signature.');
  889. }
  890. if (reader.uint64() !== 0) {
  891. throw new mxnet.Error('Invalid reserved block.');
  892. }
  893. const data = new Array(reader.uint64());
  894. for (let i = 0; i < data.length; i++) {
  895. data[i] = new mxnet.ndarray.NDArray(reader);
  896. }
  897. const decoder = new TextDecoder('ascii');
  898. const names = new Array(reader.uint64());
  899. for (let i = 0; i < names.length; i++) {
  900. names[i] = decoder.decode(reader.read(reader.uint64()));
  901. }
  902. if (names.length != data.length) {
  903. throw new mxnet.Error('Label count mismatch.');
  904. }
  905. for (let i = 0; i < names.length; i++) {
  906. map.set(names[i], data[i]);
  907. }
  908. return map;
  909. }
  910. };
  911. mxnet.ndarray.NDArray = class {
  912. constructor(reader) {
  913. mxnet.ndarray.NDArray._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
  914. switch (reader.uint32()) {
  915. case 0xf993faca: { // NDARRAY_V3_MAGIC
  916. throw new mxnet.Array('mxnet.ndarray.NDArray v3 not supported.');
  917. }
  918. case 0xf993fac9: { // NDARRAY_V2_MAGIC
  919. const stype = reader.uint32();
  920. let num_aux_data = 0;
  921. switch (stype) {
  922. case 0: num_aux_data = 0; break; // kDefaultStorage
  923. case 1: num_aux_data = 1; break; // kRowSparseStorage
  924. case 2: num_aux_data = 2; break; // kCSRStorage
  925. default: throw mxnet.Error("Unsupported NDArray type '" + stype + "'.");
  926. }
  927. this.sshape = null;
  928. if (num_aux_data > 0) {
  929. this.sshape = reader.uint64s();
  930. }
  931. this.shape = reader.uint64s();
  932. if (this.shape.length !== 0) {
  933. this.context = new mxnet.context.Context(reader);
  934. this.dtype = reader.uint32();
  935. if (num_aux_data > 0) {
  936. throw new mxnet.Error('Not implemented.');
  937. }
  938. const dataTypeSize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  939. const size = dataTypeSize * this.size;
  940. this.data = reader.read(size);
  941. }
  942. break;
  943. }
  944. case 0xf993fac8: { // NDARRAY_V1_MAGIC
  945. this.shape = reader.uint64s();
  946. if (this.shape.length !== 0) {
  947. this.context = new mxnet.context.Context(reader);
  948. this.dtype = reader.uint32();
  949. const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  950. const size = itemsize * this.size;
  951. this.data = reader.read(size);
  952. }
  953. break;
  954. }
  955. default: {
  956. reader.skip(-4);
  957. this.shape = reader.uint32s();
  958. this.context = new mxnet.context.Context(reader);
  959. this.dtype = reader.uint32();
  960. const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  961. const size = itemsize * this.size;
  962. this.data = reader.read(size);
  963. break;
  964. }
  965. }
  966. }
  967. get size() {
  968. return this.shape.reduce((a, b) => a * b, 1);
  969. }
  970. };
  971. mxnet.BinaryReader = class extends base.BinaryReader {
  972. uint32s() {
  973. const count = this.uint32();
  974. const array = new Array(count);
  975. for (let i = 0; i < array.length; i++) {
  976. array[i] = this.uint32();
  977. }
  978. return array;
  979. }
  980. uint64s() {
  981. const count = this.uint32();
  982. const array = new Array(count);
  983. for (let i = 0; i < array.length; i++) {
  984. array[i] = this.uint64();
  985. }
  986. return array;
  987. }
  988. };
  989. mxnet.context = {};
  990. mxnet.context.Context = class {
  991. constructor(reader) {
  992. this._deviceType = reader.uint32();
  993. this._deviceId = reader.uint32();
  994. }
  995. };
  996. mxnet.Error = class extends Error {
  997. constructor(message) {
  998. super(message);
  999. this.name = 'Error loading MXNet model.';
  1000. }
  1001. };
  1002. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1003. module.exports.ModelFactory = mxnet.ModelFactory;
  1004. }