mxnet.js 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277
  1. /* jshint esversion: 6 */
  2. var mxnet = mxnet || {};
  3. var json = json || require('./json');
  4. var zip = zip || require('./zip');
  5. var ndarray = ndarray || {};
  6. mxnet.ModelFactory = class {
  7. match(context) {
  8. const identifier = context.identifier;
  9. const extension = identifier.split('.').pop().toLowerCase();
  10. if (extension === 'model' || extension === 'mar') {
  11. if (context.entries('zip').length > 0) {
  12. return true;
  13. }
  14. }
  15. else if (extension == 'json') {
  16. const obj = context.open('json');
  17. if (obj && obj.nodes && obj.arg_nodes && obj.heads) {
  18. return true;
  19. }
  20. }
  21. else if (extension == 'params') {
  22. const stream = context.stream;
  23. const signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
  24. if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value == signature[index])) {
  25. return true;
  26. }
  27. }
  28. return false;
  29. }
  30. open(context) {
  31. return mxnet.Metadata.open(context).then((metadata) => {
  32. const basename = (identifier, extension, suffix) => {
  33. const dots = identifier.split('.');
  34. if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
  35. const dashes = dots.join('.').split('-');
  36. if (dashes.length >= 2) {
  37. const token = dashes.pop();
  38. if (suffix) {
  39. if (token != suffix) {
  40. return null;
  41. }
  42. }
  43. else {
  44. for (let i = 0; i < token.length; i++) {
  45. const c = token.charAt(i);
  46. if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
  47. continue;
  48. }
  49. return null;
  50. }
  51. }
  52. return dashes.join('-');
  53. }
  54. }
  55. return null;
  56. };
  57. const open_model = (metadata, format, manifest, symbol, signature, params) => {
  58. const parameters = new Map();
  59. if (params) {
  60. try {
  61. const stream = new ndarray.Stream(params);
  62. for (const key of Object.keys(stream.arrays)) {
  63. const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
  64. parameters.set(name, stream.arrays[key]);
  65. }
  66. }
  67. catch (error) {
  68. // continue regardless of error
  69. }
  70. }
  71. return new mxnet.Model(metadata, format, manifest, symbol, signature, parameters);
  72. };
  73. const identifier = context.identifier;
  74. const extension = context.identifier.split('.').pop().toLowerCase();
  75. let symbol = null;
  76. let params = null;
  77. let format = null;
  78. let base = null;
  79. switch (extension) {
  80. case 'json':
  81. try {
  82. symbol = context.open('json');
  83. if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
  84. format = 'TVM';
  85. }
  86. }
  87. catch (error) {
  88. const message = error && error.message ? error.message : error.toString();
  89. throw new mxnet.Error("Failed to load symbol entry (" + message.replace(/\.$/, '') + ').');
  90. }
  91. base = basename(identifier, 'json', 'symbol');
  92. if (base) {
  93. return context.request(base + '-0000.params', null).then((stream) => {
  94. const buffer = stream.peek();
  95. return open_model(metadata, format, null, symbol, null, buffer);
  96. }).catch(() => {
  97. return open_model(metadata, format, null, symbol, null, params);
  98. });
  99. }
  100. return open_model(metadata, format, null, symbol, null, null);
  101. case 'params':
  102. params = context.stream.peek();
  103. base = basename(context.identifier, 'params');
  104. if (base) {
  105. return context.request(base + '-symbol.json', 'utf-8').then((text) => {
  106. symbol = JSON.parse(text);
  107. if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
  108. format = 'TVM';
  109. }
  110. return open_model(metadata, format, null, symbol, null, params);
  111. }).catch(() => {
  112. return open_model(metadata, format, null, null, null, params);
  113. });
  114. }
  115. return open_model(metadata, format, null, null, null, params);
  116. case 'mar':
  117. case 'model': {
  118. const entries = new Map();
  119. try {
  120. for (const entry of context.entries('zip')) {
  121. entries.set(entry.name, entry);
  122. }
  123. }
  124. catch (err) {
  125. throw new mxnet.Error('Failed to decompress Zip archive. ' + err.message);
  126. }
  127. let manifestEntry = entries.get(entries.has('MANIFEST.json') ? 'MANIFEST.json' : 'MAR-INF/MANIFEST.json');
  128. let rootFolder = '';
  129. if (!manifestEntry) {
  130. const folders = Array.from(entries.keys()).filter((name) => name.endsWith('/')).filter((name) => entries.get(name + 'MANIFEST.json'));
  131. if (folders.length != 1) {
  132. throw new mxnet.Error("Manifest not found.");
  133. }
  134. rootFolder = folders[0];
  135. manifestEntry = entries.get(rootFolder + 'MANIFEST.json');
  136. }
  137. const decoder = new TextDecoder('utf-8');
  138. let manifest = null;
  139. try {
  140. manifest = JSON.parse(decoder.decode(manifestEntry.data));
  141. }
  142. catch (err) {
  143. throw new mxnet.Error('Failed to read manifest. ' + err.message);
  144. }
  145. let modelFormat = null;
  146. let symbolEntry = null;
  147. let signatureEntry = null;
  148. let paramsEntry = null;
  149. if (manifest.Model) {
  150. modelFormat = manifest.Model['Model-Format'];
  151. if (modelFormat && modelFormat != 'MXNet-Symbolic') {
  152. throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
  153. }
  154. format = 'MXNet Model Server';
  155. if (manifest['Model-Archive-Version']) {
  156. format += ' v' + manifest['Model-Archive-Version'].toString();
  157. }
  158. if (!manifest.Model.Symbol) {
  159. throw new mxnet.Error('Manifest does not contain symbol entry.');
  160. }
  161. symbolEntry = entries.get(rootFolder + manifest.Model.Symbol);
  162. if (manifest.Model.Signature) {
  163. signatureEntry = entries.get(rootFolder + manifest.Model.Signature);
  164. }
  165. if (manifest.Model.Parameters) {
  166. paramsEntry = entries.get(rootFolder + manifest.Model.Parameters);
  167. }
  168. }
  169. else if (manifest.model) {
  170. format = 'MXNet Model Archive';
  171. if (manifest.specificationVersion) {
  172. format += ' v' + manifest.specificationVersion.toString();
  173. }
  174. if (manifest.model.modelName) {
  175. symbolEntry = entries.get(rootFolder + manifest.model.modelName + '-symbol.json');
  176. let key = null;
  177. for (key of Array.from(entries.keys())) {
  178. key = key.substring(rootFolder.length);
  179. if (key.endsWith('.params') && key.startsWith(manifest.model.modelName)) {
  180. paramsEntry = entries.get(key);
  181. break;
  182. }
  183. }
  184. if (!symbolEntry && !paramsEntry) {
  185. for (key of Object.keys(entries)) {
  186. key = key.substring(rootFolder.length);
  187. if (key.endsWith('.params')) {
  188. paramsEntry = entries.get(key);
  189. break;
  190. }
  191. }
  192. }
  193. }
  194. }
  195. else {
  196. throw new mxnet.Error('Manifest does not contain model.');
  197. }
  198. if (!symbolEntry && !paramsEntry) {
  199. throw new mxnet.Error("Model does not contain symbol entry.");
  200. }
  201. try {
  202. if (symbolEntry) {
  203. symbol = JSON.parse(decoder.decode(symbolEntry.data));
  204. }
  205. }
  206. catch (err) {
  207. throw new mxnet.Error('Failed to load symbol entry.' + err.message);
  208. }
  209. if (paramsEntry) {
  210. params = paramsEntry.data;
  211. }
  212. let signature = null;
  213. try {
  214. if (signatureEntry) {
  215. signature = JSON.parse(decoder.decode(signatureEntry.data));
  216. }
  217. }
  218. catch (err) {
  219. // continue regardless of error
  220. }
  221. return open_model(metadata, format, manifest, symbol, signature, params);
  222. }
  223. default:
  224. throw new mxnet.Error('Unsupported file extension.');
  225. }
  226. });
  227. }
  228. };
  229. mxnet.Model = class {
  230. constructor(metadata, format, manifest, symbol, signature, params) {
  231. if (!symbol && !params) {
  232. throw new mxnet.Error('JSON symbol data not available.');
  233. }
  234. if (symbol) {
  235. if (!Object.prototype.hasOwnProperty.call(symbol, 'nodes')) {
  236. throw new mxnet.Error('JSON file does not contain an MXNet \'nodes\' property.');
  237. }
  238. if (!Object.prototype.hasOwnProperty.call(symbol, 'arg_nodes')) {
  239. throw new mxnet.Error('JSON file does not contain an MXNet \'arg_nodes\' property.');
  240. }
  241. if (!Object.prototype.hasOwnProperty.call(symbol, 'heads')) {
  242. throw new mxnet.Error('JSON file does not contain an MXNet \'heads\' property.');
  243. }
  244. }
  245. if (manifest) {
  246. if (manifest.Model && manifest.Model['Model-Name']) {
  247. this._name = manifest.Model['Model-Name'];
  248. }
  249. if (manifest.Model && manifest.Model.Description && this._name != manifest.Model.Description) {
  250. this._description = manifest.Model.Description;
  251. }
  252. if (manifest.Engine && manifest.Engine.MXNet) {
  253. const engineVersion = mxnet.Model._convert_version(manifest.Engine.MXNet);
  254. this._runtime = 'MXNet v' + (engineVersion ? engineVersion : manifest.Engine.MXNet.toString());
  255. }
  256. if (manifest.License) {
  257. this._license = manifest.License;
  258. }
  259. if (manifest.model && manifest.model.modelName) {
  260. this._name = manifest.model.modelName;
  261. }
  262. if (manifest.model && manifest.model.modelVersion) {
  263. this._version = manifest.model.modelVersion;
  264. }
  265. if (manifest.model && manifest.model.modelName && this._name != manifest.model.description) {
  266. this._description = manifest.model.description;
  267. }
  268. if (manifest.runtime) {
  269. this._runtime = manifest.runtime;
  270. }
  271. if (manifest.engine && manifest.engine.engineName) {
  272. const engine = manifest.engine.engineVersion ? manifest.engine.engineName + ' ' + manifest.engine.engineVersion : manifest.engine.engineName;
  273. this._runtime = this._runtime ? (this._runtime + ' (' + engine + ')') : engine;
  274. }
  275. if (manifest.publisher && manifest.publisher.author) {
  276. this._author = manifest.publisher.author;
  277. if (manifest.publisher.email) {
  278. this._author = this._author + ' <' + manifest.publisher.email + '>';
  279. }
  280. }
  281. if (manifest.license) {
  282. this._license = manifest.license;
  283. }
  284. }
  285. this._format = format;
  286. if (!this._format && symbol && symbol.attrs && symbol.attrs.mxnet_version) {
  287. const version = mxnet.Model._convert_version(symbol.attrs.mxnet_version);
  288. if (version) {
  289. this._format = 'MXNet v' + version;
  290. }
  291. }
  292. if (!this._format) {
  293. this._format = 'MXNet';
  294. }
  295. this._graphs = [];
  296. this._graphs.push(new mxnet.Graph(metadata, manifest, symbol, signature, params));
  297. }
  298. get format() {
  299. return this._format;
  300. }
  301. get name() {
  302. return this._name;
  303. }
  304. get version() {
  305. return this._version;
  306. }
  307. get description() {
  308. return this._description;
  309. }
  310. get author() {
  311. return this._author;
  312. }
  313. get license() {
  314. return this._license;
  315. }
  316. get runtime() {
  317. return this._runtime;
  318. }
  319. get graphs() {
  320. return this._graphs;
  321. }
  322. static _convert_version(value) {
  323. if (Array.isArray(value)) {
  324. if (value.length == 2 && value[0] == 'int') {
  325. const major = Math.floor(value[1] / 10000) % 100;
  326. const minor = Math.floor(value[1] / 100) % 100;
  327. const patch = Math.floor(value[1]) % 100;
  328. return [ major.toString(), minor.toString(), patch.toString() ].join('.');
  329. }
  330. }
  331. return null;
  332. }
  333. };
  334. mxnet.Graph = class {
  335. constructor(metadata, manifest, symbol, signature, params) {
  336. this._metadata = metadata;
  337. this._nodes = [];
  338. this._inputs = [];
  339. this._outputs = [];
  340. const tensors = new Map();
  341. if (params) {
  342. for (const pair of params) {
  343. const key = pair[0];
  344. const value = pair[1];
  345. tensors.set(key, new mxnet.Tensor('Initializer', key, new mxnet.TensorType(value.dataType, new mxnet.TensorShape(value.shape.dimensions)), value.data));
  346. }
  347. }
  348. if (symbol) {
  349. const nodes = symbol.nodes;
  350. const inputs = {};
  351. if (signature && signature.inputs) {
  352. for (const input of signature.inputs) {
  353. inputs[input.data_name] = input;
  354. }
  355. }
  356. const outputs = {};
  357. if (signature && signature.outputs) {
  358. for (const output of signature.outputs) {
  359. outputs[output.data_name] = output;
  360. }
  361. }
  362. for (const node of nodes) {
  363. node.outputs = [];
  364. }
  365. for (const node of nodes) {
  366. node.inputs = node.inputs.map((input) => {
  367. return mxnet.Graph._updateOutput(nodes, input);
  368. });
  369. }
  370. const outputCountMap = {};
  371. for (const node of nodes) {
  372. for (const output of node.outputs) {
  373. outputCountMap[output] = (outputCountMap[output] || 0) + 1;
  374. }
  375. }
  376. const argumentMap = {};
  377. for (const index of symbol.arg_nodes) {
  378. argumentMap[index] = (index < nodes.length) ? nodes[index] : null;
  379. }
  380. for (let i = 0; i < symbol.heads.length; i++) {
  381. const head = symbol.heads[i];
  382. const outputId = mxnet.Graph._updateOutput(nodes, head);
  383. const outputName = nodes[outputId[0]] ? nodes[outputId[0]].name : ('output' + ((i == 0) ? '' : (i + 1).toString()));
  384. let outputType = null;
  385. const outputSignature = outputs[outputName];
  386. if (outputSignature && outputSignature.data_shape) {
  387. outputType = new mxnet.TensorType(-1, new mxnet.TensorShape(outputSignature.data_shape));
  388. }
  389. this._outputs.push(new mxnet.Parameter(outputName, [ new mxnet.Argument('[' + outputId.join(',') + ']', outputType, null) ]));
  390. }
  391. const initializerMap = {};
  392. for (const node of nodes.filter((node, index) => !argumentMap[index])) {
  393. this._nodes.push(new mxnet.Node(this._metadata, node, argumentMap, initializerMap, tensors));
  394. }
  395. for (const argumentKey of Object.keys(argumentMap)) {
  396. const argument = argumentMap[argumentKey];
  397. if (argument && (!argument.inputs || argument.inputs.length == 0) && (argument.outputs && argument.outputs.length == 1)) {
  398. const inputId = argument.outputs[0];
  399. const inputName = argument.name;
  400. let inputType = null;
  401. const inputSignature = inputs[inputName];
  402. if (inputSignature && inputSignature.data_shape) {
  403. inputType = new mxnet.TensorType(-1, new mxnet.TensorShape(inputSignature.data_shape));
  404. }
  405. this._inputs.push(new mxnet.Parameter(inputName, [ new mxnet.Argument('[' + inputId.join(',') + ']', inputType) ]));
  406. }
  407. }
  408. }
  409. else if (params) {
  410. const blocks = new Map();
  411. let separator = Array.from(params.keys()).every((key) => key.indexOf('_') != -1) ? '_' : '';
  412. if (separator.length == 0) {
  413. separator = Array.from(params.keys()).every((key) => key.indexOf('.') != -1) ? '.' : '';
  414. }
  415. if (separator.length > 0) {
  416. for (const param of params) {
  417. const key = param[0];
  418. const parts = key.split(separator);
  419. let argumentName = parts.pop();
  420. if (key.endsWith('moving_mean') || key.endsWith('moving_var')) {
  421. argumentName = [ parts.pop(), argumentName ].join(separator);
  422. }
  423. const nodeName = parts.join(separator);
  424. if (!blocks.has(nodeName)) {
  425. blocks.set(nodeName, { name: nodeName, op: 'Weights', params: [] });
  426. }
  427. blocks.get(nodeName).params.push({ name: argumentName, id: key });
  428. }
  429. }
  430. else {
  431. throw new mxnet.Error("Unsupported key format in params.");
  432. }
  433. for (const block of blocks.values()) {
  434. this._nodes.push(new mxnet.Node(metadata, block, {}, {}, tensors));
  435. }
  436. }
  437. }
  438. get name() {
  439. return '';
  440. }
  441. get inputs() {
  442. return this._inputs;
  443. }
  444. get outputs() {
  445. return this._outputs;
  446. }
  447. get nodes() {
  448. return this._nodes;
  449. }
  450. static _updateOutput(nodes, input) {
  451. const nodeIndex = input[0];
  452. const node = nodes[nodeIndex];
  453. const outputIndex = input[1];
  454. if (node) {
  455. while (outputIndex >= node.outputs.length) {
  456. node.outputs.push([ nodeIndex, node.outputs.length ]);
  457. }
  458. }
  459. return [ nodeIndex, outputIndex ];
  460. }
  461. };
  462. mxnet.Parameter = class {
  463. constructor(name, args) {
  464. this._name = name;
  465. this._arguments = args;
  466. }
  467. get name() {
  468. return this._name;
  469. }
  470. get visible() {
  471. return true;
  472. }
  473. get arguments() {
  474. return this._arguments;
  475. }
  476. };
  477. mxnet.Argument = class {
  478. constructor(name, type, initializer) {
  479. if (typeof name !== 'string') {
  480. throw new mxnet.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  481. }
  482. this._name = name;
  483. this._type = type || null;
  484. this._initializer = initializer || null;
  485. }
  486. get name() {
  487. if (this._initializer) {
  488. return this._initializer.name;
  489. }
  490. return this._name;
  491. }
  492. get type() {
  493. if (this._initializer) {
  494. return this._initializer.type;
  495. }
  496. return this._type;
  497. }
  498. get initializer() {
  499. return this._initializer;
  500. }
  501. };
  502. mxnet.Node = class {
  503. constructor(metadata, node, argumentMap, initializerMap, tensors) {
  504. this._metadata = metadata;
  505. this._type = node.op;
  506. this._name = node.name;
  507. this._attributes = [];
  508. this._inputs = [];
  509. this._outputs = [];
  510. const attrs = node.attrs || node.attr || node.param;
  511. if (attrs) {
  512. if (this._type == 'tvm_op' && attrs.func_name) {
  513. this._type = attrs.func_name;
  514. }
  515. for (const attributeName of Object.keys(attrs)) {
  516. if (this._type != 'tvm_op' && attributeName != 'func_name') {
  517. this._attributes.push(new mxnet.Attribute(this._metadata, this.type, attributeName, attrs[attributeName]));
  518. }
  519. }
  520. }
  521. let initializer = null;
  522. const schema = metadata.type(this.type);
  523. if (node.inputs) {
  524. let inputs = node.inputs;
  525. if (this._type == 'RNN') {
  526. inputs = inputs.map((input) => {
  527. const argumentNodeIndex = input[0];
  528. const argument = argumentMap[argumentNodeIndex];
  529. if (argument && argument.op == 'null' && argument.name &&
  530. argument.name.endsWith('_parameters') && argument.attr && argument.attr.__init__) {
  531. this._attributes.push(new mxnet.Attribute(this._metadata, this.type, argument.name, argument.attr.__init__));
  532. delete argumentMap[argumentNodeIndex];
  533. return null;
  534. }
  535. return input;
  536. });
  537. inputs = inputs.filter((item) => item != null);
  538. }
  539. const initializers = {};
  540. for (const input of inputs) {
  541. const id = '[' + input.join(',') + ']';
  542. initializer = initializerMap[id];
  543. if (!initializer) {
  544. const argumentNodeIndex = input[0];
  545. const argument = argumentMap[argumentNodeIndex];
  546. if (argument && argument.name &&
  547. (!argument.inputs || argument.inputs.length == 0) &&
  548. (argument.outputs && argument.outputs.length == 1)) {
  549. initializer = tensors.get(argument.name) || null;
  550. if (initializer) {
  551. delete argumentMap[argumentNodeIndex];
  552. }
  553. else {
  554. let prefix = this._name;
  555. if (prefix.endsWith('_fwd')) {
  556. prefix = prefix.slice(0, -3);
  557. }
  558. if (argument.name && (argument.name.startsWith(prefix + '_') || argument.name.startsWith(prefix + '.'))) {
  559. let dataType = -1;
  560. let shape = [];
  561. if (argument.attrs && argument.attrs.__dtype__ && argument.attrs.__shape__) {
  562. try {
  563. dataType = parseInt(argument.attrs.__dtype__);
  564. shape = JSON.parse('[' + argument.attrs.__shape__.replace('(', '').replace(')', '').split(' ').join('').split(',').map((dimension => dimension || '"?"' )).join(',') + ']');
  565. }
  566. catch (err) {
  567. // continue regardless of error
  568. }
  569. }
  570. let argumentType = null;
  571. if (dataType !== -1 || shape.length > 0) {
  572. argumentType = new mxnet.TensorType(dataType, new mxnet.TensorShape(shape));
  573. }
  574. else {
  575. argumentType = new mxnet.TensorType(-1, new mxnet.TensorShape(null));
  576. }
  577. initializer = new mxnet.Tensor('Initializer', argument.name, argumentType, null);
  578. delete argumentMap[argumentNodeIndex];
  579. }
  580. }
  581. }
  582. }
  583. if (initializer) {
  584. initializers[id] = initializer;
  585. initializerMap[id] = initializer;
  586. }
  587. }
  588. let inputIndex = 0;
  589. if (schema && schema.inputs) {
  590. for (const inputDef of schema.inputs) {
  591. if (inputIndex < inputs.length || inputDef.option != 'optional') {
  592. const inputCount = (inputDef.option == 'variadic') ? (inputs.length - inputIndex) : 1;
  593. const inputArguments = [];
  594. for (const input of inputs.slice(inputIndex, inputIndex + inputCount)) {
  595. const inputId = '[' + input.join(',') + ']';
  596. if (inputId != '' || inputDef.option != 'optional') {
  597. inputArguments.push(new mxnet.Argument(inputId, inputDef.type, initializers[inputId]));
  598. }
  599. }
  600. this._inputs.push(new mxnet.Parameter(inputDef.name, inputArguments));
  601. inputIndex += inputCount;
  602. }
  603. }
  604. }
  605. if (inputIndex < inputs.length) {
  606. this._inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  607. const inputId = '[' + input.join(',') + ']';
  608. return new mxnet.Parameter((inputIndex + index).toString(), [
  609. new mxnet.Argument(inputId, null, initializers[inputId])
  610. ]);
  611. }));
  612. }
  613. }
  614. if (node.outputs) {
  615. const outputs = node.outputs;
  616. let outputIndex = 0;
  617. if (schema && schema.outputs) {
  618. for (const outputDef of schema.outputs) {
  619. if (outputIndex < outputs.length || outputDef.option != 'optional') {
  620. const outputArguments = [];
  621. const outputCount = (outputDef.option == 'variadic') ? (outputs.length - outputIndex) : 1;
  622. for (const output of outputs.slice(outputIndex, outputIndex + outputCount)) {
  623. outputArguments.push(new mxnet.Argument('[' + output.join(',') + ']', null, null));
  624. }
  625. this._outputs.push(new mxnet.Parameter(outputDef.name, outputArguments));
  626. outputIndex += outputCount;
  627. }
  628. }
  629. }
  630. if (outputIndex < outputs.length) {
  631. this._outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  632. return new mxnet.Parameter((outputIndex + index).toString(), [
  633. new mxnet.Argument('[' + output.join(',') + ']', null, null)
  634. ]);
  635. }));
  636. }
  637. }
  638. if (node.params) {
  639. for (const param of node.params) {
  640. this._inputs.push(new mxnet.Parameter(param.name, [
  641. new mxnet.Argument(param.id, null, tensors.get(param.id) || null)
  642. ]));
  643. }
  644. }
  645. }
  646. get type() {
  647. return this._type;
  648. }
  649. get metadata() {
  650. return this._metadata.type(this._type);
  651. }
  652. get name() {
  653. return this._name;
  654. }
  655. get inputs() {
  656. return this._inputs;
  657. }
  658. get outputs() {
  659. return this._outputs;
  660. }
  661. get attributes() {
  662. return this._attributes;
  663. }
  664. };
  665. mxnet.Attribute = class {
  666. constructor(metadata, type, name, value) {
  667. this._name = name;
  668. this._value = value;
  669. let number;
  670. const schema = metadata.attribute(type, name);
  671. if (schema && schema.type) {
  672. switch (schema.type) {
  673. case 'boolean':
  674. switch (value) {
  675. case 'True':
  676. this._value = true;
  677. break;
  678. case 'False':
  679. this._value = false;
  680. break;
  681. }
  682. break;
  683. case 'int32':
  684. number = Number.parseInt(this._value, 10);
  685. this._value = Number.isNaN(this._value - number) ? value : number;
  686. break;
  687. case 'float32':
  688. case 'float64':
  689. number = Number.parseFloat(this._value);
  690. this._value = Number.isNaN(this._value - number) ? value : number;
  691. break;
  692. case 'int32[]':
  693. if (this._value.length > 2 && this._value.startsWith('(') && this._value.endsWith(')')) {
  694. let array = [];
  695. const items = this._value.substring(1, this._value.length - 1).split(',')
  696. .map((item) => item.trim())
  697. .map((item) => item.endsWith('L') ? item.substring(0, item.length - 1) : item);
  698. for (const item of items) {
  699. number = Number.parseInt(item, 10);
  700. if (Number.isNaN(item - number)) {
  701. array = null;
  702. }
  703. else if (array != null) {
  704. array.push(number);
  705. }
  706. }
  707. if (array != null) {
  708. this._value = array;
  709. }
  710. }
  711. break;
  712. }
  713. }
  714. if (schema) {
  715. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  716. this._visible = false;
  717. }
  718. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  719. let defaultValue = schema.default;
  720. if (this._value == defaultValue) {
  721. this._visible = false;
  722. }
  723. else if (Array.isArray(this._value) && Array.isArray(defaultValue)) {
  724. defaultValue = defaultValue.slice(0, defaultValue.length);
  725. if (defaultValue.length > 1 && defaultValue[defaultValue.length - 1] == null) {
  726. defaultValue.pop();
  727. while (defaultValue.length < this._value.length) {
  728. defaultValue.push(defaultValue[defaultValue.length - 1]);
  729. }
  730. }
  731. if (this._value.every((item, index) => { return item == defaultValue[index]; })) {
  732. this._visible = false;
  733. }
  734. }
  735. }
  736. }
  737. }
  738. get name() {
  739. return this._name;
  740. }
  741. get type() {
  742. return this._type;
  743. }
  744. get value() {
  745. return this._value;
  746. }
  747. get visible() {
  748. return this._visible == false ? false : true;
  749. }
  750. };
  751. mxnet.Tensor = class {
  752. constructor(kind, name, type, data) {
  753. this._kind = kind;
  754. this._name = name;
  755. this._type = type;
  756. this._data = data;
  757. }
  758. get kind() {
  759. return 'Initializer';
  760. }
  761. get name() {
  762. return this._name;
  763. }
  764. get type() {
  765. return this._type;
  766. }
  767. get state() {
  768. return this._context().state;
  769. }
  770. get value() {
  771. const context = this._context();
  772. if (context.state) {
  773. return null;
  774. }
  775. context.limit = Number.MAX_SAFE_INTEGER;
  776. return this._decode(context, 0);
  777. }
  778. toString() {
  779. const context = this._context();
  780. if (context.state) {
  781. return '';
  782. }
  783. context.limit = 10000;
  784. const value = this._decode(context, 0);
  785. return JSON.stringify(value, null, 4);
  786. }
  787. _context() {
  788. const context = {};
  789. context.state = null;
  790. context.index = 0;
  791. context.count = 0;
  792. if (!this._data) {
  793. context.state = 'Tensor data is empty.';
  794. return context;
  795. }
  796. if (!this._type && this._type.dataType === '?') {
  797. context.state = 'Tensor has no data type.';
  798. return context;
  799. }
  800. if (this._type.shape.length < 1) {
  801. context.state = 'Tensor has unknown shape.';
  802. return context;
  803. }
  804. context.dataType = this._type.dataType;
  805. context.dimensions = this._type.shape.dimensions;
  806. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  807. return context;
  808. }
  809. _decode(context, dimension) {
  810. const results = [];
  811. const size = context.dimensions[dimension];
  812. if (dimension == context.dimensions.length - 1) {
  813. for (let i = 0; i < size; i++) {
  814. if (context.count > context.limit) {
  815. results.push('...');
  816. return results;
  817. }
  818. switch (context.dataType) {
  819. case 'float32':
  820. results.push(context.data.getFloat32(context.index, true));
  821. context.index += 4;
  822. context.count++;
  823. break;
  824. case 'float64':
  825. results.push(context.data.getFloat64(context.index, true));
  826. context.index += 8;
  827. context.count++;
  828. break;
  829. case 'float16':
  830. results.push(mxnet.Tensor._decodeNumberFromFloat16(context.data.getUint16(context.index, true)));
  831. context.index += 2;
  832. context.count++;
  833. break;
  834. case 'uint8':
  835. results.push(context.data.getUint8(context.index, true));
  836. context.index += 1;
  837. context.count++;
  838. break;
  839. case 'int32':
  840. results.push(context.data.getInt32(context.index, true));
  841. context.index += 4;
  842. context.count++;
  843. break;
  844. case 'int8':
  845. results.push(context.data.getInt8(context.index, true));
  846. context.index += 1;
  847. context.count++;
  848. break;
  849. case 'int64':
  850. results.push(context.data.getInt64(context.index, true));
  851. context.index += 8;
  852. context.count++;
  853. break;
  854. }
  855. }
  856. }
  857. else {
  858. for (let j = 0; j < size; j++) {
  859. if (context.count > context.limit) {
  860. results.push('...');
  861. return results;
  862. }
  863. results.push(this._decode(context, dimension + 1));
  864. }
  865. }
  866. return results;
  867. }
  868. static _decodeNumberFromFloat16(value) {
  869. const s = (value & 0x8000) >> 15;
  870. const e = (value & 0x7C00) >> 10;
  871. const f = value & 0x03FF;
  872. if(e == 0) {
  873. return (s ? -1 : 1) * Math.pow(2, -14) * (f / Math.pow(2, 10));
  874. }
  875. else if (e == 0x1F) {
  876. return f ? NaN : ((s ? -1 : 1) * Infinity);
  877. }
  878. return (s ? -1 : 1) * Math.pow(2, e-15) * (1 + (f / Math.pow(2, 10)));
  879. }
  880. };
  881. mxnet.TensorType = class {
  882. constructor(dataType, shape) {
  883. switch (dataType) {
  884. case 0: this._dataType = 'float32'; break;
  885. case 1: this._dataType = 'float64'; break;
  886. case 2: this._dataType = 'float16'; break;
  887. case 3: this._dataType = 'uint8'; break;
  888. case 4: this._dataType = 'int32'; break;
  889. case 5: this._dataType = 'int8'; break;
  890. case 6: this._dataType = 'int64'; break;
  891. case -1: this._dataType = '?'; break;
  892. default: throw new mxnet.Error("Unknown type '" + dataType + "'.");
  893. }
  894. this._shape = shape;
  895. }
  896. get dataType() {
  897. return this._dataType;
  898. }
  899. get shape() {
  900. return this._shape;
  901. }
  902. toString() {
  903. return this._dataType + this._shape.toString();
  904. }
  905. };
  906. mxnet.TensorShape = class {
  907. constructor(dimensions) {
  908. this._dimensions = dimensions;
  909. }
  910. get dimensions() {
  911. return this._dimensions;
  912. }
  913. toString() {
  914. if (this._dimensions) {
  915. if (this._dimensions.length == 0) {
  916. return '';
  917. }
  918. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  919. }
  920. return '';
  921. }
  922. };
  923. mxnet.Metadata = class {
  924. static open(context) {
  925. if (mxnet.Metadata._metadata) {
  926. return Promise.resolve(mxnet.Metadata._metadata);
  927. }
  928. return context.request('mxnet-metadata.json', 'utf-8', null).then((data) => {
  929. mxnet.Metadata._metadata = new mxnet.Metadata(data);
  930. return mxnet.Metadata._metadata;
  931. }).catch(() => {
  932. mxnet.Metadata._metadata = new mxnet.Metadata(null);
  933. return mxnet.Metadata._metadata;
  934. });
  935. }
  936. constructor(data) {
  937. this._map = new Map();
  938. this._attributeCache = {};
  939. if (data) {
  940. const metadata = JSON.parse(data);
  941. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  942. }
  943. }
  944. type(name) {
  945. return this._map.get(name);
  946. }
  947. attribute(type, name) {
  948. let map = this._attributeCache[type];
  949. if (!map) {
  950. map = {};
  951. const schema = this.type(type);
  952. if (schema && schema.attributes) {
  953. for (const attribute of schema.attributes) {
  954. map[attribute.name] = attribute;
  955. }
  956. }
  957. this._attributeCache[type] = map;
  958. }
  959. return map[name] || null;
  960. }
  961. };
  962. mxnet.Error = class extends Error {
  963. constructor(message) {
  964. super(message);
  965. this.name = 'Error loading MXNet model.';
  966. }
  967. };
  968. ndarray.Stream = class {
  969. constructor(buffer) {
  970. this._arrays = {};
  971. const reader = new ndarray.Reader(buffer);
  972. if (!reader.checkSignature([ 0x12, 1, 0, 0, 0, 0, 0, 0 ])) {
  973. throw new ndarray.Error('Invalid signature.');
  974. }
  975. if (!reader.checkSignature([ 0, 0, 0, 0, 0, 0, 0, 0 ])) {
  976. throw new ndarray.Error('Invalid reserved block.');
  977. }
  978. const data = [];
  979. for (let dataSize = reader.uint64(); dataSize > 0; dataSize--) {
  980. data.push(new ndarray.Array(reader));
  981. }
  982. const decoder = new TextDecoder('ascii');
  983. const names = [];
  984. for (let namesSize = reader.uint64(); namesSize > 0; namesSize--) {
  985. const name = decoder.decode(reader.read(reader.uint64()));
  986. names.push(name);
  987. }
  988. if (names.length != data.length) {
  989. throw new ndarray.Error('Label count mismatch.');
  990. }
  991. for (let i = 0; i < names.length; i++) {
  992. this._arrays[names[i]] = data[i];
  993. }
  994. }
  995. get arrays() {
  996. return this._arrays;
  997. }
  998. };
  999. ndarray.Array = class {
  1000. constructor(reader) {
  1001. ndarray.Array._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
  1002. if (reader.checkSignature([ 0xc9, 0xfa, 0x93, 0xF9 ])) {
  1003. this._loadV2(reader);
  1004. }
  1005. else if (reader.checkSignature([ 0xc8, 0xfa, 0x93, 0xF9 ])) {
  1006. this._loadV1(reader);
  1007. }
  1008. else {
  1009. this._loadV0(reader);
  1010. }
  1011. }
  1012. _loadV2(reader) {
  1013. const stype = reader.uint32();
  1014. let num_aux_data = 0;
  1015. switch (stype) {
  1016. case 0: num_aux_data = 0; break; // kDefaultStorage
  1017. case 1: num_aux_data = 1; break; // kRowSparseStorage
  1018. case 2: num_aux_data = 2; break; // kCSRStorage
  1019. }
  1020. this.sshape = null;
  1021. if (num_aux_data > 0) {
  1022. this.sshape = new ndarray.Shape(reader, true);
  1023. }
  1024. this._shape = new ndarray.Shape(reader, true);
  1025. if (this._shape.dimensions.length == 0) {
  1026. return;
  1027. }
  1028. this._context = new ndarray.Context(reader);
  1029. this._dataType = reader.uint32();
  1030. if (num_aux_data > 0) {
  1031. throw new ndarray.Error('Not implemented.');
  1032. }
  1033. const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
  1034. const size = dataTypeSize * this._shape.size();
  1035. this._data = reader.read(size);
  1036. }
  1037. _loadV1(reader) {
  1038. this._shape = new ndarray.Shape(reader, true);
  1039. if (this._shape.dimensions.length == 0) {
  1040. return;
  1041. }
  1042. this._context = new ndarray.Context(reader);
  1043. this._dataType = reader.uint32();
  1044. const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
  1045. const size = dataTypeSize * this._shape.size();
  1046. this._data = reader.read(size);
  1047. }
  1048. _loadV0(reader) {
  1049. this._shape = new ndarray.Shape(reader, false);
  1050. this._context = new ndarray.Context(reader);
  1051. this._dataType = reader.uint32();
  1052. const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
  1053. const size = dataTypeSize * this._shape.size();
  1054. this._data = reader.read(size);
  1055. }
  1056. get dataType() {
  1057. return this._dataType;
  1058. }
  1059. get shape() {
  1060. return this._shape;
  1061. }
  1062. get data() {
  1063. return this._data;
  1064. }
  1065. };
  1066. ndarray.Shape = class {
  1067. constructor(reader, uint64) {
  1068. const ndim = reader.uint32();
  1069. this._dimensions = [];
  1070. for (let i = 0; i < ndim; i++) {
  1071. this._dimensions.push(uint64 ? reader.uint64() : reader.uint32());
  1072. }
  1073. }
  1074. get dimensions() {
  1075. return this._dimensions;
  1076. }
  1077. size() {
  1078. return this._dimensions.reduce((a, b) => a * b);
  1079. }
  1080. };
  1081. ndarray.Context = class {
  1082. constructor(reader) {
  1083. this._deviceType = reader.uint32();
  1084. this._deviceId = reader.uint32();
  1085. }
  1086. };
  1087. ndarray.Reader = class {
  1088. constructor(buffer) {
  1089. this._buffer = buffer;
  1090. this._position = 0;
  1091. this._end = buffer.length;
  1092. }
  1093. checkSignature(signature) {
  1094. if (this._position + signature.length <= this._end) {
  1095. for (let i = 0; i < signature.length; i++) {
  1096. if (this._buffer[this._position + i] != signature[i]) {
  1097. return false;
  1098. }
  1099. }
  1100. }
  1101. this._position += signature.length;
  1102. return true;
  1103. }
  1104. read(size) {
  1105. if (this._position + size > this._end) {
  1106. throw new ndarray.Error('Data not available.');
  1107. }
  1108. const data = this._buffer.subarray(this._position, this._position + size);
  1109. this._position += size;
  1110. return data;
  1111. }
  1112. uint16() {
  1113. if (this._position + 2 > this._end) {
  1114. throw new ndarray.Error('Data not available.');
  1115. }
  1116. const value = this._buffer[this._position] | (this._buffer[this._position + 1] << 8);
  1117. this._position += 2;
  1118. return value;
  1119. }
  1120. uint32() {
  1121. return this.uint16() | (this.uint16() << 16);
  1122. }
  1123. uint64() {
  1124. const value = this.uint32();
  1125. if (this.uint32() != 0) {
  1126. throw new ndarray.Error('Large int64 value.');
  1127. }
  1128. return value;
  1129. }
  1130. };
  1131. ndarray.Error = class extends Error {
  1132. constructor(message) {
  1133. super(message);
  1134. this.name = 'NDArray Error';
  1135. }
  1136. };
  1137. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1138. module.exports.ModelFactory = mxnet.ModelFactory;
  1139. }