mslite.js 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. const mslite = {};
  2. mslite.ModelFactory = class {
  3. match(context) {
  4. const reader = context.peek('flatbuffers.binary');
  5. if (reader && (reader.identifier === '' || reader.identifier === 'MSL1' || reader.identifier === 'MSL2')) {
  6. context.type = 'mslite';
  7. context.target = reader;
  8. }
  9. }
  10. async open(context) {
  11. const reader = context.target;
  12. switch (reader.identifier) {
  13. case '': {
  14. throw new mslite.Error('MSL0 format is deprecated.');
  15. }
  16. case 'MSL1': {
  17. throw new mslite.Error('MSL1 format is deprecated.');
  18. }
  19. case 'MSL2':
  20. break;
  21. default:
  22. throw new mslite.Error(`Unsupported file identifier '${reader.identifier}'.`);
  23. }
  24. mslite.schema = await context.require('./mslite-schema');
  25. mslite.schema = mslite.schema.mindspore.schema;
  26. let model = null;
  27. try {
  28. model = mslite.schema.MetaGraph.create(reader);
  29. } catch (error) {
  30. const message = error && error.message ? error.message : error.toString();
  31. throw new mslite.Error(`File format is not mslite.MetaGraph (${message.replace(/\.$/, '')}).`);
  32. }
  33. const metadata = await context.metadata('mslite-metadata.json');
  34. return new mslite.Model(metadata, model);
  35. }
  36. };
  37. mslite.Model = class {
  38. constructor(metadata, model) {
  39. this.name = model.name || '';
  40. this.graphs = [];
  41. const version = model.version ? model.version.match(/^.*(\d\.\d\.\d)$/) : null;
  42. this.format = `MindSpore Lite${version ? ` v${version[1]}` : ''}`;
  43. const subgraphs = model.subGraph;
  44. if (!Array.isArray(subgraphs)) {
  45. this.graphs.push(new mslite.Graph(metadata, model, model));
  46. } else {
  47. for (const subgraph of subgraphs) {
  48. this.graphs.push(new mslite.Graph(metadata, subgraph, model));
  49. }
  50. }
  51. }
  52. };
  53. mslite.Graph = class {
  54. constructor(metadata, subgraph, model) {
  55. this.name = subgraph.name || '';
  56. this.inputs = [];
  57. this.outputs = [];
  58. this.nodes = [];
  59. const values = model.allTensors.map((tensor, index) => {
  60. const name = tensor.name || index.toString();
  61. const data = tensor.data;
  62. const type = new mslite.TensorType(tensor.dataType, tensor.dims);
  63. const initializer = (data && data.length > 0) ? new mslite.Tensor(type, tensor.data) : null;
  64. return new mslite.Value(name, tensor, initializer);
  65. });
  66. if (subgraph === model) {
  67. for (let i = 0; i < subgraph.inputIndex.length; i++) {
  68. const index = subgraph.inputIndex[i];
  69. this.inputs.push(new mslite.Argument(i.toString(), [ values[index] ]));
  70. }
  71. for (let i = 0; i < subgraph.outputIndex.length; i++) {
  72. const index = subgraph.outputIndex[i];
  73. this.outputs.push(new mslite.Argument(i.toString(), [ values[index] ]));
  74. }
  75. for (let i = 0; i < subgraph.nodes.length; i++) {
  76. this.nodes.push(new mslite.Node(metadata, subgraph.nodes[i], values));
  77. }
  78. } else {
  79. for (let i = 0; i < subgraph.inputIndices.length; i++) {
  80. const index = subgraph.inputIndices[i];
  81. this.inputs.push(new mslite.Argument(i.toString(), [ values[index] ]));
  82. }
  83. for (let i = 0; i < subgraph.outputIndices.length; i++) {
  84. const index = subgraph.outputIndices[i];
  85. this.outputs.push(new mslite.Argument(i.toString(), [ values[index] ]));
  86. }
  87. for (const name of subgraph.nodeIndices) {
  88. const node = new mslite.Node(metadata, model.nodes[name], values);
  89. this.nodes.push(node);
  90. }
  91. }
  92. }
  93. };
  94. mslite.Node = class {
  95. constructor(metadata, op, values) {
  96. this.name = op.name || '';
  97. this.type = { name: '?' };
  98. this.attributes = [];
  99. this.inputs = [];
  100. this.outputs = [];
  101. const data = op.primitive.value;
  102. if (data && data.constructor) {
  103. const type = data.constructor.name;
  104. this.type = metadata.type(type);
  105. this.attributes = Object.keys(data).map((key) => new mslite.Attribute(metadata.attribute(type, key), key.toString(), data[key]));
  106. }
  107. const input_num = op.inputIndex.length;
  108. let i = 0;
  109. if (this.type && this.type.inputs) {
  110. for (const input of this.type.inputs) {
  111. if (i >= input_num) {
  112. break;
  113. }
  114. const index = op.inputIndex[i];
  115. this.inputs.push(new mslite.Argument(input.name, [ values[index] ]));
  116. i += 1;
  117. }
  118. }
  119. for (let j = i; j < input_num; j++) {
  120. const index = op.inputIndex[j];
  121. this.inputs.push(new mslite.Argument(j.toString(), [ values[index] ]));
  122. }
  123. const output_num = op.outputIndex.length;
  124. i = 0;
  125. if (this.type && this.type.outputs) {
  126. for (const output of this.type.outputs) {
  127. if (i >= output_num) {
  128. break;
  129. }
  130. const index = op.outputIndex[i];
  131. const argument = new mslite.Argument(output.name, [ values[index] ]);
  132. this.outputs.push(argument);
  133. i += 1;
  134. }
  135. }
  136. for (let j = i; j < output_num; j++) {
  137. const index = op.outputIndex[j];
  138. const argument = new mslite.Argument(j.toString(), [ values[index] ]);
  139. this.outputs.push(argument);
  140. }
  141. }
  142. };
  143. mslite.Attribute = class {
  144. constructor(metadata, name, value) {
  145. this.type = null;
  146. this.name = name;
  147. this.visible = false;
  148. this.value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  149. if (metadata && metadata.type) {
  150. this.type = metadata.type;
  151. if (this.type) {
  152. this.value = mslite.Utility.enum(this.type, this.value);
  153. }
  154. }
  155. }
  156. };
  157. mslite.Argument = class {
  158. constructor(name, value) {
  159. this.name = name;
  160. this.value = value;
  161. }
  162. };
  163. mslite.Value = class {
  164. constructor(name, tensor, initializer) {
  165. this.name = name;
  166. this.type = initializer ? initializer.type : new mslite.TensorType(tensor.dataType, tensor.dims);
  167. this.initializer = initializer || null;
  168. if (tensor.quantParams) {
  169. const list = [];
  170. for (let i = 0; i < tensor.quantParams.length; i++) {
  171. const param = tensor.quantParams[i];
  172. if (param.scale !== 0 || param.zeroPoint !== 0) {
  173. const scale = param.scale;
  174. const zeroPoint = param.zeroPoint;
  175. let quantization = '';
  176. if (scale !== 1) {
  177. quantization += `${scale} * `;
  178. }
  179. if (zeroPoint === 0) {
  180. quantization += 'q';
  181. } else if (zeroPoint < 0) {
  182. quantization += `(q + ${-zeroPoint})`;
  183. } else if (zeroPoint > 0) {
  184. quantization += `(q - ${zeroPoint})`;
  185. }
  186. list.push(quantization);
  187. }
  188. }
  189. if (list.length > 0 && !list.every((value) => value === 'q')) {
  190. this.quantization = list.length === 1 ? list[0] : list;
  191. }
  192. }
  193. }
  194. };
  195. mslite.Tensor = class {
  196. constructor(type, data) {
  197. this.type = type;
  198. this.encoding = type.dataType === 'string' ? '|' : '<';
  199. this._data = data || null;
  200. }
  201. get values() {
  202. switch (this.type.dataType) {
  203. case 'string': {
  204. let offset = 0;
  205. const data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  206. const count = data.getInt32(0, true);
  207. offset += 4;
  208. const offsetTable = [];
  209. for (let j = 0; j < count; j++) {
  210. offsetTable.push(data.getInt32(offset, true));
  211. offset += 4;
  212. }
  213. offsetTable.push(this._data.length);
  214. const stringTable = [];
  215. const utf8Decoder = new TextDecoder('utf-8');
  216. for (let k = 0; k < count; k++) {
  217. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  218. stringTable.push(utf8Decoder.decode(textArray));
  219. }
  220. return stringTable;
  221. }
  222. default: return this._data;
  223. }
  224. }
  225. };
  226. mslite.TensorType = class {
  227. constructor(dataType, dimensions) {
  228. switch (dataType) {
  229. case 0: this.dataType = "?"; break;
  230. case 1: this.dataType = "type"; break;
  231. case 2: this.dataType = "any"; break;
  232. case 3: this.dataType = "object"; break;
  233. case 4: this.dataType = "typetype"; break;
  234. case 5: this.dataType = "problem"; break;
  235. case 6: this.dataType = "external"; break;
  236. case 7: this.dataType = "none"; break;
  237. case 8: this.dataType = "null"; break;
  238. case 9: this.dataType = "ellipsis"; break;
  239. case 11: this.dataType = "number"; break;
  240. case 12: this.dataType = "string"; break;
  241. case 13: this.dataType = "list"; break;
  242. case 14: this.dataType = "tuple"; break;
  243. case 15: this.dataType = "slice"; break;
  244. case 16: this.dataType = "keyword"; break;
  245. case 17: this.dataType = "tensortype"; break;
  246. case 18: this.dataType = "rowtensortype"; break;
  247. case 19: this.dataType = "sparsetensortype"; break;
  248. case 20: this.dataType = "undeterminedtype"; break;
  249. case 21: this.dataType = "class"; break;
  250. case 22: this.dataType = "dictionary"; break;
  251. case 23: this.dataType = "function"; break;
  252. case 24: this.dataType = "jtagged"; break;
  253. case 25: this.dataType = "symbolickeytype"; break;
  254. case 26: this.dataType = "envtype"; break;
  255. case 27: this.dataType = "refkey"; break;
  256. case 28: this.dataType = "ref"; break;
  257. case 30: this.dataType = "boolean"; break;
  258. case 31: this.dataType = "int"; break;
  259. case 32: this.dataType = "int8"; break;
  260. case 33: this.dataType = "int16"; break;
  261. case 34: this.dataType = "int32"; break;
  262. case 35: this.dataType = "int64"; break;
  263. case 36: this.dataType = "uint"; break;
  264. case 37: this.dataType = "uint8"; break;
  265. case 38: this.dataType = "uint16"; break;
  266. case 39: this.dataType = "uint32"; break;
  267. case 40: this.dataType = "uint64"; break;
  268. case 41: this.dataType = "float"; break;
  269. case 42: this.dataType = "float16"; break;
  270. case 43: this.dataType = "float32"; break;
  271. case 44: this.dataType = "float64"; break;
  272. case 45: this.dataType = "complex64"; break;
  273. default: throw new mslite.Error(`Unsupported data type '${dataType}'.`);
  274. }
  275. this.shape = new mslite.TensorShape(Array.from(dimensions));
  276. }
  277. toString() {
  278. return this.dataType + this.shape.toString();
  279. }
  280. };
  281. mslite.TensorShape = class {
  282. constructor(dimensions) {
  283. this.dimensions = dimensions;
  284. }
  285. toString() {
  286. if (this.dimensions && this.dimensions.length > 0) {
  287. return `[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`;
  288. }
  289. return '';
  290. }
  291. };
  292. mslite.Utility = class {
  293. static enum(name, value) {
  294. mslite.Utility._enumKeyMap = mslite.Utility._enumKeyMap || new Map();
  295. if (!mslite.Utility._enumKeyMap.has(name)) {
  296. const type = name && mslite.schema ? mslite.schema[name] : undefined;
  297. if (type) {
  298. if (!mslite.Utility._enumKeyMap.has(name)) {
  299. const entries = new Map(Object.entries(type).map(([key, value]) => [ value, key ]));
  300. mslite.Utility._enumKeyMap.set(name, entries);
  301. }
  302. }
  303. }
  304. const map = mslite.Utility._enumKeyMap.get(name);
  305. if (map && map.has(value)) {
  306. return map.get(value);
  307. }
  308. return value;
  309. }
  310. };
  311. mslite.Error = class extends Error {
  312. constructor(message) {
  313. super(message);
  314. this.name = 'Error loading MindSpore Lite model.';
  315. }
  316. };
  317. export const ModelFactory = mslite.ModelFactory;