mnn.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. const mnn = {};
  2. mnn.ModelFactory = class {
  3. async match(context) {
  4. const reader = await context.peek('flatbuffers.binary');
  5. if (reader) {
  6. return context.set('mnn.flatbuffers', reader);
  7. }
  8. const obj = await context.peek('json');
  9. if (obj && obj.sourceType && Array.isArray(obj.oplists) && Array.isArray(obj.tensorName)) {
  10. return context.set('mnn.flatbuffers.json', obj);
  11. }
  12. return null;
  13. }
  14. async open(context) {
  15. mnn.schema = await context.require('./mnn-schema');
  16. mnn.schema = mnn.schema.MNN;
  17. let net = null;
  18. switch (context.type) {
  19. case 'mnn.flatbuffers': {
  20. try {
  21. const reader = context.value;
  22. net = mnn.schema.Net.create(reader);
  23. } catch (error) {
  24. const message = error && error.message ? error.message : error.toString();
  25. throw new mnn.Error(`File format is not mnn.Net (${message.replace(/\.$/, '')}).`);
  26. }
  27. break;
  28. }
  29. case 'mnn.flatbuffers.json': {
  30. try {
  31. const reader = await context.read('flatbuffers.text');
  32. net = mnn.schema.Net.createText(reader);
  33. } catch (error) {
  34. const message = error && error.message ? error.message : error.toString();
  35. throw new mnn.Error(`File format is not mnn.Net (${message.replace(/\.$/, '')}).`);
  36. }
  37. break;
  38. }
  39. default: {
  40. throw new mnn.Error(`Unsupported TensorFlow Lite format '${context.type}'.`);
  41. }
  42. }
  43. const metadata = await context.metadata('mnn-metadata.json');
  44. return new mnn.Model(metadata, net);
  45. }
  46. };
  47. mnn.Model = class {
  48. constructor(metadata, net) {
  49. this.format = 'MNN v2';
  50. const sources = new Map([
  51. [mnn.schema.NetSource.CAFFE, 'Caffe'],
  52. [mnn.schema.NetSource.TENSORFLOW, 'TensorFlow'],
  53. [mnn.schema.NetSource.TFLITE, 'TensorFlow Lite'],
  54. [mnn.schema.NetSource.ONNX, 'ONNX'],
  55. [mnn.schema.NetSource.TORCH, 'Torch']
  56. ]);
  57. if (!sources.has(net.sourceType)) {
  58. throw new mnn.Error(`Unsupported model source '${net.sourceType}'.`);
  59. }
  60. this.source = sources.get(net.sourceType);
  61. this.modules = [new mnn.Graph(metadata, net)];
  62. }
  63. };
  64. mnn.Graph = class {
  65. constructor(metadata, net) {
  66. this.name = '';
  67. this.nodes = [];
  68. this.inputs = [];
  69. this.outputs = [];
  70. for (let i = 0; i < net.tensorName.length; i++) {
  71. if (net.tensorName[i] === '') {
  72. net.tensorName[i] = `\n${i}`;
  73. }
  74. }
  75. const inputs = new Map();
  76. for (const op of net.oplists) {
  77. for (const input of op.inputIndexes) {
  78. inputs.set(input, (inputs.get(input) || 0) + 1);
  79. }
  80. }
  81. const consts = new Map();
  82. const oplists = net.oplists.filter((op) => {
  83. if (op.type === mnn.schema.OpType.Const &&
  84. op.inputIndexes.length === 0 &&
  85. op.outputIndexes.length === 1 &&
  86. op.main instanceof mnn.schema.Blob &&
  87. inputs.get(op.outputIndexes[0]) === 1) {
  88. consts.set(op.outputIndexes[0], op);
  89. return false;
  90. }
  91. return true;
  92. });
  93. const values = new Map();
  94. values.map = (index) => {
  95. if (!values.has(index)) {
  96. const name = net.tensorName[index];
  97. const op = consts.get(index);
  98. if (op) {
  99. const tensor = op ? mnn.Utility.createTensor(op.main, 'Const') : null;
  100. values.set(index, new mnn.Value(name, null, tensor));
  101. } else {
  102. const extraTensorDescribe = net.extraTensorDescribe[index];
  103. const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
  104. const type = blob && blob.dims && blob.dims.length > 0 ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims), blob.dataFormat) : null;
  105. values.set(index, new mnn.Value(name, type, null));
  106. }
  107. }
  108. return values.get(index);
  109. };
  110. for (const op of oplists) {
  111. if (op.type === mnn.schema.OpType.Input) {
  112. const args = Array.from(op.outputIndexes).map((index) => values.map(index));
  113. const argument = new mnn.Argument(op.name, args);
  114. this.inputs.push(argument);
  115. } else {
  116. const node = new mnn.Node(metadata, op, values);
  117. this.nodes.push(node);
  118. }
  119. }
  120. for (let i = 0; i < net.tensorName.length; i++) {
  121. if (!inputs.has(i)) {
  122. const value = values.map(i);
  123. const argument = new mnn.Argument(value.name, [value]);
  124. this.outputs.push(argument);
  125. }
  126. }
  127. }
  128. };
  129. mnn.Node = class {
  130. constructor(metadata, op, values) {
  131. const type = mnn.Utility.enum('OpType', op.type) || `(${op.type})`;
  132. this.type = metadata.type(type) || { name: type };
  133. this.name = op.name || '';
  134. this.attributes = [];
  135. this.inputs = [];
  136. this.outputs = [];
  137. this.chains = [];
  138. if (op.inputIndexes && op.inputIndexes.length > 0) {
  139. const argument = new mnn.Argument('input', Array.from(op.inputIndexes).map((index) => values.map(index)));
  140. this.inputs.push(argument);
  141. }
  142. if (op.outputIndexes && op.outputIndexes.length > 0) {
  143. const argument = new mnn.Argument('output', Array.from(op.outputIndexes).map((index) => values.map(index)));
  144. this.outputs.push(argument);
  145. }
  146. const param = op.main;
  147. if (param) {
  148. const parameters = [param];
  149. if (param instanceof mnn.schema.Blob) {
  150. const tensor = mnn.Utility.createTensor(param, 'Blob');
  151. const value = new mnn.Value('', null, tensor);
  152. const argument = new mnn.Argument('value', [value]);
  153. this.inputs.push(argument);
  154. parameters.splice(0, parameters.length);
  155. } else if (param instanceof mnn.schema.Convolution2D) {
  156. const common = param.common;
  157. const outputCount = common.outputCount;
  158. const inputCount = common.inputCount;
  159. const kernelX = common.kernelX;
  160. const kernelY = common.kernelY;
  161. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [outputCount, inputCount, kernelX, kernelY], param.weight);
  162. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [outputCount], param.bias);
  163. delete param.weight;
  164. delete param.bias;
  165. delete param.quanParameter;
  166. delete param.symmetricQuan;
  167. } else if (param instanceof mnn.schema.InnerProduct) {
  168. const outputCount = param.outputCount;
  169. const inputCount = outputCount > 0 ? param.weightSize / outputCount : 0;
  170. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [outputCount, inputCount], param.weight);
  171. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [outputCount], param.bias);
  172. delete param.weight;
  173. delete param.bias;
  174. delete param.quanParameter;
  175. } else if (param instanceof mnn.schema.Scale) {
  176. const scaleDataCount = param.channels;
  177. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [scaleDataCount], param.scaleData);
  178. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [scaleDataCount], param.biasData);
  179. delete param.scaleData;
  180. delete param.biasData;
  181. } else if (param instanceof mnn.schema.BatchNorm) {
  182. const channels = param.channels;
  183. this._buildTensor('mean', mnn.schema.DataType.DT_FLOAT, [channels], param.meanData);
  184. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [channels], param.slopeData);
  185. this._buildTensor('variance', mnn.schema.DataType.DT_FLOAT, [channels], param.varData);
  186. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [channels], param.biasData);
  187. delete param.slopeData;
  188. delete param.meanData;
  189. delete param.varData;
  190. delete param.biasData;
  191. } else if (param instanceof mnn.schema.PRelu) {
  192. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [param.slopeCount], param.slope);
  193. delete param.slopeCount;
  194. } else if (param instanceof mnn.schema.Normalize) {
  195. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [param.scale.length], param.scale);
  196. delete param.scale;
  197. }
  198. while (parameters.length > 0) {
  199. const parameter = parameters.shift();
  200. const node_type = type;
  201. for (const [key, obj] of Object.entries(parameter)) {
  202. if (Object.keys(mnn.schema).find((key) => mnn.schema[key].prototype && obj instanceof mnn.schema[key])) {
  203. parameters.push(obj);
  204. continue;
  205. }
  206. const schema = metadata.attribute(node_type, key);
  207. let value = ArrayBuffer.isView(obj) ? Array.from(obj) : obj;
  208. let type = null;
  209. if (schema && schema.type) {
  210. type = schema.type;
  211. switch (type) {
  212. case 'DataType':
  213. value = mnn.Utility.dataType(value);
  214. break;
  215. default:
  216. value = mnn.Utility.enum(type, value);
  217. break;
  218. }
  219. }
  220. const attribute = new mnn.Argument(key, value, type);
  221. this.attributes.push(attribute);
  222. }
  223. }
  224. }
  225. }
  226. _buildTensor(name, dataType, dimensions, value) {
  227. const shape = new mnn.TensorShape(dimensions);
  228. const type = new mnn.TensorType(dataType, shape);
  229. const tensor = new mnn.Tensor('Weight', type, value);
  230. const argument = new mnn.Argument(name, [new mnn.Value('', null, tensor)]);
  231. this.inputs.push(argument);
  232. }
  233. };
  234. mnn.Argument = class {
  235. constructor(name, value, type = null) {
  236. this.name = name;
  237. this.value = value;
  238. this.type = type;
  239. }
  240. };
  241. mnn.Value = class {
  242. constructor(name, type, initializer = null) {
  243. this.name = name;
  244. this.type = !type && initializer ? initializer.type : type;
  245. this.initializer = initializer;
  246. }
  247. };
  248. mnn.Tensor = class {
  249. constructor(category, type, data) {
  250. this.category = category;
  251. this.type = type;
  252. switch (type.dataType) {
  253. case 'int32':
  254. case 'float32':
  255. this.encoding = '|';
  256. this.values = data ? data.slice(0) : null;
  257. break;
  258. case 'int8':
  259. case 'uint8':
  260. case 'float16':
  261. case 'bfloat16':
  262. this.encoding = '<';
  263. this.values = data ? data.slice(0) : null;
  264. break;
  265. default:
  266. throw new mnn.Error(`Unsupported data type '${type.dataType}'.`);
  267. }
  268. }
  269. };
  270. mnn.TensorType = class {
  271. constructor(dataType, shape, format) {
  272. this.dataType = mnn.Utility.dataType(dataType);
  273. this.shape = shape;
  274. if (format) {
  275. switch (format) {
  276. case mnn.schema.MNN_DATA_FORMAT.NCHW: this.denotation = 'NCHW'; break;
  277. case mnn.schema.MNN_DATA_FORMAT.NHWC: this.denotation = 'NHWC'; break;
  278. case mnn.schema.MNN_DATA_FORMAT.NC4HW4: this.denotation = 'NC4HW4'; break;
  279. case mnn.schema.MNN_DATA_FORMAT.NHWC4: this.denotation = 'NHWC4'; break;
  280. default: throw new mnn.Error(`Unsupported tensor type format '${format}'.`);
  281. }
  282. }
  283. }
  284. toString() {
  285. return this.dataType + this.shape.toString();
  286. }
  287. };
  288. mnn.TensorShape = class {
  289. constructor(dimensions) {
  290. this.dimensions = Array.from(dimensions);
  291. }
  292. toString() {
  293. if (this.dimensions && this.dimensions.length > 0) {
  294. return `[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`;
  295. }
  296. return '';
  297. }
  298. };
  299. mnn.Utility = class {
  300. static dataType(type) {
  301. switch (type) {
  302. case mnn.schema.DataType.DT_INVALID: return '?';
  303. case mnn.schema.DataType.DT_FLOAT: return 'float32';
  304. case mnn.schema.DataType.DT_DOUBLE: return 'float64';
  305. case mnn.schema.DataType.DT_INT32: return 'int32';
  306. case mnn.schema.DataType.DT_UINT8: return 'uint8';
  307. case mnn.schema.DataType.DT_INT16: return 'int16';
  308. case mnn.schema.DataType.DT_INT8: return 'int8';
  309. case mnn.schema.DataType.DT_STRING: return 'string';
  310. case mnn.schema.DataType.DT_COMPLEX64: return 'complex64';
  311. case mnn.schema.DataType.DT_INT64: return 'int64';
  312. case mnn.schema.DataType.DT_BOOL: return 'boolean';
  313. case mnn.schema.DataType.DT_QINT8: return 'qint8';
  314. case mnn.schema.DataType.DT_QUINT8: return 'quint8';
  315. case mnn.schema.DataType.DT_QINT32: return 'qint32';
  316. case mnn.schema.DataType.DT_BFLOAT16: return 'bfloat16';
  317. case mnn.schema.DataType.DT_QINT16: return 'qint16';
  318. case mnn.schema.DataType.DT_QUINT16: return 'quint16';
  319. case mnn.schema.DataType.DT_UINT16: return 'uint16';
  320. case mnn.schema.DataType.DT_COMPLEX128: return 'complex128';
  321. case mnn.schema.DataType.DT_HALF: return 'float16';
  322. case mnn.schema.DataType.DT_RESOURCE: return 'resource';
  323. case mnn.schema.DataType.DT_VARIANT: return 'variant';
  324. default: throw new mnn.Error(`Unsupported data type '${JSON.stringify(type)}'.`);
  325. }
  326. }
  327. static enum(name, value) {
  328. const type = name && mnn.schema ? mnn.schema[name] : undefined;
  329. if (type) {
  330. mnn.Utility._enumKeyMap = mnn.Utility._enumKeyMap || new Map();
  331. if (!mnn.Utility._enumKeyMap.has(name)) {
  332. const map = new Map();
  333. for (const key of Object.keys(type)) {
  334. map.set(type[key], key);
  335. }
  336. mnn.Utility._enumKeyMap.set(name, map);
  337. }
  338. const map = mnn.Utility._enumKeyMap.get(name);
  339. if (map.has(value)) {
  340. return map.get(value);
  341. }
  342. }
  343. return value.toString();
  344. }
  345. static createTensor(param, category) {
  346. const shape = new mnn.TensorShape(param.dims);
  347. const type = new mnn.TensorType(param.dataType, shape, param.dataFormat);
  348. let data = null;
  349. switch (type.dataType) {
  350. case 'uint8': data = param.uint8s; break;
  351. case 'int8': data = param.int8s; break;
  352. case 'int32': data = param.int32s; break;
  353. case 'int64': data = param.int64s; break;
  354. case 'float16': data = param.uint8s; break;
  355. case 'float32': data = param.float32s; break;
  356. case 'bfloat16': data = param.uint8s; break;
  357. default: throw new mnn.Error(`Unsupported blob data type '${JSON.stringify(type.dataType)}'.`);
  358. }
  359. return new mnn.Tensor(category, type, data);
  360. }
  361. };
  362. mnn.Error = class extends Error {
  363. constructor(message) {
  364. super(message);
  365. this.name = 'Error loading MNN model.';
  366. }
  367. };
  368. export const ModelFactory = mnn.ModelFactory;