mnn.js 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import * as flatbuffers from './flatbuffers.js';
  2. const mnn = {};
  3. mnn.ModelFactory = class {
  4. match(context) {
  5. const identifier = context.identifier;
  6. const extension = identifier.split('.').pop().toLowerCase();
  7. if (extension == 'mnn') {
  8. const stream = context.stream;
  9. if (stream && stream.length >= 4) {
  10. const buffer = stream.peek(4);
  11. const reader = flatbuffers.BinaryReader.open(buffer);
  12. if (reader.root === 0x00000018 || reader.root === 0x0000001C || reader.root === 0x00000020) {
  13. return 'mnn.flatbuffers';
  14. }
  15. }
  16. }
  17. return null;
  18. }
  19. async open(context) {
  20. await context.require('./mnn-schema');
  21. let net = null;
  22. try {
  23. mnn.schema = flatbuffers.get('mnn').MNN;
  24. const stream = context.stream;
  25. const reader = flatbuffers.BinaryReader.open(stream);
  26. net = mnn.schema.Net.create(reader);
  27. } catch (error) {
  28. const message = error && error.message ? error.message : error.toString();
  29. throw new mnn.Error('File format is not mnn.Net (' + message.replace(/\.$/, '') + ').');
  30. }
  31. const metadata = await context.metadata('mnn-metadata.json');
  32. return new mnn.Model(metadata, net);
  33. }
  34. };
  35. mnn.Model = class {
  36. constructor(metadata, net) {
  37. this.format = 'MNN v2';
  38. const sources = new Map([
  39. [ mnn.schema.NetSource.CAFFE, 'Caffe' ],
  40. [ mnn.schema.NetSource.TENSORFLOW, 'TensorFlow' ],
  41. [ mnn.schema.NetSource.TFLITE, 'TensorFlow Lite' ],
  42. [ mnn.schema.NetSource.ONNX, 'ONNX' ],
  43. [ mnn.schema.NetSource.TORCH, 'Torch' ]
  44. ]);
  45. if (!sources.has(net.sourceType)) {
  46. throw new mnn.Error("Unsupported model source '" + net.sourceType + "'.");
  47. }
  48. this.metadata = new Map();
  49. this.metadata.set('source', sources.get(net.sourceType));
  50. this.graphs = [ new mnn.Graph(metadata, net) ];
  51. }
  52. };
  53. mnn.Graph = class {
  54. constructor(metadata, net) {
  55. this.name = '';
  56. this.nodes = [];
  57. this.inputs = [];
  58. this.outputs = [];
  59. for (let i = 0; i < net.tensorName.length; i++) {
  60. if (net.tensorName[i] === '') {
  61. net.tensorName[i] = '\n' + i.toString();
  62. }
  63. }
  64. const inputs = new Map();
  65. for (const op of net.oplists) {
  66. for (const input of op.inputIndexes) {
  67. inputs.set(input, (inputs.get(input) || 0) + 1);
  68. }
  69. }
  70. const consts = new Map();
  71. const oplists = net.oplists.filter((op) => {
  72. if (op.type === mnn.schema.OpType.Const &&
  73. op.inputIndexes.length === 0 &&
  74. op.outputIndexes.length === 1 &&
  75. op.main instanceof mnn.schema.Blob &&
  76. inputs.get(op.outputIndexes[0]) === 1) {
  77. consts.set(op.outputIndexes[0], op);
  78. return false;
  79. }
  80. return true;
  81. });
  82. const values = new Map();
  83. values.map = (index) => {
  84. if (!values.has(index)) {
  85. const name = net.tensorName[index];
  86. const op = consts.get(index);
  87. if (op) {
  88. const tensor = op ? mnn.Utility.createTensor(op.main, 'Const') : null;
  89. values.set(index, new mnn.Value(name, null, tensor));
  90. } else {
  91. const extraTensorDescribe = net.extraTensorDescribe[index];
  92. const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
  93. const type = blob && blob.dims && blob.dims.length > 0 ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims), blob.dataFormat) : null;
  94. values.set(index, new mnn.Value(name, type, null));
  95. }
  96. }
  97. return values.get(index);
  98. };
  99. for (const op of oplists) {
  100. if (op.type === mnn.schema.OpType.Input) {
  101. const args = Array.from(op.outputIndexes).map((index) => values.map(index));
  102. const argument = new mnn.Argument(op.name, args);
  103. this.inputs.push(argument);
  104. } else {
  105. const node = new mnn.Node(metadata, op, net, values);
  106. this.nodes.push(node);
  107. }
  108. }
  109. for (let i = 0; i < net.tensorName.length; i++) {
  110. if (!inputs.has(i)) {
  111. const value = values.map(i);
  112. const argument = new mnn.Argument(value.name, [ value ]);
  113. this.outputs.push(argument);
  114. }
  115. }
  116. }
  117. };
  118. mnn.Node = class {
  119. constructor(metadata, op, net, values) {
  120. const type = mnn.Utility.enum('OpType', op.type) || '(' + op.type.toString() + ')';
  121. this.type = metadata.type(type) || { name: type };
  122. this.name = op.name || '';
  123. this.attributes = [];
  124. this.inputs = [];
  125. this.outputs = [];
  126. this.chains = [];
  127. if (op.inputIndexes && op.inputIndexes.length > 0) {
  128. const argument = new mnn.Argument('input', Array.from(op.inputIndexes).map((index) => values.map(index)));
  129. this.inputs.push(argument);
  130. }
  131. if (op.outputIndexes && op.outputIndexes.length > 0) {
  132. const argument = new mnn.Argument('output', Array.from(op.outputIndexes).map((index) => values.map(index)));
  133. this.outputs.push(argument);
  134. }
  135. const param = op.main;
  136. if (param) {
  137. const parameters = [ param ];
  138. if (param instanceof mnn.schema.Blob) {
  139. const tensor = mnn.Utility.createTensor(param, 'Blob');
  140. const value = new mnn.Value('', null, tensor);
  141. const argument = new mnn.Argument('value', [ value ]);
  142. this.inputs.push(argument);
  143. parameters.splice(0, parameters.length);
  144. } else if (param instanceof mnn.schema.Convolution2D) {
  145. const common = param.common;
  146. const outputCount = common.outputCount;
  147. const inputCount = common.inputCount;
  148. const kernelX = common.kernelX;
  149. const kernelY = common.kernelY;
  150. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount, kernelX, kernelY ], param.weight);
  151. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  152. delete param.weight;
  153. delete param.bias;
  154. delete param.quanParameter;
  155. delete param.symmetricQuan;
  156. } else if (param instanceof mnn.schema.InnerProduct) {
  157. const outputCount = param.outputCount;
  158. const inputCount = param.weightSize / outputCount;
  159. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount ], param.weight);
  160. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  161. delete param.weight;
  162. delete param.bias;
  163. delete param.quanParameter;
  164. } else if (param instanceof mnn.schema.Scale) {
  165. const scaleDataCount = param.channels;
  166. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.scaleData);
  167. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.biasData);
  168. delete param.scaleData;
  169. delete param.biasData;
  170. } else if (param instanceof mnn.schema.BatchNorm) {
  171. const channels = param.channels;
  172. this._buildTensor('mean', mnn.schema.DataType.DT_FLOAT, [ channels ], param.meanData);
  173. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ channels ], param.slopeData);
  174. this._buildTensor('variance', mnn.schema.DataType.DT_FLOAT, [ channels ], param.varData);
  175. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ channels ], param.biasData);
  176. delete param.slopeData;
  177. delete param.meanData;
  178. delete param.varData;
  179. delete param.biasData;
  180. } else if (param instanceof mnn.schema.PRelu) {
  181. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ param.slopeCount ], param.slope);
  182. delete param.slopeCount;
  183. } else if (param instanceof mnn.schema.Normalize) {
  184. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ param.scale.length ], param.scale);
  185. delete param.scale;
  186. }
  187. while (parameters.length > 0) {
  188. const parameter = parameters.shift();
  189. for (const [key, value] of Object.entries(parameter)) {
  190. if (Object.keys(mnn.schema).find((key) => mnn.schema[key].prototype && value instanceof mnn.schema[key])) {
  191. parameters.push(value);
  192. continue;
  193. }
  194. const attribute = new mnn.Attribute(metadata.attribute(type, key), key, value);
  195. this.attributes.push(attribute);
  196. }
  197. }
  198. }
  199. }
  200. _buildTensor(name, dataType, dimensions, value) {
  201. const shape = new mnn.TensorShape(dimensions);
  202. const type = new mnn.TensorType(dataType, shape);
  203. const tensor = new mnn.Tensor('Weight', type, value);
  204. const argument = new mnn.Argument(name, [ new mnn.Value('', null, tensor) ]);
  205. this.inputs.push(argument);
  206. }
  207. };
  208. mnn.Attribute = class {
  209. constructor(metadata, name, value, visible) {
  210. this.type = null;
  211. this.value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  212. this.name = name;
  213. this.visible = visible ? true : false;
  214. if (metadata && metadata.type) {
  215. this.type = metadata.type;
  216. switch (this.type) {
  217. case 'DataType':
  218. this.value = mnn.Utility.dataType(this.value);
  219. break;
  220. default:
  221. this.value = mnn.Utility.enum(this.type, this.value);
  222. break;
  223. }
  224. }
  225. }
  226. };
  227. mnn.Argument = class {
  228. constructor(name, value) {
  229. this.name = name;
  230. this.value = value;
  231. }
  232. };
  233. mnn.Value = class {
  234. constructor(name, type, initializer) {
  235. this.name = name;
  236. this.type = type ? type : initializer ? initializer.type : null;
  237. this.initializer = initializer || null;
  238. }
  239. };
  240. mnn.Tensor = class {
  241. constructor(category, type, data) {
  242. this.category = category;
  243. this.type = type;
  244. switch (type.dataType) {
  245. case 'int32':
  246. case 'float32':
  247. this.encoding = '|';
  248. this.values = data ? data.slice(0) : null;
  249. break;
  250. case 'float16':
  251. this.encoding = '<';
  252. this.values = data ? data.slice(0) : null;
  253. break;
  254. default:
  255. throw new mnn.Error("Unsupported data type '" + type.dataType + "'.");
  256. }
  257. }
  258. };
  259. mnn.TensorType = class {
  260. constructor(dataType, shape, format) {
  261. this.dataType = mnn.Utility.dataType(dataType);
  262. this.shape = shape;
  263. if (format) {
  264. switch (format) {
  265. case mnn.schema.MNN_DATA_FORMAT.NCHW: this.denotation = 'NCHW'; break;
  266. case mnn.schema.MNN_DATA_FORMAT.NHWC: this.denotation = 'NHWC'; break;
  267. case mnn.schema.MNN_DATA_FORMAT.NC4HW4: this.denotation = 'NC4HW4'; break;
  268. case mnn.schema.MNN_DATA_FORMAT.NHWC4: this.denotation = 'NHWC4'; break;
  269. default: throw new mnn.Error("Unsupported tensor type format '" + format + "'.");
  270. }
  271. }
  272. }
  273. toString() {
  274. return this.dataType + this.shape.toString();
  275. }
  276. };
  277. mnn.TensorShape = class {
  278. constructor(dimensions) {
  279. this.dimensions = Array.from(dimensions);
  280. }
  281. toString() {
  282. if (this.dimensions && this.dimensions.length > 0) {
  283. return '[' + this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']';
  284. }
  285. return '';
  286. }
  287. };
  288. mnn.Utility = class {
  289. static dataType(type) {
  290. switch (type) {
  291. case mnn.schema.DataType.DT_INVALID: return '?';
  292. case mnn.schema.DataType.DT_FLOAT: return 'float32';
  293. case mnn.schema.DataType.DT_DOUBLE: return 'float64';
  294. case mnn.schema.DataType.DT_INT32: return 'int32';
  295. case mnn.schema.DataType.DT_UINT8: return 'uint8';
  296. case mnn.schema.DataType.DT_INT16: return 'int16';
  297. case mnn.schema.DataType.DT_INT8: return 'int8';
  298. case mnn.schema.DataType.DT_STRING: return 'string';
  299. case mnn.schema.DataType.DT_COMPLEX64: return 'complex64';
  300. case mnn.schema.DataType.DT_INT64: return 'int64';
  301. case mnn.schema.DataType.DT_BOOL: return 'boolean';
  302. case mnn.schema.DataType.DT_QINT8: return 'qint8';
  303. case mnn.schema.DataType.DT_QUINT8: return 'quint8';
  304. case mnn.schema.DataType.DT_QINT32: return 'qint32';
  305. case mnn.schema.DataType.DT_BFLOAT16: return 'bfloat16';
  306. case mnn.schema.DataType.DT_QINT16: return 'qint16';
  307. case mnn.schema.DataType.DT_QUINT16: return 'quint16';
  308. case mnn.schema.DataType.DT_UINT16: return 'uint16';
  309. case mnn.schema.DataType.DT_COMPLEX128: return 'complex128';
  310. case mnn.schema.DataType.DT_HALF: return 'float16';
  311. case mnn.schema.DataType.DT_RESOURCE: return 'resource';
  312. case mnn.schema.DataType.DT_VARIANT: return 'variant';
  313. default: throw new mnn.Error("Unsupported data type '" + JSON.stringify(type) + "'.");
  314. }
  315. }
  316. static enum(name, value) {
  317. const type = name && mnn.schema ? mnn.schema[name] : undefined;
  318. if (type) {
  319. mnn.Utility._enumKeyMap = mnn.Utility._enumKeyMap || new Map();
  320. if (!mnn.Utility._enumKeyMap.has(name)) {
  321. const map = new Map();
  322. for (const key of Object.keys(type)) {
  323. map.set(type[key], key);
  324. }
  325. mnn.Utility._enumKeyMap.set(name, map);
  326. }
  327. const map = mnn.Utility._enumKeyMap.get(name);
  328. if (map.has(value)) {
  329. return map.get(value);
  330. }
  331. }
  332. return value.toString();
  333. }
  334. static createTensor(param, category) {
  335. const shape = new mnn.TensorShape(param.dims);
  336. const type = new mnn.TensorType(param.dataType, shape, param.dataFormat);
  337. let data = null;
  338. switch (type.dataType) {
  339. case 'uint8': data = param.uint8s; break;
  340. case 'int8': data = param.int8s; break;
  341. case 'int32': data = param.int32s; break;
  342. case 'int64': data = param.int64s; break;
  343. case 'float16': data = param.uint8s; break;
  344. case 'float32': data = param.float32s; break;
  345. default: throw new mnn.Error("Unsupported blob data type '" + JSON.stringify(type.dataType) + "'.");
  346. }
  347. return new mnn.Tensor(category, type, data);
  348. }
  349. };
  350. mnn.Error = class extends Error {
  351. constructor(message) {
  352. super(message);
  353. this.name = 'Error loading MNN model.';
  354. }
  355. };
  356. export const ModelFactory = mnn.ModelFactory;