mxnet-model.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. /*jshint esversion: 6 */
  2. // Experimental
  3. class MXNetModelFactory {
  4. match(buffer, identifier) {
  5. if (identifier.endsWith('-symbol.json')) {
  6. return true;
  7. }
  8. var extension = identifier.split('.').pop();
  9. if (extension == 'json') {
  10. var decoder = new TextDecoder('utf-8');
  11. var json = decoder.decode(buffer);
  12. if (json.includes('\"mxnet_version\":')) {
  13. return true;
  14. }
  15. }
  16. return false;
  17. }
  18. open(buffer, identifier, host, callback) {
  19. try {
  20. var decoder = new TextDecoder('utf-8');
  21. var json = decoder.decode(buffer);
  22. var model = new MXNetModel(json);
  23. MXNetOperatorMetadata.open(host, (err, metadata) => {
  24. callback(null, model);
  25. });
  26. }
  27. catch (err) {
  28. callback(new MXNetError(err.message), null);
  29. }
  30. }
  31. }
  32. class MXNetModel {
  33. constructor(json) {
  34. var model = JSON.parse(json);
  35. if (!model) {
  36. throw new MXNetError('JSON file does not contain MXNet data.');
  37. }
  38. if (!model.hasOwnProperty('nodes')) {
  39. throw new MXNetError('JSON file does not contain an MXNet \'nodes\' property.');
  40. }
  41. if (!model.hasOwnProperty('arg_nodes')) {
  42. throw new MXNetError('JSON file does not contain an MXNet \'arg_nodes\' property.');
  43. }
  44. if (!model.hasOwnProperty('heads')) {
  45. throw new MXNetError('JSON file does not contain an MXNet \'heads\' property.');
  46. }
  47. if (model.attrs && model.attrs.mxnet_version && model.attrs.mxnet_version.length == 2 && model.attrs.mxnet_version[0] == 'int') {
  48. var version = model.attrs.mxnet_version[1];
  49. var revision = version % 100;
  50. var minor = Math.floor(version / 100) % 100;
  51. var major = Math.floor(version / 10000) % 100;
  52. this._version = major.toString() + '.' + minor.toString() + '.' + revision.toString();
  53. }
  54. this._graphs = [ new MXNetGraph(model) ];
  55. }
  56. get properties() {
  57. var results = [];
  58. results.push({ name: 'Format', value: 'MXNet' + (this._version ? (' v' + this._version) : '') });
  59. return results;
  60. }
  61. get graphs() {
  62. return this._graphs;
  63. }
  64. }
  65. class MXNetGraph {
  66. constructor(json)
  67. {
  68. var nodes = json.nodes;
  69. this._nodes = [];
  70. json.nodes.forEach((node) => {
  71. node.outputs = [];
  72. });
  73. nodes.forEach((node) => {
  74. node.inputs = node.inputs.map((input) => {
  75. return MXNetGraph.updateOutput(nodes, input);
  76. });
  77. });
  78. var argumentMap = {};
  79. json.arg_nodes.forEach((index) => {
  80. argumentMap[index] = (index < nodes.length) ? nodes[index] : null;
  81. });
  82. this._outputs = [];
  83. var headMap = {};
  84. json.heads.forEach((head, index) => {
  85. var id = MXNetGraph.updateOutput(nodes, head);
  86. var name = 'output' + ((index == 0) ? '' : (index + 1).toString());
  87. this._outputs.push({ id: id, name: name });
  88. });
  89. nodes.forEach((node, index) => {
  90. if (!argumentMap[index]) {
  91. this._nodes.push(new MXNetNode(node, argumentMap));
  92. }
  93. });
  94. this._inputs = [];
  95. Object.keys(argumentMap).forEach((key) => {
  96. var argument = argumentMap[key];
  97. if ((!argument.inputs || argument.inputs.length == 0) &&
  98. (argument.outputs && argument.outputs.length == 1)) {
  99. this._inputs.push( { id: argument.outputs[0], name: argument.name });
  100. }
  101. });
  102. }
  103. get name() {
  104. return '';
  105. }
  106. get inputs() {
  107. return this._inputs.map((input) => {
  108. return {
  109. name: input.name,
  110. type: 'T',
  111. id: '[' + input.id.join(',') + ']'
  112. };
  113. });
  114. }
  115. get outputs() {
  116. return this._outputs.map((output) => {
  117. return {
  118. name: output.name,
  119. type: 'T',
  120. id: '[' + output.id.join(',') + ']'
  121. };
  122. });
  123. }
  124. get nodes() {
  125. return this._nodes;
  126. }
  127. static updateOutput(nodes, input) {
  128. var sourceNodeIndex = input[0];
  129. var sourceNode = nodes[sourceNodeIndex];
  130. var sourceOutputIndex = input[1];
  131. while (sourceOutputIndex >= sourceNode.outputs.length) {
  132. sourceNode.outputs.push([ sourceNodeIndex, sourceNode.outputs.length ]);
  133. }
  134. return [ sourceNodeIndex, sourceOutputIndex ];
  135. }
  136. }
  137. class MXNetNode {
  138. constructor(json, argumentMap) {
  139. this._operator = json.op;
  140. this._name = json.name;
  141. this._inputs = json.inputs;
  142. this._outputs = json.outputs;
  143. this._attributes = [];
  144. var attrs = json.attrs;
  145. if (!attrs) {
  146. attrs = json.attr;
  147. }
  148. if (!attrs) {
  149. attrs = json.param;
  150. }
  151. if (attrs) {
  152. Object.keys(attrs).forEach((key) => {
  153. var value = attrs[key];
  154. this._attributes.push(new MXNetAttribute(this, key, value));
  155. });
  156. }
  157. this._initializers = {};
  158. this._inputs.forEach((input) => {
  159. var argumentNodeIndex = input[0];
  160. var argument = argumentMap[argumentNodeIndex];
  161. if (argument) {
  162. if ((!argument.inputs || argument.inputs.length == 0) &&
  163. (argument.outputs && argument.outputs.length == 1)) {
  164. var prefix = this._name + '_';
  165. if (prefix.endsWith('_fwd_')) {
  166. prefix = prefix.slice(0, -4);
  167. }
  168. if (argument.name && argument.name.startsWith(prefix)) {
  169. var id = '[' + input.join(',') + ']';
  170. this._initializers[id] = new MXNetTensor(argument);
  171. delete argumentMap[argumentNodeIndex];
  172. }
  173. }
  174. }
  175. });
  176. }
  177. get operator() {
  178. return this._operator;
  179. }
  180. get category() {
  181. return MXNetOperatorMetadata.operatorMetadata.getOperatorCategory(this._operator);
  182. }
  183. get documentation() {
  184. return MXNetOperatorMetadata.operatorMetadata.getOperatorDocumentation(this.operator);
  185. }
  186. get name() {
  187. return this._name;
  188. }
  189. get inputs() {
  190. var inputs = this._inputs.map((inputs) => {
  191. return '[' + inputs.join(',') + ']';
  192. });
  193. var results = MXNetOperatorMetadata.operatorMetadata.getInputs(this._operator, inputs);
  194. results.forEach((input) => {
  195. input.connections.forEach((connection) => {
  196. var initializer = this._initializers[connection.id];
  197. if (initializer) {
  198. connection.type = initializer.type;
  199. connection.initializer = initializer;
  200. }
  201. });
  202. });
  203. return results;
  204. }
  205. get outputs() {
  206. var outputs = this._outputs.map((output) => {
  207. return '[' + output.join(',') + ']';
  208. });
  209. return MXNetOperatorMetadata.operatorMetadata.getOutputs(this._type, outputs);
  210. }
  211. get attributes() {
  212. return this._attributes;
  213. }
  214. }
  215. class MXNetAttribute {
  216. constructor(owner, name, value) {
  217. this._owner = owner;
  218. this._name = name;
  219. this._value = value;
  220. }
  221. get name() {
  222. return this._name;
  223. }
  224. get value() {
  225. return this._value;
  226. }
  227. get visible() {
  228. return MXNetOperatorMetadata.operatorMetadata.getAttributeVisible(this._owner.operator, this._name, this._value);
  229. }
  230. }
  231. class MXNetTensor {
  232. constructor(json) {
  233. this._json = json;
  234. this._type = '';
  235. var attrs = this._json.attrs;
  236. if (attrs) {
  237. var dtype = attrs.__dtype__;
  238. var shape = attrs.__shape__;
  239. if (dtype && shape) {
  240. dtype = dtype.replace('0', 'float');
  241. shape = shape.split(' ').join('').replace('(', '[').replace(')', ']');
  242. this._type = dtype + shape;
  243. }
  244. }
  245. }
  246. get name() {
  247. return this._json.name;
  248. }
  249. get kind() {
  250. return 'Initializer';
  251. }
  252. get type() {
  253. return this._type;
  254. }
  255. }
  256. class MXNetOperatorMetadata {
  257. static open(host, callback) {
  258. if (MXNetOperatorMetadata.operatorMetadata) {
  259. callback(null, MXNetOperatorMetadata.operatorMetadata);
  260. }
  261. else {
  262. host.request('/mxnet-metadata.json', (err, data) => {
  263. MXNetOperatorMetadata.operatorMetadata = new MXNetOperatorMetadata(data);
  264. callback(null, MXNetOperatorMetadata.operatorMetadata);
  265. });
  266. }
  267. }
  268. constructor(data) {
  269. this._map = {};
  270. if (data) {
  271. var items = JSON.parse(data);
  272. if (items) {
  273. items.forEach((item) => {
  274. if (item.name && item.schema)
  275. {
  276. var name = item.name;
  277. var schema = item.schema;
  278. this._map[name] = schema;
  279. }
  280. });
  281. }
  282. }
  283. }
  284. getOperatorCategory(operator) {
  285. var schema = this._map[operator];
  286. if (schema && schema.category) {
  287. return schema.category;
  288. }
  289. return null;
  290. }
  291. getInputs(type, inputs) {
  292. var results = [];
  293. var index = 0;
  294. var schema = this._map[type];
  295. if (schema && schema.inputs) {
  296. schema.inputs.forEach((inputDef) => {
  297. if (index < inputs.length || inputDef.option != 'optional') {
  298. var input = {};
  299. input.name = inputDef.name;
  300. input.type = inputDef.type;
  301. var count = (inputDef.option == 'variadic') ? (inputs.length - index) : 1;
  302. input.connections = [];
  303. inputs.slice(index, index + count).forEach((id) => {
  304. if (id != '' || inputDef.option != 'optional') {
  305. input.connections.push({ id: id});
  306. }
  307. });
  308. index += count;
  309. results.push(input);
  310. }
  311. });
  312. }
  313. else {
  314. inputs.slice(index).forEach((input) => {
  315. var name = (index == 0) ? 'input' : ('(' + index.toString() + ')');
  316. results.push({
  317. name: name,
  318. connections: [ { id: input } ]
  319. });
  320. index++;
  321. });
  322. }
  323. return results;
  324. }
  325. getOutputs(type, outputs) {
  326. var results = [];
  327. var index = 0;
  328. var schema = this._map[type];
  329. if (schema && schema.outputs) {
  330. schema.outputs.forEach((outputDef) => {
  331. if (index < outputs.length || outputDef.option != 'optional') {
  332. var output = {};
  333. output.name = outputDef.name;
  334. var count = (outputDef.option == 'variadic') ? (outputs.length - index) : 1;
  335. output.connections = outputs.slice(index, index + count).map((id) => {
  336. return { id: id };
  337. });
  338. index += count;
  339. results.push(output);
  340. }
  341. });
  342. }
  343. else {
  344. outputs.slice(index).forEach((output) => {
  345. var name = (index == 0) ? 'output' : ('(' + index.toString() + ')');
  346. results.push({
  347. name: name,
  348. connections: [ { id: output } ]
  349. });
  350. index++;
  351. });
  352. }
  353. return results;
  354. }
  355. getAttributeVisible(operator, name, value) {
  356. var schema = this._map[operator];
  357. if (schema && schema.attributes && schema.attributes.length > 0) {
  358. if (!schema.attributesMap) {
  359. schema.attributesMap = {};
  360. schema.attributes.forEach((attribute) => {
  361. schema.attributesMap[attribute.name] = attribute;
  362. });
  363. }
  364. var attribute = schema.attributesMap[name];
  365. if (attribute) {
  366. if (attribute.hasOwnProperty('visible')) {
  367. return attribute.visible;
  368. }
  369. if (attribute.hasOwnProperty('default')) {
  370. value = MXNetOperatorMetadata.formatTuple(value);
  371. return !MXNetOperatorMetadata.isEquivalent(attribute.default, value);
  372. }
  373. }
  374. }
  375. return true;
  376. }
  377. static formatTuple(value) {
  378. if (value.startsWith('(') && value.endsWith(')')) {
  379. var list = value.substring(1, value.length - 1).split(',');
  380. list = list.map(item => item.trim());
  381. if (list.length > 1) {
  382. if (list.every(item => item == list[0])) {
  383. list = [ list[0], '' ];
  384. }
  385. }
  386. return '(' + list.join(',') + ')';
  387. }
  388. return value;
  389. }
  390. static isEquivalent(a, b) {
  391. if (a === b) {
  392. return a !== 0 || 1 / a === 1 / b;
  393. }
  394. if (a == null || b == null) {
  395. return false;
  396. }
  397. if (a !== a) {
  398. return b !== b;
  399. }
  400. var type = typeof a;
  401. if (type !== 'function' && type !== 'object' && typeof b != 'object') {
  402. return false;
  403. }
  404. var className = toString.call(a);
  405. if (className !== toString.call(b)) {
  406. return false;
  407. }
  408. switch (className) {
  409. case '[object RegExp]':
  410. case '[object String]':
  411. return '' + a === '' + b;
  412. case '[object Number]':
  413. if (+a !== +a) {
  414. return +b !== +b;
  415. }
  416. return +a === 0 ? 1 / +a === 1 / b : +a === +b;
  417. case '[object Date]':
  418. case '[object Boolean]':
  419. return +a === +b;
  420. case '[object Array]':
  421. var length = a.length;
  422. if (length !== b.length) {
  423. return false;
  424. }
  425. while (length--) {
  426. if (!KerasOperatorMetadata.isEquivalent(a[length], b[length])) {
  427. return false;
  428. }
  429. }
  430. return true;
  431. }
  432. var keys = Object.keys(a);
  433. var size = keys.length;
  434. if (Object.keys(b).length != size) {
  435. return false;
  436. }
  437. while (size--) {
  438. var key = keys[size];
  439. if (!(b.hasOwnProperty(key) && KerasOperatorMetadata.isEquivalent(a[key], b[key]))) {
  440. return false;
  441. }
  442. }
  443. return true;
  444. }
  445. getOperatorDocumentation(operator) {
  446. var schema = this._map[operator];
  447. if (schema) {
  448. schema = JSON.parse(JSON.stringify(schema));
  449. schema.name = operator;
  450. if (schema.description) {
  451. schema.description = marked(schema.description);
  452. }
  453. if (schema.attributes) {
  454. schema.attributes.forEach((attribute) => {
  455. if (attribute.description) {
  456. attribute.description = marked(attribute.description);
  457. }
  458. });
  459. }
  460. if (schema.inputs) {
  461. schema.inputs.forEach((input) => {
  462. if (input.description) {
  463. input.description = marked(input.description);
  464. }
  465. });
  466. }
  467. if (schema.outputs) {
  468. schema.outputs.forEach((output) => {
  469. if (output.description) {
  470. output.description = marked(output.description);
  471. }
  472. });
  473. }
  474. var template = Handlebars.compile(operatorTemplate, 'utf-8');
  475. return template(schema);
  476. }
  477. return '';
  478. }
  479. }
  480. class MXNetError extends Error {
  481. constructor(message) {
  482. super(message);
  483. this.name = 'Error loading MXNet model.';
  484. }
  485. }