mslite.js 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. const mslite = {};
  2. mslite.ModelFactory = class {
  3. async match(context) {
  4. const extension = context.identifier.split('.').pop().toLowerCase();
  5. const reader = await context.peek('flatbuffers.binary');
  6. if (reader) {
  7. const identifier = reader.identifier;
  8. if (identifier === 'MSL1' || identifier === 'MSL2' || (identifier === '' && extension === 'ms')) {
  9. return context.set('mslite', reader);
  10. }
  11. }
  12. return null;
  13. }
  14. async open(context) {
  15. const reader = context.value;
  16. switch (reader.identifier) {
  17. case '': {
  18. throw new mslite.Error('MSL0 format is deprecated.');
  19. }
  20. case 'MSL1': {
  21. throw new mslite.Error('MSL1 format is deprecated.');
  22. }
  23. case 'MSL2':
  24. break;
  25. default:
  26. throw new mslite.Error(`Unsupported file identifier '${reader.identifier}'.`);
  27. }
  28. mslite.schema = await context.require('./mslite-schema');
  29. mslite.schema = mslite.schema.mindspore.schema;
  30. let model = null;
  31. try {
  32. model = mslite.schema.MetaGraph.create(reader);
  33. } catch (error) {
  34. const message = error && error.message ? error.message : error.toString();
  35. throw new mslite.Error(`File format is not mslite.MetaGraph (${message.replace(/\.$/, '')}).`);
  36. }
  37. const metadata = await context.metadata('mslite-metadata.json');
  38. return new mslite.Model(metadata, model);
  39. }
  40. };
  41. mslite.Model = class {
  42. constructor(metadata, model) {
  43. this.name = model.name || '';
  44. this.modules = [];
  45. const version = model.version ? model.version.match(/^.*(\d\.\d\.\d)$/) : null;
  46. this.format = `MindSpore Lite${version ? ` v${version[1]}` : ''}`;
  47. const subgraphs = model.subGraph;
  48. if (Array.isArray(subgraphs)) {
  49. for (const subgraph of subgraphs) {
  50. this.modules.push(new mslite.Graph(metadata, subgraph, model));
  51. }
  52. } else {
  53. const graph = new mslite.Graph(metadata, model, model);
  54. this.modules.push(graph);
  55. }
  56. }
  57. };
  58. mslite.Graph = class {
  59. constructor(metadata, subgraph, model) {
  60. this.name = subgraph.name || '';
  61. this.inputs = [];
  62. this.outputs = [];
  63. this.nodes = [];
  64. const values = model.allTensors.map((tensor, index) => {
  65. const name = tensor.name || index.toString();
  66. const data = tensor.data;
  67. const type = new mslite.TensorType(tensor.dataType, tensor.dims);
  68. const initializer = (data && data.length > 0) ? new mslite.Tensor(type, tensor.data) : null;
  69. return new mslite.Value(name, tensor, initializer);
  70. });
  71. if (subgraph === model) {
  72. for (let i = 0; i < subgraph.inputIndex.length; i++) {
  73. const index = subgraph.inputIndex[i];
  74. this.inputs.push(new mslite.Argument(i.toString(), [values[index]]));
  75. }
  76. for (let i = 0; i < subgraph.outputIndex.length; i++) {
  77. const index = subgraph.outputIndex[i];
  78. this.outputs.push(new mslite.Argument(i.toString(), [values[index]]));
  79. }
  80. for (let i = 0; i < subgraph.nodes.length; i++) {
  81. this.nodes.push(new mslite.Node(metadata, subgraph.nodes[i], values));
  82. }
  83. } else {
  84. for (let i = 0; i < subgraph.inputIndices.length; i++) {
  85. const index = subgraph.inputIndices[i];
  86. this.inputs.push(new mslite.Argument(i.toString(), [values[index]]));
  87. }
  88. for (let i = 0; i < subgraph.outputIndices.length; i++) {
  89. const index = subgraph.outputIndices[i];
  90. this.outputs.push(new mslite.Argument(i.toString(), [values[index]]));
  91. }
  92. for (const name of subgraph.nodeIndices) {
  93. const node = new mslite.Node(metadata, model.nodes[name], values);
  94. this.nodes.push(node);
  95. }
  96. }
  97. }
  98. };
  99. mslite.Node = class {
  100. constructor(metadata, op, values) {
  101. this.name = op.name || '';
  102. this.type = { name: '?' };
  103. this.attributes = [];
  104. this.inputs = [];
  105. this.outputs = [];
  106. const data = op.primitive.value;
  107. if (data && data.constructor) {
  108. const type = data.constructor.name;
  109. this.type = metadata.type(type);
  110. this.attributes = Object.entries(data).map(([key, obj]) => {
  111. let value = ArrayBuffer.isView(obj) ? Array.from(obj) : obj;
  112. let type = null;
  113. const schema = metadata.attribute(this.type.name, key);
  114. if (schema && schema.type) {
  115. type = schema.type;
  116. value = type ? mslite.Utility.enum(type, value) : value;
  117. }
  118. return new mslite.Argument(key.toString(), value, type);
  119. });
  120. }
  121. const input_num = op.inputIndex.length;
  122. let i = 0;
  123. if (this.type && this.type.inputs) {
  124. for (const input of this.type.inputs) {
  125. if (i >= input_num) {
  126. break;
  127. }
  128. const index = op.inputIndex[i];
  129. const argument = new mslite.Argument(input.name, [values[index]]);
  130. this.inputs.push(argument);
  131. i += 1;
  132. }
  133. }
  134. for (let j = i; j < input_num; j++) {
  135. const index = op.inputIndex[j];
  136. const argument = new mslite.Argument(j.toString(), [values[index]]);
  137. this.inputs.push(argument);
  138. }
  139. const output_num = op.outputIndex.length;
  140. i = 0;
  141. if (this.type && this.type.outputs) {
  142. for (const output of this.type.outputs) {
  143. if (i >= output_num) {
  144. break;
  145. }
  146. const index = op.outputIndex[i];
  147. const argument = new mslite.Argument(output.name, [values[index]]);
  148. this.outputs.push(argument);
  149. i += 1;
  150. }
  151. }
  152. for (let j = i; j < output_num; j++) {
  153. const index = op.outputIndex[j];
  154. const argument = new mslite.Argument(j.toString(), [values[index]]);
  155. this.outputs.push(argument);
  156. }
  157. }
  158. };
  159. mslite.Argument = class {
  160. constructor(name, value, type = null) {
  161. this.name = name;
  162. this.value = value;
  163. this.type = type;
  164. }
  165. };
  166. mslite.Value = class {
  167. constructor(name, tensor, initializer = null) {
  168. this.name = name;
  169. this.type = initializer ? initializer.type : new mslite.TensorType(tensor.dataType, tensor.dims);
  170. this.initializer = initializer;
  171. if (Array.isArray(tensor.quantParams) && tensor.quantParams.length > 0) {
  172. this.quantization = {
  173. type: 'linear',
  174. scale: [],
  175. offset: []
  176. };
  177. for (let i = 0; i < tensor.quantParams.length; i++) {
  178. const param = tensor.quantParams[i];
  179. this.quantization.scale.push(param.scale);
  180. this.quantization.offset.push(param.zeroPoint);
  181. }
  182. }
  183. }
  184. };
  185. mslite.Tensor = class {
  186. constructor(type, data = null) {
  187. this.type = type;
  188. this.encoding = type.dataType === 'string' ? '|' : '<';
  189. this._data = data;
  190. }
  191. get values() {
  192. switch (this.type.dataType) {
  193. case 'string': {
  194. let offset = 0;
  195. const data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  196. const count = data.getInt32(0, true);
  197. offset += 4;
  198. const offsetTable = [];
  199. for (let j = 0; j < count; j++) {
  200. offsetTable.push(data.getInt32(offset, true));
  201. offset += 4;
  202. }
  203. offsetTable.push(this._data.length);
  204. const stringTable = [];
  205. const utf8Decoder = new TextDecoder('utf-8');
  206. for (let k = 0; k < count; k++) {
  207. const textArray = this._data.subarray(offsetTable[k], offsetTable[k + 1]);
  208. stringTable.push(utf8Decoder.decode(textArray));
  209. }
  210. return stringTable;
  211. }
  212. default: return this._data;
  213. }
  214. }
  215. };
  216. mslite.TensorType = class {
  217. constructor(dataType, dimensions) {
  218. switch (dataType) {
  219. case 0: this.dataType = "?"; break;
  220. case 1: this.dataType = "type"; break;
  221. case 2: this.dataType = "any"; break;
  222. case 3: this.dataType = "object"; break;
  223. case 4: this.dataType = "typetype"; break;
  224. case 5: this.dataType = "problem"; break;
  225. case 6: this.dataType = "external"; break;
  226. case 7: this.dataType = "none"; break;
  227. case 8: this.dataType = "null"; break;
  228. case 9: this.dataType = "ellipsis"; break;
  229. case 11: this.dataType = "number"; break;
  230. case 12: this.dataType = "string"; break;
  231. case 13: this.dataType = "list"; break;
  232. case 14: this.dataType = "tuple"; break;
  233. case 15: this.dataType = "slice"; break;
  234. case 16: this.dataType = "keyword"; break;
  235. case 17: this.dataType = "tensortype"; break;
  236. case 18: this.dataType = "rowtensortype"; break;
  237. case 19: this.dataType = "sparsetensortype"; break;
  238. case 20: this.dataType = "undeterminedtype"; break;
  239. case 21: this.dataType = "class"; break;
  240. case 22: this.dataType = "dictionary"; break;
  241. case 23: this.dataType = "function"; break;
  242. case 24: this.dataType = "jtagged"; break;
  243. case 25: this.dataType = "symbolickeytype"; break;
  244. case 26: this.dataType = "envtype"; break;
  245. case 27: this.dataType = "refkey"; break;
  246. case 28: this.dataType = "ref"; break;
  247. case 30: this.dataType = "boolean"; break;
  248. // case 31: this.dataType = "int"; break;
  249. case 32: this.dataType = "int8"; break;
  250. case 33: this.dataType = "int16"; break;
  251. case 34: this.dataType = "int32"; break;
  252. case 35: this.dataType = "int64"; break;
  253. // case 36: this.dataType = "uint"; break;
  254. case 37: this.dataType = "uint8"; break;
  255. case 38: this.dataType = "uint16"; break;
  256. case 39: this.dataType = "uint32"; break;
  257. case 40: this.dataType = "uint64"; break;
  258. // case 41: this.dataType = "float"; break;
  259. case 42: this.dataType = "float16"; break;
  260. case 43: this.dataType = "float32"; break;
  261. case 44: this.dataType = "float64"; break;
  262. case 45: this.dataType = "bfloat16"; break;
  263. // case 46: this.dataType = "double"; break;
  264. // case 47: this.dataType = "complex"; break;
  265. case 48: this.dataType = "complex64"; break;
  266. case 49: this.dataType = "complex128"; break;
  267. case 50: this.dataType = "int4"; break;
  268. default: throw new mslite.Error(`Unsupported data type '${dataType}'.`);
  269. }
  270. this.shape = new mslite.TensorShape(Array.from(dimensions));
  271. }
  272. toString() {
  273. return this.dataType + this.shape.toString();
  274. }
  275. };
  276. mslite.TensorShape = class {
  277. constructor(dimensions) {
  278. this.dimensions = dimensions;
  279. }
  280. toString() {
  281. if (this.dimensions && this.dimensions.length > 0) {
  282. return `[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`;
  283. }
  284. return '';
  285. }
  286. };
  287. mslite.Utility = class {
  288. static enum(name, value) {
  289. mslite.Utility._enumKeyMap = mslite.Utility._enumKeyMap || new Map();
  290. if (!mslite.Utility._enumKeyMap.has(name)) {
  291. const type = name && mslite.schema ? mslite.schema[name] : undefined;
  292. if (type) {
  293. if (!mslite.Utility._enumKeyMap.has(name)) {
  294. const entries = new Map(Object.entries(type).map(([key, value]) => [value, key]));
  295. mslite.Utility._enumKeyMap.set(name, entries);
  296. }
  297. }
  298. }
  299. const map = mslite.Utility._enumKeyMap.get(name);
  300. if (map && map.has(value)) {
  301. return map.get(value);
  302. }
  303. return value;
  304. }
  305. };
  306. mslite.Error = class extends Error {
  307. constructor(message) {
  308. super(message);
  309. this.name = 'Error loading MindSpore Lite model.';
  310. }
  311. };
  312. export const ModelFactory = mslite.ModelFactory;