mslite.js 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. var mslite = mslite || {};
  2. var flatbuffers = flatbuffers || require('./flatbuffers');
  3. mslite.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. if (stream && stream.length >= 8) {
  7. const buffer = stream.peek(8);
  8. const reader = flatbuffers.BinaryReader.open(buffer);
  9. if (reader.identifier === '' || reader.identifier === 'MSL1' || reader.identifier === 'MSL2') {
  10. return 'mslite';
  11. }
  12. }
  13. return '';
  14. }
  15. open(context) {
  16. return context.require('./mslite-schema').then(() => {
  17. const stream = context.stream;
  18. const reader = flatbuffers.BinaryReader.open(stream);
  19. switch (reader.identifier) {
  20. case '':
  21. throw new mslite.Error('MSL0 format is deprecated.', false);
  22. case 'MSL1':
  23. throw new mslite.Error('MSL1 format is deprecated.', false);
  24. case 'MSL2':
  25. break;
  26. default:
  27. throw new mslite.Error("Unsupported file identifier '" + reader.identifier + "'.");
  28. }
  29. let model = null;
  30. try {
  31. mslite.schema = flatbuffers.get('mslite').mindspore.schema;
  32. model = mslite.schema.MetaGraph.create(reader);
  33. }
  34. catch (error) {
  35. const message = error && error.message ? error.message : error.toString();
  36. throw new mslite.Error('File format is not mslite.MetaGraph (' + message.replace(/\.$/, '') + ').');
  37. }
  38. return context.metadata('mslite-metadata.json').then((metadata) => {
  39. return new mslite.Model(metadata, model);
  40. });
  41. });
  42. }
  43. };
  44. mslite.Model = class {
  45. constructor(metadata, model) {
  46. this._name = model.name || '';
  47. this._format = model.version || '';
  48. this._graphs = [];
  49. const format = 'MindSpore Lite ';
  50. if (this._format.startsWith(format)) {
  51. const version = this._format.substring(format.length).replace(/^v/, '');
  52. this._format = format + 'v' + version;
  53. }
  54. const subgraphs = model.subGraph;
  55. if (Array.isArray(subgraphs)) {
  56. for (const subgraph of subgraphs) {
  57. this._graphs.push(new mslite.Graph(metadata, subgraph, model));
  58. }
  59. }
  60. else {
  61. this._graphs.push(new mslite.Graph(metadata, model, model));
  62. }
  63. }
  64. get name() {
  65. return this._name;
  66. }
  67. get format() {
  68. return this._format;
  69. }
  70. get graphs() {
  71. return this._graphs;
  72. }
  73. };
  74. mslite.Graph = class {
  75. constructor(metadata, subgraph, model) {
  76. this._name = subgraph.name || '';
  77. this._inputs = [];
  78. this._outputs = [];
  79. this._nodes = [];
  80. const args = model.allTensors.map((tensor, index) => {
  81. const name = tensor.name || index.toString();
  82. const data = tensor.data;
  83. const type = new mslite.TensorType(tensor.dataType, tensor.dims);
  84. const initializer = (data && data.length > 0) ? new mslite.Tensor(type, tensor.data) : null;
  85. return new mslite.Argument(name, tensor, initializer);
  86. });
  87. if (subgraph === model) {
  88. for (let i = 0; i < subgraph.inputIndex.length; i++) {
  89. const index = subgraph.inputIndex[i];
  90. this._inputs.push(new mslite.Parameter(i.toString(), true, [ args[index] ]));
  91. }
  92. for (let i = 0; i < subgraph.outputIndex.length; i++) {
  93. const index = subgraph.outputIndex[i];
  94. this._outputs.push(new mslite.Parameter(i.toString(), true, [ args[index] ]));
  95. }
  96. for (let i = 0; i < subgraph.nodes.length; i++) {
  97. this._nodes.push(new mslite.Node(metadata, subgraph.nodes[i], args));
  98. }
  99. }
  100. else {
  101. for (let i = 0; i < subgraph.inputIndices.length; i++) {
  102. const index = subgraph.inputIndices[i];
  103. this._inputs.push(new mslite.Parameter(i.toString(), true, [args[index]]));
  104. }
  105. for (let i = 0; i < subgraph.outputIndices.length; i++) {
  106. const index = subgraph.outputIndices[i];
  107. this._outputs.push(new mslite.Parameter(i.toString(), true, [args[index]]));
  108. }
  109. for (let i = 0; i < subgraph.nodeIndices.length; i++) {
  110. const nodeId = subgraph.nodeIndices[i];
  111. this._nodes.push(new mslite.Node(metadata, model.nodes[nodeId], args));
  112. }
  113. }
  114. }
  115. get name() {
  116. return this._name;
  117. }
  118. get inputs() {
  119. return this._inputs;
  120. }
  121. get outputs() {
  122. return this._outputs;
  123. }
  124. get nodes() {
  125. return this._nodes;
  126. }
  127. };
  128. mslite.Node = class {
  129. constructor(metadata, op, args) {
  130. this._name = op.name || '';
  131. this._type = { name: '?' };
  132. this._attributes = [];
  133. this._inputs = [];
  134. this._outputs = [];
  135. const data = op.primitive.value;
  136. if (data && data.constructor) {
  137. const type = data.constructor.name;
  138. this._type = metadata.type(type);
  139. this._attributes = Object.keys(data).map((key) => new mslite.Attribute(metadata.attribute(type, key), key.toString(), data[key]));
  140. }
  141. const input_num = op.inputIndex.length;
  142. let i = 0;
  143. if (this._type && this._type.inputs){
  144. for (const input of this._type.inputs) {
  145. if (i >= input_num) {
  146. break;
  147. }
  148. const index = op.inputIndex[i];
  149. this._inputs.push(new mslite.Parameter(input.name, true, [ args[index] ]));
  150. i += 1;
  151. }
  152. }
  153. for (let j = i; j < input_num; j++) {
  154. const index = op.inputIndex[j];
  155. this._inputs.push(new mslite.Parameter(j.toString(), true, [ args[index] ]));
  156. }
  157. const output_num = op.outputIndex.length;
  158. i = 0;
  159. if (this._type && this._type.outputs){
  160. for (const output of this._type.outputs) {
  161. if (i >= output_num) {
  162. break;
  163. }
  164. const index = op.outputIndex[i];
  165. this._outputs.push(new mslite.Parameter(output.name, true, [ args[index] ]));
  166. i += 1;
  167. }
  168. }
  169. for (let j = i; j < output_num; j++) {
  170. const index = op.outputIndex[j];
  171. this._outputs.push(new mslite.Parameter(j.toString(), true, [ args[index] ]));
  172. }
  173. }
  174. get name() {
  175. return this._name;
  176. }
  177. get type() {
  178. return this._type;
  179. }
  180. get inputs() {
  181. return this._inputs;
  182. }
  183. get outputs() {
  184. return this._outputs;
  185. }
  186. get attributes() {
  187. return this._attributes;
  188. }
  189. };
  190. mslite.Attribute = class {
  191. constructor(schema, attrName, value) {
  192. this._type = null;
  193. this._name = attrName;
  194. this._visible = false;
  195. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  196. if (schema) {
  197. if (schema.type) {
  198. this._type = schema.type;
  199. if (this._type) {
  200. this._value = mslite.Utility.enum(this._type, this._value);
  201. }
  202. }
  203. }
  204. }
  205. get name() {
  206. return this._name;
  207. }
  208. get type() {
  209. return this._type;
  210. }
  211. get value() {
  212. return this._value;
  213. }
  214. get visible() {
  215. return this._visible !== false;
  216. }
  217. };
  218. mslite.Parameter = class {
  219. constructor(name, visible, args) {
  220. this._name = name;
  221. this._visible = visible;
  222. this._arguments = args;
  223. }
  224. get name() {
  225. return this._name;
  226. }
  227. get visible() {
  228. return this._visible;
  229. }
  230. get arguments() {
  231. return this._arguments;
  232. }
  233. };
  234. mslite.Argument = class {
  235. constructor(name, tensor, initializer) {
  236. this._name = name;
  237. this._type = initializer ? null : new mslite.TensorType(tensor.dataType, tensor.dims);
  238. this._initializer = initializer || null;
  239. if (tensor.quantParams) {
  240. const list = [];
  241. for (let i = 0; i < tensor.quantParams.length; i++) {
  242. const param = tensor.quantParams[i];
  243. if (param.scale !== 0 || param.zeroPoint !== 0) {
  244. list.push((param.scale !== 1 ? param.scale.toString() + ' * ' : '') + 'q' + (param.zeroPoint !== 0 ? ' + ' + param.zeroPoint.toString() : ''));
  245. }
  246. }
  247. if (list.length > 0 && !list.every((value) => value === 'q')) {
  248. this._quantization = list.length === 1 ? list[0] : list;
  249. }
  250. }
  251. }
  252. get name() {
  253. return this._name;
  254. }
  255. get type() {
  256. if (this._initializer) {
  257. return this._initializer.type;
  258. }
  259. return this._type;
  260. }
  261. get initializer() {
  262. return this._initializer;
  263. }
  264. get quantization() {
  265. return this._quantization;
  266. }
  267. };
  268. mslite.Tensor = class {
  269. constructor(type, data) {
  270. this._type = type;
  271. this._data = data || null;
  272. }
  273. get type() {
  274. return this._type;
  275. }
  276. get state() {
  277. return this._context().state;
  278. }
  279. get value() {
  280. const context = this._context();
  281. if (context.state) {
  282. return null;
  283. }
  284. context.limit = Number.MAX_SAFE_INTEGER;
  285. return this._decode(context, 0);
  286. }
  287. toString() {
  288. const context = this._context();
  289. if (context.state) {
  290. return '';
  291. }
  292. context.limit = 10000;
  293. const value = this._decode(context, 0);
  294. return JSON.stringify(value, null, 4);
  295. }
  296. _context() {
  297. const context = {};
  298. context.state = null;
  299. context.index = 0;
  300. context.count = 0;
  301. if (this._data == null || this._data.length === 0) {
  302. context.state = 'Tensor data is empty.';
  303. return context;
  304. }
  305. context.dataType = this._type.dataType;
  306. context.shape = this._type.shape.dimensions;
  307. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  308. if (this._type.dataType === 'string') {
  309. let offset = 0;
  310. const count = context.data.getInt32(0, true);
  311. offset += 4;
  312. const offsetTable = [];
  313. for (let j = 0; j < count; j++) {
  314. offsetTable.push(context.data.getInt32(offset, true));
  315. offset += 4;
  316. }
  317. offsetTable.push(this._data.length);
  318. const stringTable = [];
  319. const utf8Decoder = new TextDecoder('utf-8');
  320. for (let k = 0; k < count; k++) {
  321. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  322. stringTable.push(utf8Decoder.decode(textArray));
  323. }
  324. context.data = stringTable;
  325. }
  326. return context;
  327. }
  328. _decode(context, dimension) {
  329. const shape = (context.shape.length === 0) ? [ 1 ] : context.shape;
  330. const size = shape[dimension];
  331. const results = [];
  332. if (dimension === shape.length - 1) {
  333. for (let i = 0; i < size; i++) {
  334. if (context.count > context.limit) {
  335. results.push('...');
  336. return results;
  337. }
  338. switch (context.dataType) {
  339. case 'uint8':
  340. results.push(context.data.getUint8(context.index));
  341. context.index += 1;
  342. context.count++;
  343. break;
  344. case 'int8':
  345. results.push(context.data.getInt8(context.index));
  346. context.index += 1;
  347. context.count++;
  348. break;
  349. case 'int16':
  350. results.push(context.data.getInt16(context.index));
  351. context.index += 2;
  352. context.count++;
  353. break;
  354. case 'int32':
  355. results.push(context.data.getInt32(context.index, true));
  356. context.index += 4;
  357. context.count++;
  358. break;
  359. case 'int64':
  360. results.push(context.data.getInt64(context.index, true));
  361. context.index += 8;
  362. context.count++;
  363. break;
  364. case 'float16':
  365. results.push(context.data.getFloat16(context.index, true));
  366. context.index += 2;
  367. context.count++;
  368. break;
  369. case 'float32':
  370. results.push(context.data.getFloat32(context.index, true));
  371. context.index += 4;
  372. context.count++;
  373. break;
  374. case 'float64':
  375. results.push(context.data.getFloat64(context.index, true));
  376. context.index += 8;
  377. context.count++;
  378. break;
  379. case 'string':
  380. results.push(context.data[context.index++]);
  381. context.count++;
  382. break;
  383. default:
  384. break;
  385. }
  386. }
  387. }
  388. else {
  389. for (let j = 0; j < size; j++) {
  390. if (context.count > context.limit) {
  391. results.push('...');
  392. return results;
  393. }
  394. results.push(this._decode(context, dimension + 1));
  395. }
  396. }
  397. if (context.shape.length === 0) {
  398. return results[0];
  399. }
  400. return results;
  401. }
  402. };
  403. mslite.TensorType = class {
  404. constructor(dataType, dimensions) {
  405. switch (dataType) {
  406. case 0: this._dataType = "?"; break;
  407. case 1: this._dataType = "type"; break;
  408. case 2: this._dataType = "any"; break;
  409. case 3: this._dataType = "object"; break;
  410. case 4: this._dataType = "typetype"; break;
  411. case 5: this._dataType = "problem"; break;
  412. case 6: this._dataType = "external"; break;
  413. case 7: this._dataType = "none"; break;
  414. case 8: this._dataType = "null"; break;
  415. case 9: this._dataType = "ellipsis"; break;
  416. case 11: this._dataType = "number"; break;
  417. case 12: this._dataType = "string"; break;
  418. case 13: this._dataType = "list"; break;
  419. case 14: this._dataType = "tuple"; break;
  420. case 15: this._dataType = "slice"; break;
  421. case 16: this._dataType = "keyword"; break;
  422. case 17: this._dataType = "tensortype"; break;
  423. case 18: this._dataType = "rowtensortype"; break;
  424. case 19: this._dataType = "sparsetensortype"; break;
  425. case 20: this._dataType = "undeterminedtype"; break;
  426. case 21: this._dataType = "class"; break;
  427. case 22: this._dataType = "dictionary"; break;
  428. case 23: this._dataType = "function"; break;
  429. case 24: this._dataType = "jtagged"; break;
  430. case 25: this._dataType = "symbolickeytype"; break;
  431. case 26: this._dataType = "envtype"; break;
  432. case 27: this._dataType = "refkey"; break;
  433. case 28: this._dataType = "ref"; break;
  434. case 30: this._dataType = "boolean"; break;
  435. case 31: this._dataType = "int"; break;
  436. case 32: this._dataType = "int8"; break;
  437. case 33: this._dataType = "int16"; break;
  438. case 34: this._dataType = "int32"; break;
  439. case 35: this._dataType = "int64"; break;
  440. case 36: this._dataType = "uint"; break;
  441. case 37: this._dataType = "uint8"; break;
  442. case 38: this._dataType = "uint16"; break;
  443. case 39: this._dataType = "uint32"; break;
  444. case 40: this._dataType = "uint64"; break;
  445. case 41: this._dataType = "float"; break;
  446. case 42: this._dataType = "float16"; break;
  447. case 43: this._dataType = "float32"; break;
  448. case 44: this._dataType = "float64"; break;
  449. case 45: this._dataType = "complex64"; break;
  450. default: throw new mslite.Error("Unsupported data type '" + dataType.toString() + "'.");
  451. }
  452. this._shape = new mslite.TensorShape(Array.from(dimensions));
  453. }
  454. get dataType() {
  455. return this._dataType;
  456. }
  457. get shape() {
  458. return this._shape;
  459. }
  460. toString() {
  461. return this.dataType + this._shape.toString();
  462. }
  463. };
  464. mslite.TensorShape = class {
  465. constructor(dimensions) {
  466. this._dimensions = dimensions;
  467. }
  468. get dimensions() {
  469. return this._dimensions;
  470. }
  471. toString() {
  472. if (this._dimensions && this._dimensions.length > 0) {
  473. return '[' + this._dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']';
  474. }
  475. return '';
  476. }
  477. };
  478. mslite.Utility = class {
  479. static enum(name, value) {
  480. const type = name && mslite.schema ? mslite.schema[name] : undefined;
  481. if (type) {
  482. mslite.Utility._enumKeyMap = mslite.Utility._enumKeyMap || new Map();
  483. if (!mslite.Utility._enumKeyMap.has(name)) {
  484. const map = new Map();
  485. for (const key of Object.keys(type)) {
  486. map.set(type[key], key);
  487. }
  488. mslite.Utility._enumKeyMap.set(name, map);
  489. }
  490. const map = mslite.Utility._enumKeyMap.get(name);
  491. if (map.has(value)) {
  492. return map.get(value);
  493. }
  494. }
  495. return value;
  496. }
  497. };
  498. mslite.Error = class extends Error {
  499. constructor(message, context) {
  500. super(message);
  501. this.name = 'Error loading MindSpore Lite model.';
  502. this.context = context === false ? false : true;
  503. }
  504. };
  505. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  506. module.exports.ModelFactory = mslite.ModelFactory;
  507. }