mslite.js 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. /* jshint esversion: 6 */
  2. var mslite = mslite || {};
  3. var flatbuffers = flatbuffers || require('./flatbuffers');
  4. mslite.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. if (stream.length >= 8) {
  8. const buffer = stream.peek(8);
  9. const reader = flatbuffers.BinaryReader.open(buffer);
  10. if (reader.identifier === '' || reader.identifier === 'MSL1' || reader.identifier === 'MSL2') {
  11. return true;
  12. }
  13. }
  14. return false;
  15. }
  16. open(context) {
  17. return context.require('./mslite-schema').then(() => {
  18. const stream = context.stream;
  19. const reader = flatbuffers.BinaryReader.open(stream);
  20. switch (reader.identifier) {
  21. case '':
  22. throw new mslite.Error('MSL0 format is deprecated.', false);
  23. case 'MSL1':
  24. throw new mslite.Error('MSL1 format is deprecated.', false);
  25. case 'MSL2':
  26. break;
  27. }
  28. let model = null;
  29. try {
  30. mslite.schema = flatbuffers.get('mslite').mindspore.schema;
  31. model = mslite.schema.MetaGraph.create(reader);
  32. }
  33. catch (error) {
  34. const message = error && error.message ? error.message : error.toString();
  35. throw new mslite.Error('File format is not mslite.MetaGraph (' + message.replace(/\.$/, '') + ').');
  36. }
  37. return mslite.Metadata.open(context).then((metadata) => {
  38. return new mslite.Model(metadata, model);
  39. });
  40. });
  41. }
  42. };
  43. mslite.Model = class {
  44. constructor(metadata, model) {
  45. this._name = model.name || '';
  46. this._format = model.version || '';
  47. this._graphs = [];
  48. const format = 'MindSpore Lite ';
  49. if (this._format.startsWith(format)) {
  50. const version = this._format.substring(format.length).replace(/^v/, '');
  51. this._format = format + 'v' + version;
  52. }
  53. const subgraphs = model.subGraph;
  54. if (Array.isArray(subgraphs)) {
  55. for (const subgraph of subgraphs) {
  56. this._graphs.push(new mslite.Graph(metadata, subgraph, model));
  57. }
  58. }
  59. else {
  60. this._graphs.push(new mslite.Graph(metadata, model, model));
  61. }
  62. }
  63. get name() {
  64. return this._name;
  65. }
  66. get format() {
  67. return this._format;
  68. }
  69. get graphs() {
  70. return this._graphs;
  71. }
  72. };
  73. mslite.Graph = class {
  74. constructor(metadata, subgraph, model) {
  75. this._name = subgraph.name || '';
  76. this._inputs = [];
  77. this._outputs = [];
  78. this._nodes = [];
  79. const args = model.allTensors.map((tensor, index) => {
  80. const name = tensor.name || index.toString();
  81. const data = tensor.data;
  82. const type = new mslite.TensorType(tensor.dataType, tensor.dims);
  83. const initializer = (data && data.length > 0) ? new mslite.Tensor(type, tensor.data) : null;
  84. return new mslite.Argument(name, tensor, initializer);
  85. });
  86. if (subgraph === model) {
  87. for (let i = 0; i < subgraph.inputIndex.length; i++) {
  88. const index = subgraph.inputIndex[i];
  89. this._inputs.push(new mslite.Parameter(i.toString(), true, [ args[index] ]));
  90. }
  91. for (let i = 0; i < subgraph.outputIndex.length; i++) {
  92. const index = subgraph.outputIndex[i];
  93. this._outputs.push(new mslite.Parameter(i.toString(), true, [ args[index] ]));
  94. }
  95. for (let i = 0; i < subgraph.nodes.length; i++) {
  96. this._nodes.push(new mslite.Node(metadata, subgraph.nodes[i], args));
  97. }
  98. }
  99. else {
  100. for (let i = 0; i < subgraph.inputIndices.length; i++) {
  101. const index = subgraph.inputIndices[i];
  102. this._inputs.push(new mslite.Parameter(i.toString(), true, [args[index]]));
  103. }
  104. for (let i = 0; i < subgraph.outputIndices.length; i++) {
  105. const index = subgraph.outputIndices[i];
  106. this._outputs.push(new mslite.Parameter(i.toString(), true, [args[index]]));
  107. }
  108. for (let i = 0; i < subgraph.nodeIndices.length; i++) {
  109. const nodeId = subgraph.nodeIndices[i];
  110. this._nodes.push(new mslite.Node(metadata, model.nodes[nodeId], args));
  111. }
  112. }
  113. }
  114. get name() {
  115. return this._name;
  116. }
  117. get groups() {
  118. return false;
  119. }
  120. get inputs() {
  121. return this._inputs;
  122. }
  123. get outputs() {
  124. return this._outputs;
  125. }
  126. get nodes() {
  127. return this._nodes;
  128. }
  129. };
  130. mslite.Node = class {
  131. constructor(metadata, op, args) {
  132. this._name = op.name || '';
  133. this._type = { name: '?' };
  134. this._attributes = [];
  135. this._inputs = [];
  136. this._outputs = [];
  137. const data = op.primitive.value;
  138. if (data && data.constructor) {
  139. const type = data.constructor.name;
  140. this._type = metadata.type(type);
  141. this._attributes = Object.keys(data).map((key) => new mslite.Attribute(metadata.attribute(type, key), key.toString(), data[key]));
  142. }
  143. const input_num = op.inputIndex.length;
  144. let i = 0;
  145. if (this._type && this._type.inputs){
  146. for (const input of this._type.inputs) {
  147. if (i >= input_num) {
  148. break;
  149. }
  150. const index = op.inputIndex[i];
  151. this._inputs.push(new mslite.Parameter(input.name, true, [ args[index] ]));
  152. i += 1;
  153. }
  154. }
  155. for (let j = i; j < input_num; j++) {
  156. const index = op.inputIndex[j];
  157. this._inputs.push(new mslite.Parameter(j.toString(), true, [ args[index] ]));
  158. }
  159. const output_num = op.outputIndex.length;
  160. i = 0;
  161. if (this._type && this._type.outputs){
  162. for (const output of this._type.outputs) {
  163. if (i >= output_num) {
  164. break;
  165. }
  166. const index = op.outputIndex[i];
  167. this._outputs.push(new mslite.Parameter(output.name, true, [ args[index] ]));
  168. i += 1;
  169. }
  170. }
  171. for (let j = i; j < output_num; j++) {
  172. const index = op.outputIndex[j];
  173. this._outputs.push(new mslite.Parameter(j.toString(), true, [ args[index] ]));
  174. }
  175. }
  176. get name() {
  177. return this._name;
  178. }
  179. get type() {
  180. return this._type;
  181. }
  182. get inputs() {
  183. return this._inputs;
  184. }
  185. get outputs() {
  186. return this._outputs;
  187. }
  188. get attributes() {
  189. return this._attributes;
  190. }
  191. };
  192. mslite.Attribute = class {
  193. constructor(schema, attrName, value) {
  194. this._type = null;
  195. this._name = attrName;
  196. this._visible = false;
  197. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  198. if (schema) {
  199. if (schema.type) {
  200. this._type = schema.type;
  201. if (this._type) {
  202. this._value = mslite.Utility.enum(this._type, this._value);
  203. }
  204. }
  205. }
  206. }
  207. get name() {
  208. return this._name;
  209. }
  210. get type() {
  211. return this._type;
  212. }
  213. get value() {
  214. return this._value;
  215. }
  216. get visible() {
  217. return this._visible !== false;
  218. }
  219. };
  220. mslite.Parameter = class {
  221. constructor(name, visible, args) {
  222. this._name = name;
  223. this._visible = visible;
  224. this._arguments = args;
  225. }
  226. get name() {
  227. return this._name;
  228. }
  229. get visible() {
  230. return this._visible;
  231. }
  232. get arguments() {
  233. return this._arguments;
  234. }
  235. };
  236. mslite.Argument = class {
  237. constructor(name, tensor, initializer) {
  238. this._name = name;
  239. this._type = initializer ? null : new mslite.TensorType(tensor.dataType, tensor.dims);
  240. this._initializer = initializer || null;
  241. if (tensor.quantParams) {
  242. const params = [];
  243. for (let i = 0; i < tensor.quantParams.length; i++) {
  244. const param = tensor.quantParams[i];
  245. if (param.scale !== 0 || param.zeroPoint !== 0) {
  246. params.push(param.scale.toString() + ' * x + ' + param.zeroPoint.toString());
  247. }
  248. }
  249. this._quantization = params.join(' -> ');
  250. }
  251. }
  252. get name() {
  253. return this._name;
  254. }
  255. get type() {
  256. if (this._initializer) {
  257. return this._initializer.type;
  258. }
  259. return this._type;
  260. }
  261. get initializer() {
  262. return this._initializer;
  263. }
  264. get quantization() {
  265. return this._quantization;
  266. }
  267. };
  268. mslite.Tensor = class {
  269. constructor(type, data) {
  270. this._type = type;
  271. this._data = data || null;
  272. }
  273. get type() {
  274. return this._type;
  275. }
  276. get state() {
  277. return this._context().state;
  278. }
  279. get value() {
  280. const context = this._context();
  281. if (context.state) {
  282. return null;
  283. }
  284. context.limit = Number.MAX_SAFE_INTEGER;
  285. return this._decode(context, 0);
  286. }
  287. toString() {
  288. const context = this._context();
  289. if (context.state) {
  290. return '';
  291. }
  292. context.limit = 10000;
  293. const value = this._decode(context, 0);
  294. return JSON.stringify(value, null, 4);
  295. }
  296. _context() {
  297. const context = {};
  298. context.state = null;
  299. context.index = 0;
  300. context.count = 0;
  301. if (this._data == null || this._data.length === 0) {
  302. context.state = 'Tensor data is empty.';
  303. return context;
  304. }
  305. context.dataType = this._type.dataType;
  306. context.shape = this._type.shape.dimensions;
  307. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  308. if (this._type.dataType === 'string') {
  309. let offset = 0;
  310. const count = context.data.getInt32(0, true);
  311. offset += 4;
  312. const offsetTable = [];
  313. for (let j = 0; j < count; j++) {
  314. offsetTable.push(context.data.getInt32(offset, true));
  315. offset += 4;
  316. }
  317. offsetTable.push(this._data.length);
  318. const stringTable = [];
  319. const utf8Decoder = new TextDecoder('utf-8');
  320. for (let k = 0; k < count; k++) {
  321. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  322. stringTable.push(utf8Decoder.decode(textArray));
  323. }
  324. context.data = stringTable;
  325. }
  326. return context;
  327. }
  328. _decode(context, dimension) {
  329. const shape = (context.shape.length === 0) ? [ 1 ] : context.shape;
  330. const size = shape[dimension];
  331. const results = [];
  332. if (dimension === shape.length - 1) {
  333. for (let i = 0; i < size; i++) {
  334. if (context.count > context.limit) {
  335. results.push('...');
  336. return results;
  337. }
  338. switch (context.dataType) {
  339. case 'uint8':
  340. results.push(context.data.getUint8(context.index));
  341. context.index += 1;
  342. context.count++;
  343. break;
  344. case 'int8':
  345. results.push(context.data.getInt8(context.index));
  346. context.index += 1;
  347. context.count++;
  348. break;
  349. case 'int16':
  350. results.push(context.data.getInt16(context.index));
  351. context.index += 2;
  352. context.count++;
  353. break;
  354. case 'int32':
  355. results.push(context.data.getInt32(context.index, true));
  356. context.index += 4;
  357. context.count++;
  358. break;
  359. case 'int64':
  360. results.push(context.data.getInt64(context.index, true));
  361. context.index += 8;
  362. context.count++;
  363. break;
  364. case 'float16':
  365. results.push(context.data.getFloat16(context.index, true));
  366. context.index += 2;
  367. context.count++;
  368. break;
  369. case 'float32':
  370. results.push(context.data.getFloat32(context.index, true));
  371. context.index += 4;
  372. context.count++;
  373. break;
  374. case 'float64':
  375. results.push(context.data.getFloat64(context.index, true));
  376. context.index += 8;
  377. context.count++;
  378. break;
  379. case 'string':
  380. results.push(context.data[context.index++]);
  381. context.count++;
  382. break;
  383. default:
  384. break;
  385. }
  386. }
  387. }
  388. else {
  389. for (let j = 0; j < size; j++) {
  390. if (context.count > context.limit) {
  391. results.push('...');
  392. return results;
  393. }
  394. results.push(this._decode(context, dimension + 1));
  395. }
  396. }
  397. if (context.shape.length === 0) {
  398. return results[0];
  399. }
  400. return results;
  401. }
  402. };
  403. mslite.TensorType = class {
  404. constructor(dataType, dimensions) {
  405. switch (dataType) {
  406. case 0: this._dataType = "?"; break;
  407. case 1: this._dataType = "type"; break;
  408. case 2: this._dataType = "any"; break;
  409. case 3: this._dataType = "object"; break;
  410. case 4: this._dataType = "typetype"; break;
  411. case 5: this._dataType = "problem"; break;
  412. case 6: this._dataType = "external"; break;
  413. case 7: this._dataType = "none"; break;
  414. case 8: this._dataType = "null"; break;
  415. case 9: this._dataType = "ellipsis"; break;
  416. case 11: this._dataType = "number"; break;
  417. case 12: this._dataType = "string"; break;
  418. case 13: this._dataType = "list"; break;
  419. case 14: this._dataType = "tuple"; break;
  420. case 15: this._dataType = "slice"; break;
  421. case 16: this._dataType = "keyword"; break;
  422. case 17: this._dataType = "tensortype"; break;
  423. case 18: this._dataType = "rowtensortype"; break;
  424. case 19: this._dataType = "sparsetensortype"; break;
  425. case 20: this._dataType = "undeterminedtype"; break;
  426. case 21: this._dataType = "class"; break;
  427. case 22: this._dataType = "dictionary"; break;
  428. case 23: this._dataType = "function"; break;
  429. case 24: this._dataType = "jtagged"; break;
  430. case 25: this._dataType = "symbolickeytype"; break;
  431. case 26: this._dataType = "envtype"; break;
  432. case 27: this._dataType = "refkey"; break;
  433. case 28: this._dataType = "ref"; break;
  434. case 30: this._dataType = "boolean"; break;
  435. case 31: this._dataType = "int"; break;
  436. case 32: this._dataType = "int8"; break;
  437. case 33: this._dataType = "int16"; break;
  438. case 34: this._dataType = "int32"; break;
  439. case 35: this._dataType = "int64"; break;
  440. case 36: this._dataType = "uint"; break;
  441. case 37: this._dataType = "uint8"; break;
  442. case 38: this._dataType = "uint16"; break;
  443. case 39: this._dataType = "uint32"; break;
  444. case 40: this._dataType = "uint64"; break;
  445. case 41: this._dataType = "float"; break;
  446. case 42: this._dataType = "float16"; break;
  447. case 43: this._dataType = "float32"; break;
  448. case 44: this._dataType = "float64"; break;
  449. case 45: this._dataType = "complex64"; break;
  450. default:
  451. throw new mslite.Error("Unknown data type '" + dataType.toString() + "'.");
  452. }
  453. this._shape = new mslite.TensorShape(Array.from(dimensions));
  454. }
  455. get dataType() {
  456. return this._dataType;
  457. }
  458. get shape() {
  459. return this._shape;
  460. }
  461. toString() {
  462. return this.dataType + this._shape.toString();
  463. }
  464. };
  465. mslite.TensorShape = class {
  466. constructor(dimensions) {
  467. this._dimensions = dimensions;
  468. }
  469. get dimensions() {
  470. return this._dimensions;
  471. }
  472. toString() {
  473. if (this._dimensions && this._dimensions.length > 0) {
  474. return '[' + this._dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']';
  475. }
  476. return '';
  477. }
  478. };
  479. mslite.Metadata = class {
  480. static open(context) {
  481. if (mslite.Metadata._metadata) {
  482. return Promise.resolve(mslite.Metadata._metadata);
  483. }
  484. return context.request('mslite-metadata.json', 'utf-8', null).then((data) => {
  485. mslite.Metadata._metadata = new mslite.Metadata(data);
  486. return mslite.Metadata._metadata;
  487. }).catch(() => {
  488. mslite.Metadata._metadata = new mslite.Metadata(null);
  489. return mslite.Metadata._metadata;
  490. });
  491. }
  492. constructor(data) {
  493. this._map = new Map();
  494. if (data) {
  495. const metadata = JSON.parse(data);
  496. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  497. }
  498. }
  499. type(name) {
  500. return this._map.has(name) ? this._map.get(name) : null;
  501. }
  502. attribute(type, name) {
  503. const schema = this.type(type);
  504. if (schema) {
  505. let attributeMap = schema.attributeMap;
  506. if (!attributeMap) {
  507. attributeMap = {};
  508. if (schema.attributes) {
  509. for (const attribute of schema.attributes) {
  510. attributeMap[attribute.name] = attribute;
  511. }
  512. }
  513. schema.attributeMap = attributeMap;
  514. }
  515. const attributeSchema = attributeMap[name];
  516. if (attributeSchema) {
  517. return attributeSchema;
  518. }
  519. }
  520. return null;
  521. }
  522. };
  523. mslite.Utility = class {
  524. static enum(name, value) {
  525. const type = name && mslite.schema ? mslite.schema[name] : undefined;
  526. if (type) {
  527. mslite.Utility._enumKeyMap = mslite.Utility._enumKeyMap || new Map();
  528. if (!mslite.Utility._enumKeyMap.has(name)) {
  529. const map = new Map();
  530. for (const key of Object.keys(type)) {
  531. map.set(type[key], key);
  532. }
  533. mslite.Utility._enumKeyMap.set(name, map);
  534. }
  535. const map = mslite.Utility._enumKeyMap.get(name);
  536. if (map.has(value)) {
  537. return map.get(value);
  538. }
  539. }
  540. return value;
  541. }
  542. };
  543. mslite.Error = class extends Error {
  544. constructor(message, context) {
  545. super(message);
  546. this.name = 'Error loading MindSpore Lite model.';
  547. this.context = context === false ? false : true;
  548. }
  549. };
  550. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  551. module.exports.ModelFactory = mslite.ModelFactory;
  552. }