mxnet.js 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104
  1. var mxnet = mxnet || {};
  2. var json = json || require('./json');
  3. var zip = zip || require('./zip');
  4. var ndarray = ndarray || {};
  5. var base = base || require('./base');
  6. mxnet.ModelFactory = class {
  7. match(context) {
  8. const identifier = context.identifier;
  9. const extension = identifier.split('.').pop().toLowerCase();
  10. if (extension === 'json') {
  11. const obj = context.open('json');
  12. if (obj && obj.nodes && obj.arg_nodes && obj.heads) {
  13. return 'mxnet.json';
  14. }
  15. }
  16. if (extension === 'params') {
  17. const stream = context.stream;
  18. const signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
  19. if (stream && stream.length > signature.length && stream.peek(signature.length).every((value, index) => value == signature[index])) {
  20. return 'mxnet.params';
  21. }
  22. }
  23. return undefined;
  24. }
  25. open(context, match) {
  26. return context.metadata('mxnet-metadata.json').then((metadata) => {
  27. const basename = (base, identifier, extension, suffix, append) => {
  28. if (!base) {
  29. if (identifier.toLowerCase().endsWith(extension)) {
  30. const items = identifier.substring(0, identifier.length - extension.length).split('-');
  31. if (items.length >= 2) {
  32. const token = items.pop();
  33. if ((suffix && token === suffix) || /[a-zA-Z0-9]*/.exec(token)) {
  34. return items.join('-') + append;
  35. }
  36. }
  37. }
  38. }
  39. return base;
  40. };
  41. const convertVersion = (value) => {
  42. if (Array.isArray(value)) {
  43. if (value.length === 2 && value[0] === 'int') {
  44. const major = Math.floor(value[1] / 10000) % 100;
  45. const minor = Math.floor(value[1] / 100) % 100;
  46. const patch = Math.floor(value[1]) % 100;
  47. return [ major.toString(), minor.toString(), patch.toString() ].join('.');
  48. }
  49. }
  50. return null;
  51. };
  52. const requestManifest = () => {
  53. const parse = (stream) => {
  54. try {
  55. const manifest = {};
  56. if (stream) {
  57. const reader = json.TextReader.open(stream);
  58. const obj = reader.read();
  59. if (obj.Model) {
  60. const modelFormat = obj.Model['Model-Format'];
  61. if (modelFormat && modelFormat !== 'MXNet-Symbolic') {
  62. throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
  63. }
  64. manifest.format = 'MXNet Model Server';
  65. if (obj['Model-Archive-Version']) {
  66. manifest.format += ' v' + obj['Model-Archive-Version'].toString();
  67. }
  68. if (!obj.Model.Symbol) {
  69. throw new mxnet.Error('Manifest does not contain symbol entry.');
  70. }
  71. manifest.symbol = obj.Model.Symbol;
  72. if (obj.Model.Signature) {
  73. manifest.signature = obj.Model.Signature;
  74. }
  75. if (obj.Model.Parameters) {
  76. manifest.params = obj.Model.Parameters;
  77. }
  78. if (obj.Model['Model-Name']) {
  79. manifest.name = obj.Model['Model-Name'];
  80. }
  81. if (obj.Model.Description && manifest.name !== obj.Model.Description) {
  82. manifest.description = obj.Model.Description;
  83. }
  84. }
  85. else if (obj.model) {
  86. manifest.format = 'MXNet Model Archive';
  87. if (obj.specificationVersion) {
  88. manifest.format += ' v' + obj.specificationVersion.toString();
  89. }
  90. if (obj.model.modelName) {
  91. manifest.symbol = obj.model.modelName + '-symbol.json';
  92. }
  93. if (obj.model.modelName) {
  94. manifest.name = obj.model.modelName;
  95. }
  96. if (manifest.model && obj.model.modelVersion) {
  97. manifest.version = obj.model.modelVersion;
  98. }
  99. if (manifest.model && manifest.model.modelName && manifest.name != obj.model.description) {
  100. manifest.description = obj.model.description;
  101. }
  102. }
  103. else {
  104. throw new mxnet.Error('Manifest does not contain model.');
  105. }
  106. if (obj.Engine && obj.Engine.MXNet) {
  107. const version = convertVersion(obj.Engine.MXNet);
  108. manifest.runtime = 'MXNet v' + (version ? version : obj.Engine.MXNet.toString());
  109. }
  110. if (obj.License) {
  111. manifest.license = obj.License;
  112. }
  113. if (obj.runtime) {
  114. manifest.runtime = obj.runtime;
  115. }
  116. if (obj.engine && obj.engine.engineName) {
  117. const engine = obj.engine.engineVersion ? obj.engine.engineName + ' ' + obj.engine.engineVersion : obj.engine.engineName;
  118. manifest.runtime = manifest.runtime ? (manifest.runtime + ' (' + engine + ')') : engine;
  119. }
  120. if (obj.publisher && obj.publisher.author) {
  121. manifest.author = obj.publisher.author;
  122. if (obj.publisher.email) {
  123. manifest.author = manifest.author + ' <' + obj.publisher.email + '>';
  124. }
  125. }
  126. if (obj.license) {
  127. manifest.license = obj.license;
  128. }
  129. if (obj.Model && obj.Model.Signature) {
  130. return context.request(obj.Model.Signature).then((stream) => {
  131. const reader = json.TextReader.open(stream);
  132. manifest.signature = reader.read();
  133. return manifest;
  134. }).catch (() => {
  135. return manifest;
  136. });
  137. }
  138. }
  139. return manifest;
  140. }
  141. catch (err) {
  142. throw new mxnet.Error('Failed to read manifest. ' + err.message);
  143. }
  144. };
  145. return context.request('MANIFEST.json').then((stream) => {
  146. return parse(stream);
  147. }).catch (() => {
  148. return context.request('MAR-INF/MANIFEST.json').then((stream) => {
  149. return parse(stream);
  150. }).catch(() => {
  151. return parse(null);
  152. });
  153. });
  154. };
  155. const createModel = (metadata, manifest, symbol, params) => {
  156. const parameters = new Map();
  157. if (params) {
  158. try {
  159. for (const entry of mxnet.ndarray.load(params)) {
  160. const key = entry[0];
  161. const array = entry[1];
  162. const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
  163. parameters.set(name, array);
  164. }
  165. }
  166. catch (error) {
  167. // continue regardless of error
  168. }
  169. }
  170. if (symbol) {
  171. if (!manifest.format) {
  172. const version = convertVersion(symbol && symbol.attrs && symbol.attrs.mxnet_version ? symbol.attrs.mxnet_version : null);
  173. manifest.format = 'MXNet' + (version ? ' v' + version : '');
  174. }
  175. if (symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
  176. manifest.producer = 'TVM';
  177. }
  178. }
  179. return new mxnet.Model(metadata, manifest, symbol, parameters);
  180. };
  181. const identifier = context.identifier;
  182. switch (match) {
  183. case 'mxnet.json': {
  184. let symbol = null;
  185. try {
  186. symbol = context.open('json');
  187. }
  188. catch (error) {
  189. const message = error && error.message ? error.message : error.toString();
  190. throw new mxnet.Error("Failed to load symbol entry (" + message.replace(/\.$/, '') + ').');
  191. }
  192. const requestParams = (manifest) => {
  193. const file = basename(manifest.params, identifier, '.json', 'symbol', '-0000.params');
  194. if (file) {
  195. return context.request(file, null).then((stream) => {
  196. const buffer = stream.peek();
  197. return createModel(metadata, manifest, symbol, buffer);
  198. }).catch(() => {
  199. return createModel(metadata, manifest, symbol, null);
  200. });
  201. }
  202. return createModel(metadata, manifest, symbol, null);
  203. };
  204. return requestManifest().then((manifest) => {
  205. return requestParams(manifest);
  206. });
  207. }
  208. case 'mxnet.params': {
  209. const params = context.stream.peek();
  210. const requestSymbol = (manifest) => {
  211. const file = basename(manifest.symbol, identifier, '.params', null, '-symbol.json');
  212. if (file) {
  213. return context.request(file, 'utf-8').then((text) => {
  214. const symbol = JSON.parse(text);
  215. return createModel(metadata, manifest, symbol, params);
  216. }).catch(() => {
  217. return createModel(metadata, manifest, null, params);
  218. });
  219. }
  220. return createModel(metadata, manifest, null, params);
  221. };
  222. return requestManifest().then((manifest) => {
  223. return requestSymbol(manifest);
  224. });
  225. }
  226. default: {
  227. throw new mxnet.Error("Unsupported MXNet format '" + match + "'.");
  228. }
  229. }
  230. });
  231. }
  232. };
  233. mxnet.Model = class {
  234. constructor(metadata, manifest, symbol, params) {
  235. if (!symbol && !params) {
  236. throw new mxnet.Error('JSON symbol data not available.');
  237. }
  238. if (symbol) {
  239. if (!Object.prototype.hasOwnProperty.call(symbol, 'nodes')) {
  240. throw new mxnet.Error('JSON file does not contain an MXNet \'nodes\' property.');
  241. }
  242. if (!Object.prototype.hasOwnProperty.call(symbol, 'arg_nodes')) {
  243. throw new mxnet.Error('JSON file does not contain an MXNet \'arg_nodes\' property.');
  244. }
  245. if (!Object.prototype.hasOwnProperty.call(symbol, 'heads')) {
  246. throw new mxnet.Error('JSON file does not contain an MXNet \'heads\' property.');
  247. }
  248. }
  249. this._format = manifest.format || 'MXNet';
  250. this._producer = manifest.producer || '';
  251. this._name = manifest.name || '';
  252. this._version = manifest.version;
  253. this._description = manifest.description || '';
  254. this._runtime = manifest.runtime || '';
  255. this._metadata = [];
  256. if (manifest.author) {
  257. this._metadata.push({ name: 'author', value: manifest.author });
  258. }
  259. if (manifest.license) {
  260. this._metadata.push({ name: 'license', value: manifest.license });
  261. }
  262. this._graphs = [ new mxnet.Graph(metadata, manifest, symbol, params) ];
  263. }
  264. get format() {
  265. return this._format;
  266. }
  267. get producer() {
  268. return this._producer;
  269. }
  270. get runtime() {
  271. return this._runtime;
  272. }
  273. get name() {
  274. return this._name;
  275. }
  276. get version() {
  277. return this._version;
  278. }
  279. get description() {
  280. return this._description;
  281. }
  282. get metadata() {
  283. return this._metadata;
  284. }
  285. get graphs() {
  286. return this._graphs;
  287. }
  288. };
  289. mxnet.Graph = class {
  290. constructor(metadata, manifest, symbol, params) {
  291. this._metadata = metadata;
  292. this._nodes = [];
  293. this._inputs = [];
  294. this._outputs = [];
  295. const tensors = new Map();
  296. if (params) {
  297. for (const pair of params) {
  298. const key = pair[0];
  299. const value = pair[1];
  300. tensors.set(key, new mxnet.Tensor('Initializer', key, new mxnet.TensorType(value.dtype, new mxnet.TensorShape(value.shape)), value.data));
  301. }
  302. }
  303. if (symbol) {
  304. const nodes = symbol.nodes;
  305. const inputs = {};
  306. if (manifest && manifest.signature && manifest.signature.inputs) {
  307. for (const input of manifest.signature.inputs) {
  308. inputs[input.data_name] = input;
  309. }
  310. }
  311. const outputs = {};
  312. if (manifest && manifest.signature && manifest.signature.outputs) {
  313. for (const output of manifest.signature.outputs) {
  314. outputs[output.data_name] = output;
  315. }
  316. }
  317. for (const node of nodes) {
  318. node.outputs = [];
  319. }
  320. for (const node of nodes) {
  321. node.inputs = (node.inputs || []).map((input) => {
  322. return mxnet.Graph._updateOutput(nodes, input);
  323. });
  324. }
  325. const outputCountMap = {};
  326. for (const node of nodes) {
  327. for (const output of node.outputs || []) {
  328. outputCountMap[output] = (outputCountMap[output] || 0) + 1;
  329. }
  330. }
  331. const argumentMap = {};
  332. for (const index of symbol.arg_nodes) {
  333. argumentMap[index] = (index < nodes.length) ? nodes[index] : null;
  334. }
  335. for (let i = 0; i < symbol.heads.length; i++) {
  336. const head = symbol.heads[i];
  337. const outputId = mxnet.Graph._updateOutput(nodes, head);
  338. const outputName = nodes[outputId[0]] ? nodes[outputId[0]].name : ('output' + ((i == 0) ? '' : (i + 1).toString()));
  339. let outputType = null;
  340. const outputSignature = outputs[outputName];
  341. if (outputSignature && outputSignature.data_shape) {
  342. outputType = new mxnet.TensorType(-1, new mxnet.TensorShape(outputSignature.data_shape));
  343. }
  344. this._outputs.push(new mxnet.Parameter(outputName, [ new mxnet.Argument('[' + outputId.join(',') + ']', outputType, null) ]));
  345. }
  346. const initializerMap = {};
  347. for (const node of nodes.filter((node, index) => !argumentMap[index])) {
  348. this._nodes.push(new mxnet.Node(this._metadata, node, argumentMap, initializerMap, tensors));
  349. }
  350. for (const argumentKey of Object.keys(argumentMap)) {
  351. const argument = argumentMap[argumentKey];
  352. if (argument && (!argument.inputs || argument.inputs.length == 0) && (argument.outputs && argument.outputs.length == 1)) {
  353. const inputId = argument.outputs[0];
  354. const inputName = argument.name;
  355. let inputType = null;
  356. const inputSignature = inputs[inputName];
  357. if (inputSignature && inputSignature.data_shape) {
  358. inputType = new mxnet.TensorType(-1, new mxnet.TensorShape(inputSignature.data_shape));
  359. }
  360. this._inputs.push(new mxnet.Parameter(inputName, [ new mxnet.Argument('[' + inputId.join(',') + ']', inputType) ]));
  361. }
  362. }
  363. }
  364. else if (params) {
  365. const blocks = new Map();
  366. let separator = Array.from(params.keys()).every((key) => key.indexOf('_') != -1) ? '_' : '';
  367. if (separator.length == 0) {
  368. separator = Array.from(params.keys()).every((key) => key.indexOf('.') != -1) ? '.' : '';
  369. }
  370. if (separator.length > 0) {
  371. for (const param of params) {
  372. const key = param[0];
  373. const parts = key.split(separator);
  374. let argumentName = parts.pop();
  375. if (key.endsWith('moving_mean') || key.endsWith('moving_var')) {
  376. argumentName = [ parts.pop(), argumentName ].join(separator);
  377. }
  378. const nodeName = parts.join(separator);
  379. if (!blocks.has(nodeName)) {
  380. blocks.set(nodeName, { name: nodeName, op: 'Weights', params: [] });
  381. }
  382. blocks.get(nodeName).params.push({ name: argumentName, id: key });
  383. }
  384. }
  385. else {
  386. throw new mxnet.Error("Unsupported key format in params.");
  387. }
  388. for (const block of blocks.values()) {
  389. this._nodes.push(new mxnet.Node(metadata, block, {}, {}, tensors));
  390. }
  391. }
  392. }
  393. get name() {
  394. return '';
  395. }
  396. get inputs() {
  397. return this._inputs;
  398. }
  399. get outputs() {
  400. return this._outputs;
  401. }
  402. get nodes() {
  403. return this._nodes;
  404. }
  405. static _updateOutput(nodes, input) {
  406. const nodeIndex = input[0];
  407. const node = nodes[nodeIndex];
  408. const outputIndex = input[1];
  409. if (node) {
  410. while (outputIndex >= node.outputs.length) {
  411. node.outputs.push([ nodeIndex, node.outputs.length ]);
  412. }
  413. }
  414. return [ nodeIndex, outputIndex ];
  415. }
  416. };
  417. mxnet.Parameter = class {
  418. constructor(name, args) {
  419. this._name = name;
  420. this._arguments = args;
  421. }
  422. get name() {
  423. return this._name;
  424. }
  425. get visible() {
  426. return true;
  427. }
  428. get arguments() {
  429. return this._arguments;
  430. }
  431. };
  432. mxnet.Argument = class {
  433. constructor(name, type, initializer) {
  434. if (typeof name !== 'string') {
  435. throw new mxnet.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  436. }
  437. this._name = name;
  438. this._type = type || null;
  439. this._initializer = initializer || null;
  440. }
  441. get name() {
  442. if (this._initializer) {
  443. return this._initializer.name;
  444. }
  445. return this._name;
  446. }
  447. get type() {
  448. if (this._initializer) {
  449. return this._initializer.type;
  450. }
  451. return this._type;
  452. }
  453. get initializer() {
  454. return this._initializer;
  455. }
  456. };
  457. mxnet.Node = class {
  458. constructor(metadata, node, argumentMap, initializerMap, tensors) {
  459. let type = node.op;
  460. this._name = node.name;
  461. this._attributes = [];
  462. this._inputs = [];
  463. this._outputs = [];
  464. const attrs = node.attrs || node.attr || node.param;
  465. if (attrs) {
  466. if (type == 'tvm_op' && attrs.func_name) {
  467. type = attrs.func_name;
  468. }
  469. for (const attributeName of Object.keys(attrs)) {
  470. if (type != 'tvm_op' && attributeName != 'func_name') {
  471. this._attributes.push(new mxnet.Attribute(metadata, type, attributeName, attrs[attributeName]));
  472. }
  473. }
  474. }
  475. let initializer = null;
  476. this._type = metadata.type(type) || { name: type };
  477. if (node.inputs) {
  478. let inputs = node.inputs;
  479. if (type == 'RNN') {
  480. inputs = inputs.map((input) => {
  481. const argumentNodeIndex = input[0];
  482. const argument = argumentMap[argumentNodeIndex];
  483. if (argument && argument.op == 'null' && argument.name &&
  484. argument.name.endsWith('_parameters') && argument.attr && argument.attr.__init__) {
  485. this._attributes.push(new mxnet.Attribute(metadata, type, argument.name, argument.attr.__init__));
  486. delete argumentMap[argumentNodeIndex];
  487. return null;
  488. }
  489. return input;
  490. });
  491. inputs = inputs.filter((item) => item != null);
  492. }
  493. const initializers = {};
  494. for (const input of inputs) {
  495. const id = '[' + input.join(',') + ']';
  496. initializer = initializerMap[id];
  497. if (!initializer) {
  498. const argumentNodeIndex = input[0];
  499. const argument = argumentMap[argumentNodeIndex];
  500. if (argument && argument.name &&
  501. (!argument.inputs || argument.inputs.length == 0) &&
  502. (argument.outputs && argument.outputs.length == 1)) {
  503. initializer = tensors.get(argument.name) || null;
  504. if (initializer) {
  505. delete argumentMap[argumentNodeIndex];
  506. }
  507. else {
  508. let prefix = this._name;
  509. if (prefix.endsWith('_fwd')) {
  510. prefix = prefix.slice(0, -3);
  511. }
  512. if (argument.name && (argument.name.startsWith(prefix + '_') || argument.name.startsWith(prefix + '.'))) {
  513. let dataType = -1;
  514. let shape = [];
  515. if (argument.attrs && argument.attrs.__dtype__ && argument.attrs.__shape__) {
  516. try {
  517. dataType = parseInt(argument.attrs.__dtype__);
  518. shape = JSON.parse('[' + argument.attrs.__shape__.replace('(', '').replace(')', '').split(' ').join('').split(',').map((dimension => dimension || '"?"' )).join(',') + ']');
  519. }
  520. catch (err) {
  521. // continue regardless of error
  522. }
  523. }
  524. let argumentType = null;
  525. if (dataType !== -1 || shape.length > 0) {
  526. argumentType = new mxnet.TensorType(dataType, new mxnet.TensorShape(shape));
  527. }
  528. else {
  529. argumentType = new mxnet.TensorType(-1, new mxnet.TensorShape(null));
  530. }
  531. initializer = new mxnet.Tensor('Initializer', argument.name, argumentType, null);
  532. delete argumentMap[argumentNodeIndex];
  533. }
  534. }
  535. }
  536. }
  537. if (initializer) {
  538. initializers[id] = initializer;
  539. initializerMap[id] = initializer;
  540. }
  541. }
  542. let inputIndex = 0;
  543. if (this._type && this._type.inputs) {
  544. for (const inputDef of this._type.inputs) {
  545. if (inputIndex < inputs.length || inputDef.option != 'optional') {
  546. const inputCount = (inputDef.option == 'variadic') ? (inputs.length - inputIndex) : 1;
  547. const inputArguments = [];
  548. for (const input of inputs.slice(inputIndex, inputIndex + inputCount)) {
  549. const inputId = '[' + input.join(',') + ']';
  550. if (inputId != '' || inputDef.option != 'optional') {
  551. inputArguments.push(new mxnet.Argument(inputId, inputDef.type, initializers[inputId]));
  552. }
  553. }
  554. this._inputs.push(new mxnet.Parameter(inputDef.name, inputArguments));
  555. inputIndex += inputCount;
  556. }
  557. }
  558. }
  559. if (inputIndex < inputs.length) {
  560. this._inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  561. const inputId = '[' + input.join(',') + ']';
  562. return new mxnet.Parameter((inputIndex + index).toString(), [
  563. new mxnet.Argument(inputId, null, initializers[inputId])
  564. ]);
  565. }));
  566. }
  567. }
  568. if (node.outputs) {
  569. const outputs = node.outputs;
  570. let outputIndex = 0;
  571. if (this._type && this._type.outputs) {
  572. for (const outputDef of this._type.outputs) {
  573. if (outputIndex < outputs.length || outputDef.option != 'optional') {
  574. const outputArguments = [];
  575. const outputCount = (outputDef.option == 'variadic') ? (outputs.length - outputIndex) : 1;
  576. for (const output of outputs.slice(outputIndex, outputIndex + outputCount)) {
  577. outputArguments.push(new mxnet.Argument('[' + output.join(',') + ']', null, null));
  578. }
  579. this._outputs.push(new mxnet.Parameter(outputDef.name, outputArguments));
  580. outputIndex += outputCount;
  581. }
  582. }
  583. }
  584. if (outputIndex < outputs.length) {
  585. this._outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  586. return new mxnet.Parameter((outputIndex + index).toString(), [
  587. new mxnet.Argument('[' + output.join(',') + ']', null, null)
  588. ]);
  589. }));
  590. }
  591. }
  592. if (node.params) {
  593. for (const param of node.params) {
  594. this._inputs.push(new mxnet.Parameter(param.name, [
  595. new mxnet.Argument(param.id, null, tensors.get(param.id) || null)
  596. ]));
  597. }
  598. }
  599. }
  600. get type() {
  601. return this._type;
  602. }
  603. get name() {
  604. return this._name;
  605. }
  606. get inputs() {
  607. return this._inputs;
  608. }
  609. get outputs() {
  610. return this._outputs;
  611. }
  612. get attributes() {
  613. return this._attributes;
  614. }
  615. };
  616. mxnet.Attribute = class {
  617. constructor(metadata, type, name, value) {
  618. this._name = name;
  619. this._value = value;
  620. let number;
  621. metadata = metadata.attribute(type, name);
  622. if (metadata && metadata.type) {
  623. switch (metadata.type) {
  624. case 'boolean':
  625. switch (value) {
  626. case 0:
  627. case '0':
  628. case 'False':
  629. this._value = false;
  630. break;
  631. case 1:
  632. case '1':
  633. case 'True':
  634. this._value = true;
  635. break;
  636. default:
  637. throw new mxnet.Error("Unsupported attribute boolean value '" + value + "'.");
  638. }
  639. break;
  640. case 'int32':
  641. number = Number.parseInt(this._value, 10);
  642. this._value = Number.isNaN(this._value - number) ? value : number;
  643. break;
  644. case 'float32':
  645. case 'float64':
  646. number = Number.parseFloat(this._value);
  647. this._value = Number.isNaN(this._value - number) ? value : number;
  648. break;
  649. case 'int32[]':
  650. if (this._value.length > 2 && this._value.startsWith('(') && this._value.endsWith(')')) {
  651. let array = [];
  652. const items = this._value.substring(1, this._value.length - 1).split(',')
  653. .map((item) => item.trim())
  654. .map((item) => item.endsWith('L') ? item.substring(0, item.length - 1) : item);
  655. for (const item of items) {
  656. number = Number.parseInt(item, 10);
  657. if (Number.isNaN(item - number)) {
  658. array = null;
  659. }
  660. else if (array != null) {
  661. array.push(number);
  662. }
  663. }
  664. if (array != null) {
  665. this._value = array;
  666. }
  667. }
  668. break;
  669. default:
  670. throw new mxnet.Error("Unsupported attribute type '" + metadata.type + "'.");
  671. }
  672. }
  673. if (metadata) {
  674. if (metadata.visible === false) {
  675. this._visible = false;
  676. }
  677. else if (metadata.default !== undefined) {
  678. let defaultValue = metadata.default;
  679. if (this._value == defaultValue) {
  680. this._visible = false;
  681. }
  682. else if (Array.isArray(this._value) && Array.isArray(defaultValue)) {
  683. defaultValue = defaultValue.slice(0, defaultValue.length);
  684. if (defaultValue.length > 1 && defaultValue[defaultValue.length - 1] == null) {
  685. defaultValue.pop();
  686. while (defaultValue.length < this._value.length) {
  687. defaultValue.push(defaultValue[defaultValue.length - 1]);
  688. }
  689. }
  690. if (this._value.every((item, index) => item == defaultValue[index])) {
  691. this._visible = false;
  692. }
  693. }
  694. }
  695. }
  696. }
  697. get name() {
  698. return this._name;
  699. }
  700. get type() {
  701. return this._type;
  702. }
  703. get value() {
  704. return this._value;
  705. }
  706. get visible() {
  707. return this._visible == false ? false : true;
  708. }
  709. };
  710. mxnet.Tensor = class {
  711. constructor(kind, name, type, data) {
  712. this._kind = kind;
  713. this._name = name;
  714. this._type = type;
  715. this._data = data;
  716. }
  717. get kind() {
  718. return 'Initializer';
  719. }
  720. get name() {
  721. return this._name;
  722. }
  723. get type() {
  724. return this._type;
  725. }
  726. get state() {
  727. return this._context().state;
  728. }
  729. get value() {
  730. const context = this._context();
  731. if (context.state) {
  732. return null;
  733. }
  734. context.limit = Number.MAX_SAFE_INTEGER;
  735. return this._decode(context, 0);
  736. }
  737. toString() {
  738. const context = this._context();
  739. if (context.state) {
  740. return '';
  741. }
  742. context.limit = 10000;
  743. const value = this._decode(context, 0);
  744. return JSON.stringify(value, null, 4);
  745. }
  746. _context() {
  747. const context = {};
  748. context.state = null;
  749. context.index = 0;
  750. context.count = 0;
  751. if (!this._data) {
  752. context.state = 'Tensor data is empty.';
  753. return context;
  754. }
  755. if (!this._type && this._type.dataType === '?') {
  756. context.state = 'Tensor has no data type.';
  757. return context;
  758. }
  759. if (this._type.shape.length < 1) {
  760. context.state = 'Tensor has unknown shape.';
  761. return context;
  762. }
  763. context.dataType = this._type.dataType;
  764. context.dimensions = this._type.shape.dimensions;
  765. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  766. return context;
  767. }
  768. _decode(context, dimension) {
  769. const results = [];
  770. const size = context.dimensions[dimension];
  771. if (dimension == context.dimensions.length - 1) {
  772. for (let i = 0; i < size; i++) {
  773. if (context.count > context.limit) {
  774. results.push('...');
  775. return results;
  776. }
  777. switch (context.dataType) {
  778. case 'float32':
  779. results.push(context.data.getFloat32(context.index, true));
  780. context.index += 4;
  781. context.count++;
  782. break;
  783. case 'float64':
  784. results.push(context.data.getFloat64(context.index, true));
  785. context.index += 8;
  786. context.count++;
  787. break;
  788. case 'float16':
  789. results.push(mxnet.Tensor._decodeNumberFromFloat16(context.data.getUint16(context.index, true)));
  790. context.index += 2;
  791. context.count++;
  792. break;
  793. case 'uint8':
  794. results.push(context.data.getUint8(context.index, true));
  795. context.index += 1;
  796. context.count++;
  797. break;
  798. case 'int32':
  799. results.push(context.data.getInt32(context.index, true));
  800. context.index += 4;
  801. context.count++;
  802. break;
  803. case 'int8':
  804. results.push(context.data.getInt8(context.index, true));
  805. context.index += 1;
  806. context.count++;
  807. break;
  808. case 'int64':
  809. results.push(context.data.getInt64(context.index, true));
  810. context.index += 8;
  811. context.count++;
  812. break;
  813. default:
  814. throw new mxnet.Error("Unsupported tensor data type '" + context.dataType + "'.");
  815. }
  816. }
  817. }
  818. else {
  819. for (let j = 0; j < size; j++) {
  820. if (context.count > context.limit) {
  821. results.push('...');
  822. return results;
  823. }
  824. results.push(this._decode(context, dimension + 1));
  825. }
  826. }
  827. return results;
  828. }
  829. static _decodeNumberFromFloat16(value) {
  830. const s = (value & 0x8000) >> 15;
  831. const e = (value & 0x7C00) >> 10;
  832. const f = value & 0x03FF;
  833. if(e == 0) {
  834. return (s ? -1 : 1) * Math.pow(2, -14) * (f / Math.pow(2, 10));
  835. }
  836. else if (e == 0x1F) {
  837. return f ? NaN : ((s ? -1 : 1) * Infinity);
  838. }
  839. return (s ? -1 : 1) * Math.pow(2, e-15) * (1 + (f / Math.pow(2, 10)));
  840. }
  841. };
  842. mxnet.TensorType = class {
  843. constructor(dataType, shape) {
  844. switch (dataType) {
  845. case 0: this._dataType = 'float32'; break;
  846. case 1: this._dataType = 'float64'; break;
  847. case 2: this._dataType = 'float16'; break;
  848. case 3: this._dataType = 'uint8'; break;
  849. case 4: this._dataType = 'int32'; break;
  850. case 5: this._dataType = 'int8'; break;
  851. case 6: this._dataType = 'int64'; break;
  852. case -1: this._dataType = '?'; break;
  853. default: throw new mxnet.Error("Unsupported type '" + dataType + "'.");
  854. }
  855. this._shape = shape;
  856. }
  857. get dataType() {
  858. return this._dataType;
  859. }
  860. get shape() {
  861. return this._shape;
  862. }
  863. toString() {
  864. return this._dataType + this._shape.toString();
  865. }
  866. };
  867. mxnet.TensorShape = class {
  868. constructor(dimensions) {
  869. this._dimensions = dimensions;
  870. }
  871. get dimensions() {
  872. return this._dimensions;
  873. }
  874. toString() {
  875. if (this._dimensions) {
  876. if (this._dimensions.length == 0) {
  877. return '';
  878. }
  879. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  880. }
  881. return '';
  882. }
  883. };
  884. mxnet.ndarray = class {
  885. static load(buffer) {
  886. // NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
  887. const map = new Map();
  888. const reader = new mxnet.BinaryReader(buffer);
  889. if (reader.uint64() !== 0x112) { // kMXAPINDArrayListMagic
  890. throw new mxnet.Error('Invalid signature.');
  891. }
  892. if (reader.uint64() !== 0) {
  893. throw new mxnet.Error('Invalid reserved block.');
  894. }
  895. const data = new Array(reader.uint64());
  896. for (let i = 0; i < data.length; i++) {
  897. data[i] = new mxnet.ndarray.NDArray(reader);
  898. }
  899. const decoder = new TextDecoder('ascii');
  900. const names = new Array(reader.uint64());
  901. for (let i = 0; i < names.length; i++) {
  902. names[i] = decoder.decode(reader.read(reader.uint64()));
  903. }
  904. if (names.length != data.length) {
  905. throw new mxnet.Error('Label count mismatch.');
  906. }
  907. for (let i = 0; i < names.length; i++) {
  908. map.set(names[i], data[i]);
  909. }
  910. return map;
  911. }
  912. };
  913. mxnet.ndarray.NDArray = class {
  914. constructor(reader) {
  915. mxnet.ndarray.NDArray._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
  916. switch (reader.uint32()) {
  917. case 0xf993faca: { // NDARRAY_V3_MAGIC
  918. throw new mxnet.Array('mxnet.ndarray.NDArray v3 not supported.');
  919. }
  920. case 0xf993fac9: { // NDARRAY_V2_MAGIC
  921. const stype = reader.uint32();
  922. let num_aux_data = 0;
  923. switch (stype) {
  924. case 0: num_aux_data = 0; break; // kDefaultStorage
  925. case 1: num_aux_data = 1; break; // kRowSparseStorage
  926. case 2: num_aux_data = 2; break; // kCSRStorage
  927. default: throw mxnet.Error("Unsupported NDArray type '" + stype + "'.");
  928. }
  929. this.sshape = null;
  930. if (num_aux_data > 0) {
  931. this.sshape = reader.uint64s();
  932. }
  933. this.shape = reader.uint64s();
  934. if (this.shape.length !== 0) {
  935. this.context = new mxnet.context.Context(reader);
  936. this.dtype = reader.uint32();
  937. if (num_aux_data > 0) {
  938. throw new mxnet.Error('Not implemented.');
  939. }
  940. const dataTypeSize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  941. const size = dataTypeSize * this.size;
  942. this.data = reader.read(size);
  943. }
  944. break;
  945. }
  946. case 0xf993fac8: { // NDARRAY_V1_MAGIC
  947. this.shape = reader.uint64s();
  948. if (this.shape.length !== 0) {
  949. this.context = new mxnet.context.Context(reader);
  950. this.dtype = reader.uint32();
  951. const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  952. const size = itemsize * this.size;
  953. this.data = reader.read(size);
  954. }
  955. break;
  956. }
  957. default: {
  958. reader.skip(-4);
  959. this.shape = reader.uint32s();
  960. this.context = new mxnet.context.Context(reader);
  961. this.dtype = reader.uint32();
  962. const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
  963. const size = itemsize * this.size;
  964. this.data = reader.read(size);
  965. break;
  966. }
  967. }
  968. }
  969. get size() {
  970. return this.shape.reduce((a, b) => a * b, 1);
  971. }
  972. };
  973. mxnet.BinaryReader = class extends base.BinaryReader {
  974. uint32s() {
  975. const count = this.uint32();
  976. const array = new Array(count);
  977. for (let i = 0; i < array.length; i++) {
  978. array[i] = this.uint32();
  979. }
  980. return array;
  981. }
  982. uint64s() {
  983. const count = this.uint32();
  984. const array = new Array(count);
  985. for (let i = 0; i < array.length; i++) {
  986. array[i] = this.uint64();
  987. }
  988. return array;
  989. }
  990. };
  991. mxnet.context = {};
  992. mxnet.context.Context = class {
  993. constructor(reader) {
  994. this._deviceType = reader.uint32();
  995. this._deviceId = reader.uint32();
  996. }
  997. };
  998. mxnet.Error = class extends Error {
  999. constructor(message) {
  1000. super(message);
  1001. this.name = 'Error loading MXNet model.';
  1002. }
  1003. };
  1004. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1005. module.exports.ModelFactory = mxnet.ModelFactory;
  1006. }