mnn.js 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. const mnn = {};
  2. mnn.ModelFactory = class {
  3. match(context) {
  4. const reader = context.peek('flatbuffers.binary');
  5. if (reader) {
  6. context.type = 'mnn.flatbuffers';
  7. context.target = reader;
  8. }
  9. }
  10. async open(context) {
  11. mnn.schema = await context.require('./mnn-schema');
  12. mnn.schema = mnn.schema.MNN;
  13. let net = null;
  14. try {
  15. const reader = context.target;
  16. net = mnn.schema.Net.create(reader);
  17. } catch (error) {
  18. const message = error && error.message ? error.message : error.toString();
  19. throw new mnn.Error(`File format is not mnn.Net (${message.replace(/\.$/, '')}).`);
  20. }
  21. const metadata = await context.metadata('mnn-metadata.json');
  22. return new mnn.Model(metadata, net);
  23. }
  24. };
  25. mnn.Model = class {
  26. constructor(metadata, net) {
  27. this.format = 'MNN v2';
  28. const sources = new Map([
  29. [ mnn.schema.NetSource.CAFFE, 'Caffe' ],
  30. [ mnn.schema.NetSource.TENSORFLOW, 'TensorFlow' ],
  31. [ mnn.schema.NetSource.TFLITE, 'TensorFlow Lite' ],
  32. [ mnn.schema.NetSource.ONNX, 'ONNX' ],
  33. [ mnn.schema.NetSource.TORCH, 'Torch' ]
  34. ]);
  35. if (!sources.has(net.sourceType)) {
  36. throw new mnn.Error(`Unsupported model source '${net.sourceType}'.`);
  37. }
  38. this.metadata = new Map();
  39. this.metadata.set('source', sources.get(net.sourceType));
  40. this.graphs = [ new mnn.Graph(metadata, net) ];
  41. }
  42. };
  43. mnn.Graph = class {
  44. constructor(metadata, net) {
  45. this.name = '';
  46. this.nodes = [];
  47. this.inputs = [];
  48. this.outputs = [];
  49. for (let i = 0; i < net.tensorName.length; i++) {
  50. if (net.tensorName[i] === '') {
  51. net.tensorName[i] = `\n${i}`;
  52. }
  53. }
  54. const inputs = new Map();
  55. for (const op of net.oplists) {
  56. for (const input of op.inputIndexes) {
  57. inputs.set(input, (inputs.get(input) || 0) + 1);
  58. }
  59. }
  60. const consts = new Map();
  61. const oplists = net.oplists.filter((op) => {
  62. if (op.type === mnn.schema.OpType.Const &&
  63. op.inputIndexes.length === 0 &&
  64. op.outputIndexes.length === 1 &&
  65. op.main instanceof mnn.schema.Blob &&
  66. inputs.get(op.outputIndexes[0]) === 1) {
  67. consts.set(op.outputIndexes[0], op);
  68. return false;
  69. }
  70. return true;
  71. });
  72. const values = new Map();
  73. values.map = (index) => {
  74. if (!values.has(index)) {
  75. const name = net.tensorName[index];
  76. const op = consts.get(index);
  77. if (op) {
  78. const tensor = op ? mnn.Utility.createTensor(op.main, 'Const') : null;
  79. values.set(index, new mnn.Value(name, null, tensor));
  80. } else {
  81. const extraTensorDescribe = net.extraTensorDescribe[index];
  82. const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
  83. const type = blob && blob.dims && blob.dims.length > 0 ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims), blob.dataFormat) : null;
  84. values.set(index, new mnn.Value(name, type, null));
  85. }
  86. }
  87. return values.get(index);
  88. };
  89. for (const op of oplists) {
  90. if (op.type === mnn.schema.OpType.Input) {
  91. const args = Array.from(op.outputIndexes).map((index) => values.map(index));
  92. const argument = new mnn.Argument(op.name, args);
  93. this.inputs.push(argument);
  94. } else {
  95. const node = new mnn.Node(metadata, op, net, values);
  96. this.nodes.push(node);
  97. }
  98. }
  99. for (let i = 0; i < net.tensorName.length; i++) {
  100. if (!inputs.has(i)) {
  101. const value = values.map(i);
  102. const argument = new mnn.Argument(value.name, [ value ]);
  103. this.outputs.push(argument);
  104. }
  105. }
  106. }
  107. };
  108. mnn.Node = class {
  109. constructor(metadata, op, net, values) {
  110. const type = mnn.Utility.enum('OpType', op.type) || `(${op.type})`;
  111. this.type = metadata.type(type) || { name: type };
  112. this.name = op.name || '';
  113. this.attributes = [];
  114. this.inputs = [];
  115. this.outputs = [];
  116. this.chains = [];
  117. if (op.inputIndexes && op.inputIndexes.length > 0) {
  118. const argument = new mnn.Argument('input', Array.from(op.inputIndexes).map((index) => values.map(index)));
  119. this.inputs.push(argument);
  120. }
  121. if (op.outputIndexes && op.outputIndexes.length > 0) {
  122. const argument = new mnn.Argument('output', Array.from(op.outputIndexes).map((index) => values.map(index)));
  123. this.outputs.push(argument);
  124. }
  125. const param = op.main;
  126. if (param) {
  127. const parameters = [ param ];
  128. if (param instanceof mnn.schema.Blob) {
  129. const tensor = mnn.Utility.createTensor(param, 'Blob');
  130. const value = new mnn.Value('', null, tensor);
  131. const argument = new mnn.Argument('value', [ value ]);
  132. this.inputs.push(argument);
  133. parameters.splice(0, parameters.length);
  134. } else if (param instanceof mnn.schema.Convolution2D) {
  135. const common = param.common;
  136. const outputCount = common.outputCount;
  137. const inputCount = common.inputCount;
  138. const kernelX = common.kernelX;
  139. const kernelY = common.kernelY;
  140. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount, kernelX, kernelY ], param.weight);
  141. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  142. delete param.weight;
  143. delete param.bias;
  144. delete param.quanParameter;
  145. delete param.symmetricQuan;
  146. } else if (param instanceof mnn.schema.InnerProduct) {
  147. const outputCount = param.outputCount;
  148. const inputCount = param.weightSize / outputCount;
  149. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount ], param.weight);
  150. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  151. delete param.weight;
  152. delete param.bias;
  153. delete param.quanParameter;
  154. } else if (param instanceof mnn.schema.Scale) {
  155. const scaleDataCount = param.channels;
  156. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.scaleData);
  157. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.biasData);
  158. delete param.scaleData;
  159. delete param.biasData;
  160. } else if (param instanceof mnn.schema.BatchNorm) {
  161. const channels = param.channels;
  162. this._buildTensor('mean', mnn.schema.DataType.DT_FLOAT, [ channels ], param.meanData);
  163. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ channels ], param.slopeData);
  164. this._buildTensor('variance', mnn.schema.DataType.DT_FLOAT, [ channels ], param.varData);
  165. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ channels ], param.biasData);
  166. delete param.slopeData;
  167. delete param.meanData;
  168. delete param.varData;
  169. delete param.biasData;
  170. } else if (param instanceof mnn.schema.PRelu) {
  171. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ param.slopeCount ], param.slope);
  172. delete param.slopeCount;
  173. } else if (param instanceof mnn.schema.Normalize) {
  174. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ param.scale.length ], param.scale);
  175. delete param.scale;
  176. }
  177. while (parameters.length > 0) {
  178. const parameter = parameters.shift();
  179. for (const [key, value] of Object.entries(parameter)) {
  180. if (Object.keys(mnn.schema).find((key) => mnn.schema[key].prototype && value instanceof mnn.schema[key])) {
  181. parameters.push(value);
  182. continue;
  183. }
  184. const attribute = new mnn.Attribute(metadata.attribute(type, key), key, value);
  185. this.attributes.push(attribute);
  186. }
  187. }
  188. }
  189. }
  190. _buildTensor(name, dataType, dimensions, value) {
  191. const shape = new mnn.TensorShape(dimensions);
  192. const type = new mnn.TensorType(dataType, shape);
  193. const tensor = new mnn.Tensor('Weight', type, value);
  194. const argument = new mnn.Argument(name, [ new mnn.Value('', null, tensor) ]);
  195. this.inputs.push(argument);
  196. }
  197. };
  198. mnn.Attribute = class {
  199. constructor(metadata, name, value, visible) {
  200. this.type = null;
  201. this.value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  202. this.name = name;
  203. this.visible = visible ? true : false;
  204. if (metadata && metadata.type) {
  205. this.type = metadata.type;
  206. switch (this.type) {
  207. case 'DataType':
  208. this.value = mnn.Utility.dataType(this.value);
  209. break;
  210. default:
  211. this.value = mnn.Utility.enum(this.type, this.value);
  212. break;
  213. }
  214. }
  215. }
  216. };
  217. mnn.Argument = class {
  218. constructor(name, value) {
  219. this.name = name;
  220. this.value = value;
  221. }
  222. };
  223. mnn.Value = class {
  224. constructor(name, type, initializer) {
  225. this.name = name;
  226. this.type = type ? type : initializer ? initializer.type : null;
  227. this.initializer = initializer || null;
  228. }
  229. };
  230. mnn.Tensor = class {
  231. constructor(category, type, data) {
  232. this.category = category;
  233. this.type = type;
  234. switch (type.dataType) {
  235. case 'int32':
  236. case 'float32':
  237. this.encoding = '|';
  238. this.values = data ? data.slice(0) : null;
  239. break;
  240. case 'uint8':
  241. case 'float16':
  242. this.encoding = '<';
  243. this.values = data ? data.slice(0) : null;
  244. break;
  245. default:
  246. throw new mnn.Error(`Unsupported data type '${type.dataType}'.`);
  247. }
  248. }
  249. };
  250. mnn.TensorType = class {
  251. constructor(dataType, shape, format) {
  252. this.dataType = mnn.Utility.dataType(dataType);
  253. this.shape = shape;
  254. if (format) {
  255. switch (format) {
  256. case mnn.schema.MNN_DATA_FORMAT.NCHW: this.denotation = 'NCHW'; break;
  257. case mnn.schema.MNN_DATA_FORMAT.NHWC: this.denotation = 'NHWC'; break;
  258. case mnn.schema.MNN_DATA_FORMAT.NC4HW4: this.denotation = 'NC4HW4'; break;
  259. case mnn.schema.MNN_DATA_FORMAT.NHWC4: this.denotation = 'NHWC4'; break;
  260. default: throw new mnn.Error(`Unsupported tensor type format '${format}'.`);
  261. }
  262. }
  263. }
  264. toString() {
  265. return this.dataType + this.shape.toString();
  266. }
  267. };
  268. mnn.TensorShape = class {
  269. constructor(dimensions) {
  270. this.dimensions = Array.from(dimensions);
  271. }
  272. toString() {
  273. if (this.dimensions && this.dimensions.length > 0) {
  274. return `[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`;
  275. }
  276. return '';
  277. }
  278. };
  279. mnn.Utility = class {
  280. static dataType(type) {
  281. switch (type) {
  282. case mnn.schema.DataType.DT_INVALID: return '?';
  283. case mnn.schema.DataType.DT_FLOAT: return 'float32';
  284. case mnn.schema.DataType.DT_DOUBLE: return 'float64';
  285. case mnn.schema.DataType.DT_INT32: return 'int32';
  286. case mnn.schema.DataType.DT_UINT8: return 'uint8';
  287. case mnn.schema.DataType.DT_INT16: return 'int16';
  288. case mnn.schema.DataType.DT_INT8: return 'int8';
  289. case mnn.schema.DataType.DT_STRING: return 'string';
  290. case mnn.schema.DataType.DT_COMPLEX64: return 'complex64';
  291. case mnn.schema.DataType.DT_INT64: return 'int64';
  292. case mnn.schema.DataType.DT_BOOL: return 'boolean';
  293. case mnn.schema.DataType.DT_QINT8: return 'qint8';
  294. case mnn.schema.DataType.DT_QUINT8: return 'quint8';
  295. case mnn.schema.DataType.DT_QINT32: return 'qint32';
  296. case mnn.schema.DataType.DT_BFLOAT16: return 'bfloat16';
  297. case mnn.schema.DataType.DT_QINT16: return 'qint16';
  298. case mnn.schema.DataType.DT_QUINT16: return 'quint16';
  299. case mnn.schema.DataType.DT_UINT16: return 'uint16';
  300. case mnn.schema.DataType.DT_COMPLEX128: return 'complex128';
  301. case mnn.schema.DataType.DT_HALF: return 'float16';
  302. case mnn.schema.DataType.DT_RESOURCE: return 'resource';
  303. case mnn.schema.DataType.DT_VARIANT: return 'variant';
  304. default: throw new mnn.Error(`Unsupported data type '${JSON.stringify(type)}'.`);
  305. }
  306. }
  307. static enum(name, value) {
  308. const type = name && mnn.schema ? mnn.schema[name] : undefined;
  309. if (type) {
  310. mnn.Utility._enumKeyMap = mnn.Utility._enumKeyMap || new Map();
  311. if (!mnn.Utility._enumKeyMap.has(name)) {
  312. const map = new Map();
  313. for (const key of Object.keys(type)) {
  314. map.set(type[key], key);
  315. }
  316. mnn.Utility._enumKeyMap.set(name, map);
  317. }
  318. const map = mnn.Utility._enumKeyMap.get(name);
  319. if (map.has(value)) {
  320. return map.get(value);
  321. }
  322. }
  323. return value.toString();
  324. }
  325. static createTensor(param, category) {
  326. const shape = new mnn.TensorShape(param.dims);
  327. const type = new mnn.TensorType(param.dataType, shape, param.dataFormat);
  328. let data = null;
  329. switch (type.dataType) {
  330. case 'uint8': data = param.uint8s; break;
  331. case 'int8': data = param.int8s; break;
  332. case 'int32': data = param.int32s; break;
  333. case 'int64': data = param.int64s; break;
  334. case 'float16': data = param.uint8s; break;
  335. case 'float32': data = param.float32s; break;
  336. default: throw new mnn.Error(`Unsupported blob data type '${JSON.stringify(type.dataType)}'.`);
  337. }
  338. return new mnn.Tensor(category, type, data);
  339. }
  340. };
  341. mnn.Error = class extends Error {
  342. constructor(message) {
  343. super(message);
  344. this.name = 'Error loading MNN model.';
  345. }
  346. };
  347. export const ModelFactory = mnn.ModelFactory;