mnn.js 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. var mnn = mnn || {};
  2. var flatbuffers = flatbuffers || require('./flatbuffers');
  3. mnn.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. if (stream.length >= 4) {
  7. const extension = context.identifier.split('.').pop().toLowerCase();
  8. if (extension == 'mnn') {
  9. const buffer = stream.peek(4);
  10. const reader = flatbuffers.BinaryReader.open(buffer);
  11. if (reader.root === 0x00000018 || reader.root === 0x0000001C || reader.root === 0x00000020) {
  12. return 'mnn.flatbuffers';
  13. }
  14. }
  15. }
  16. return undefined;
  17. }
  18. open(context) {
  19. return context.require('./mnn-schema').then((/* schema */) => {
  20. let net = null;
  21. try {
  22. mnn.schema = flatbuffers.get('mnn').MNN;
  23. const stream = context.stream;
  24. const reader = flatbuffers.BinaryReader.open(stream);
  25. net = mnn.schema.Net.create(reader);
  26. }
  27. catch (error) {
  28. const message = error && error.message ? error.message : error.toString();
  29. throw new mnn.Error('File format is not mnn.Net (' + message.replace(/\.$/, '') + ').');
  30. }
  31. return context.metadata('mnn-metadata.json').then((metadata) => {
  32. return new mnn.Model(metadata, net);
  33. });
  34. });
  35. }
  36. };
  37. mnn.Model = class {
  38. constructor(metadata, net) {
  39. const sources = new Map([
  40. [ mnn.schema.NetSource.CAFFE, 'Caffe' ],
  41. [ mnn.schema.NetSource.TENSORFLOW, 'TensorFlow' ],
  42. [ mnn.schema.NetSource.TFLITE, 'TensorFlow Lite' ],
  43. [ mnn.schema.NetSource.ONNX, 'ONNX' ],
  44. [ mnn.schema.NetSource.TORCH, 'Torch' ]
  45. ]);
  46. if (!sources.has(net.sourceType)) {
  47. throw new mnn.Error("Unsupported model source '" + net.sourceType + "'.");
  48. }
  49. this._metadata = [
  50. { name: 'source', value: sources.get(net.sourceType) }
  51. ];
  52. this._graphs = [ new mnn.Graph(metadata, net) ];
  53. }
  54. get format() {
  55. return 'MNN v2';
  56. }
  57. get metadata() {
  58. return this._metadata;
  59. }
  60. get graphs() {
  61. return this._graphs;
  62. }
  63. };
  64. mnn.Graph = class {
  65. constructor(metadata, net) {
  66. this._nodes = [];
  67. this._inputs = [];
  68. this._outputs = [];
  69. for (let i = 0; i < net.tensorName.length; i++) {
  70. if (net.tensorName[i] === '') {
  71. net.tensorName[i] = '\n' + i.toString();
  72. }
  73. }
  74. const inputs = new Map();
  75. for (const op of net.oplists) {
  76. for (const input of op.inputIndexes) {
  77. inputs.set(input, (inputs.get(input) || 0) + 1);
  78. }
  79. }
  80. const consts = new Map();
  81. const oplists = net.oplists.filter((op) => {
  82. if (op.type === mnn.schema.OpType.Const &&
  83. op.inputIndexes.length === 0 &&
  84. op.outputIndexes.length === 1 &&
  85. op.main instanceof mnn.schema.Blob &&
  86. inputs.get(op.outputIndexes[0]) === 1) {
  87. consts.set(op.outputIndexes[0], op);
  88. return false;
  89. }
  90. return true;
  91. });
  92. const args = new Map();
  93. const arg = (index) => {
  94. if (!args.has(index)) {
  95. const name = net.tensorName[index];
  96. const op = consts.get(index);
  97. if (op) {
  98. const tensor = op ? mnn.Utility.createTensor(op.main, 'Const') : null;
  99. const argument = new mnn.Argument(name, null, tensor);
  100. args.set(index, argument);
  101. }
  102. else {
  103. const extraTensorDescribe = net.extraTensorDescribe[index];
  104. const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
  105. const type = blob && blob.dims && blob.dims.length > 0 ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims), blob.dataFormat) : null;
  106. const argument = new mnn.Argument(name, type, null);
  107. args.set(index, argument);
  108. }
  109. }
  110. return args.get(index);
  111. };
  112. for (const op of oplists) {
  113. if (op.type === mnn.schema.OpType.Input) {
  114. const args = Array.from(op.outputIndexes).map((index) => arg(index));
  115. this._inputs.push(new mnn.Parameter(op.name, true, args));
  116. }
  117. else {
  118. this._nodes.push(new mnn.Node(metadata, op, net, arg));
  119. }
  120. }
  121. for (let i = 0; i < net.tensorName.length; i++) {
  122. if (!inputs.has(i)) {
  123. const argument = arg(i);
  124. const parameter = new mnn.Parameter(argument.name, true, [ argument ]);
  125. this._outputs.push(parameter);
  126. }
  127. }
  128. }
  129. get name() {
  130. return '';
  131. }
  132. get nodes() {
  133. return this._nodes;
  134. }
  135. get outputs() {
  136. return this._outputs;
  137. }
  138. get inputs() {
  139. return this._inputs;
  140. }
  141. };
  142. mnn.Node = class {
  143. constructor(metadata, op, net, arg) {
  144. const type = mnn.Utility.enum('OpType', op.type) || '(' + op.type.toString() + ')';
  145. this._type = metadata.type(type) || { name: type };
  146. this._name = op.name || '';
  147. this._attributes = [];
  148. this._inputs = [];
  149. this._outputs = [];
  150. this._chains = [];
  151. if (op.inputIndexes && op.inputIndexes.length > 0) {
  152. this._inputs.push(new mnn.Parameter('input', true, Array.from(op.inputIndexes).map((index) => arg(index))));
  153. }
  154. if (op.outputIndexes && op.outputIndexes.length > 0) {
  155. this._outputs.push(new mnn.Parameter('output', true, Array.from(op.outputIndexes).map((index) => arg(index))));
  156. }
  157. const param = op.main;
  158. if (param) {
  159. const parameters = [ param ];
  160. if (param instanceof mnn.schema.Blob) {
  161. const tensor = mnn.Utility.createTensor(param, 'Blob');
  162. const argument = new mnn.Argument('', null, tensor);
  163. const parameter = new mnn.Parameter('value', true, [ argument ]);
  164. this._inputs.push(parameter);
  165. parameters.splice(0, parameters.length);
  166. }
  167. else if (param instanceof mnn.schema.Convolution2D) {
  168. const common = param.common;
  169. const outputCount = common.outputCount;
  170. const inputCount = common.inputCount;
  171. const kernelX = common.kernelX;
  172. const kernelY = common.kernelY;
  173. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount, kernelX, kernelY ], param.weight);
  174. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  175. delete param.weight;
  176. delete param.bias;
  177. delete param.quanParameter;
  178. delete param.symmetricQuan;
  179. }
  180. else if (param instanceof mnn.schema.InnerProduct) {
  181. const outputCount = param.outputCount;
  182. const inputCount = param.weightSize / outputCount;
  183. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount ], param.weight);
  184. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  185. delete param.weight;
  186. delete param.bias;
  187. delete param.quanParameter;
  188. }
  189. else if (param instanceof mnn.schema.Scale) {
  190. const scaleDataCount = param.channels;
  191. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.scaleData);
  192. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.biasData);
  193. delete param.scaleData;
  194. delete param.biasData;
  195. }
  196. else if (param instanceof mnn.schema.BatchNorm) {
  197. const channels = param.channels;
  198. this._buildTensor('mean', mnn.schema.DataType.DT_FLOAT, [ channels ], param.meanData);
  199. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ channels ], param.slopeData);
  200. this._buildTensor('variance', mnn.schema.DataType.DT_FLOAT, [ channels ], param.varData);
  201. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ channels ], param.biasData);
  202. delete param.slopeData;
  203. delete param.meanData;
  204. delete param.varData;
  205. delete param.biasData;
  206. }
  207. else if (param instanceof mnn.schema.PRelu) {
  208. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ param.slopeCount ], param.slope);
  209. delete param.slopeCount;
  210. }
  211. else if (param instanceof mnn.schema.Normalize) {
  212. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ param.scale.length ], param.scale);
  213. delete param.scale;
  214. }
  215. while (parameters.length > 0) {
  216. const parameter = parameters.shift();
  217. for (const key of Object.keys(parameter)) {
  218. if (Object.prototype.hasOwnProperty.call(parameter, key)) {
  219. const value = parameter[key];
  220. if (Object.keys(mnn.schema).find((key) => mnn.schema[key].prototype && value instanceof mnn.schema[key])) {
  221. parameters.push(value);
  222. continue;
  223. }
  224. const schema = metadata.attribute(this.type, key);
  225. this._attributes.push(new mnn.Attribute(schema, key, value));
  226. }
  227. }
  228. }
  229. }
  230. }
  231. _buildTensor(name, dataType, dimensions, value) {
  232. const shape = new mnn.TensorShape(dimensions);
  233. const type = new mnn.TensorType(dataType, shape);
  234. const tensor = new mnn.Tensor('Weight', type, value);
  235. const argument = new mnn.Argument('', null, tensor);
  236. const parameter = new mnn.Parameter(name, true, [ argument ]);
  237. this._inputs.push(parameter);
  238. }
  239. get type() {
  240. return this._type;
  241. }
  242. get name() {
  243. return this._name;
  244. }
  245. get inputs() {
  246. return this._inputs;
  247. }
  248. get outputs() {
  249. return this._outputs;
  250. }
  251. get chain() {
  252. return this._chains;
  253. }
  254. get attributes() {
  255. return this._attributes;
  256. }
  257. };
  258. mnn.Attribute = class {
  259. constructor(schema, name, value, visible) {
  260. this._type = null;
  261. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  262. this._name = name;
  263. this._visible = visible ? true : false;
  264. if (schema) {
  265. if (schema.type) {
  266. this._type = schema.type;
  267. switch (this._type) {
  268. case 'DataType':
  269. this._value = mnn.Utility.dataType(this._value);
  270. break;
  271. default:
  272. this._value = mnn.Utility.enum(this._type, this._value);
  273. break;
  274. }
  275. }
  276. }
  277. }
  278. get name() {
  279. return this._name;
  280. }
  281. get type() {
  282. return this._type;
  283. }
  284. get value() {
  285. return this._value;
  286. }
  287. get visible() {
  288. return this._visible == false ? false : true;
  289. }
  290. };
  291. mnn.Parameter = class {
  292. constructor(name, visible, args) {
  293. this._name = name;
  294. this._visible = visible;
  295. this._arguments = args;
  296. }
  297. get name() {
  298. return this._name;
  299. }
  300. get visible() {
  301. return this._visible;
  302. }
  303. get arguments() {
  304. return this._arguments;
  305. }
  306. };
  307. mnn.Argument = class {
  308. constructor(name, type, initializer) {
  309. this._name = name;
  310. this._type = type || null;
  311. this._initializer = initializer || null;
  312. }
  313. get name() {
  314. return this._name;
  315. }
  316. get type() {
  317. if (this._initializer) {
  318. return this._initializer.type;
  319. }
  320. return this._type;
  321. }
  322. get initializer() {
  323. return this._initializer;
  324. }
  325. };
  326. mnn.Tensor = class {
  327. constructor(kind, type, data) {
  328. this._kind = kind;
  329. this._type = type;
  330. this._data = data ? data.slice(0) : null;
  331. }
  332. get kind() {
  333. return this._kind;
  334. }
  335. get type() {
  336. return this._type;
  337. }
  338. get state() {
  339. return this._context().state;
  340. }
  341. get value() {
  342. const context = this._context();
  343. if (context.state) {
  344. return null;
  345. }
  346. context.limit = Number.MAX_SAFE_INTEGER;
  347. return this._decode(context, 0);
  348. }
  349. toString() {
  350. const context = this._context();
  351. if (context.state) {
  352. return '';
  353. }
  354. context.limit = 10000;
  355. const value = this._decode(context, 0);
  356. return JSON.stringify(value, null, 4);
  357. }
  358. _context() {
  359. const context = {};
  360. context.state = null;
  361. if (!this._data || this._data.length === 0) {
  362. context.state = 'Tensor data is empty.';
  363. return context;
  364. }
  365. context.index = 0;
  366. context.count = 0;
  367. context.dataType = this._type.dataType;
  368. context.dimensions = this._type.shape.dimensions;
  369. switch (context.dataType) {
  370. case 'float16':
  371. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  372. break;
  373. default:
  374. context.data = this._data;
  375. break;
  376. }
  377. return context;
  378. }
  379. _decode(context, dimension) {
  380. let shape = context.dimensions;
  381. if (shape.length == 0) {
  382. shape = [ 1 ];
  383. }
  384. const results = [];
  385. const size = shape[dimension];
  386. if (dimension == shape.length - 1) {
  387. for (let i = 0; i < size; i++) {
  388. if (context.count > context.limit) {
  389. results.push('...');
  390. return results;
  391. }
  392. switch (context.dataType) {
  393. case 'float16':
  394. results.push(context.view.getFloat16(context.index, true));
  395. context.index += 2;
  396. break;
  397. default:
  398. results.push(context.data[context.index]);
  399. context.index++;
  400. break;
  401. }
  402. context.count++;
  403. }
  404. }
  405. else {
  406. for (let j = 0; j < size; j++) {
  407. if (context.count > context.limit) {
  408. results.push('...');
  409. return results;
  410. }
  411. results.push(this._decode(context, dimension + 1));
  412. }
  413. }
  414. if (context.dimensions.length == 0) {
  415. return results[0];
  416. }
  417. return results;
  418. }
  419. };
  420. mnn.TensorType = class {
  421. constructor(dataType, shape, format) {
  422. this._dataType = mnn.Utility.dataType(dataType);
  423. this._shape = shape;
  424. if (format) {
  425. switch (format) {
  426. case mnn.schema.MNN_DATA_FORMAT.NCHW: this._denotation = 'NCHW'; break;
  427. case mnn.schema.MNN_DATA_FORMAT.NHWC: this._denotation = 'NHWC'; break;
  428. case mnn.schema.MNN_DATA_FORMAT.NC4HW4: this._denotation = 'NC4HW4'; break;
  429. case mnn.schema.MNN_DATA_FORMAT.NHWC4: this._denotation = 'NHWC4'; break;
  430. default: throw new mnn.Error("Unsupported tensor type format '" + format + "'.");
  431. }
  432. }
  433. }
  434. get dataType() {
  435. return this._dataType;
  436. }
  437. get shape() {
  438. return this._shape;
  439. }
  440. get denotation() {
  441. return this._denotation;
  442. }
  443. toString() {
  444. return this._dataType + this._shape.toString();
  445. }
  446. };
  447. mnn.TensorShape = class {
  448. constructor(dimensions) {
  449. this._dimensions = Array.from(dimensions);
  450. }
  451. get dimensions() {
  452. return this._dimensions;
  453. }
  454. toString() {
  455. if (this._dimensions && this._dimensions.length > 0) {
  456. return '[' + this._dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']';
  457. }
  458. return '';
  459. }
  460. };
  461. mnn.Utility = class {
  462. static dataType(type) {
  463. switch (type) {
  464. case mnn.schema.DataType.DT_INVALID: return '?';
  465. case mnn.schema.DataType.DT_FLOAT: return 'float32';
  466. case mnn.schema.DataType.DT_DOUBLE: return 'float64';
  467. case mnn.schema.DataType.DT_INT32: return 'int32';
  468. case mnn.schema.DataType.DT_UINT8: return 'uint8';
  469. case mnn.schema.DataType.DT_INT16: return 'int16';
  470. case mnn.schema.DataType.DT_INT8: return 'int8';
  471. case mnn.schema.DataType.DT_STRING: return 'string';
  472. case mnn.schema.DataType.DT_COMPLEX64: return 'complex64';
  473. case mnn.schema.DataType.DT_INT64: return 'int64';
  474. case mnn.schema.DataType.DT_BOOL: return 'boolean';
  475. case mnn.schema.DataType.DT_QINT8: return 'qint8';
  476. case mnn.schema.DataType.DT_QUINT8: return 'quint8';
  477. case mnn.schema.DataType.DT_QINT32: return 'qint32';
  478. case mnn.schema.DataType.DT_BFLOAT16: return 'bfloat16';
  479. case mnn.schema.DataType.DT_QINT16: return 'qint16';
  480. case mnn.schema.DataType.DT_QUINT16: return 'quint16';
  481. case mnn.schema.DataType.DT_UINT16: return 'uint16';
  482. case mnn.schema.DataType.DT_COMPLEX128: return 'complex128';
  483. case mnn.schema.DataType.DT_HALF: return 'float16';
  484. case mnn.schema.DataType.DT_RESOURCE: return 'resource';
  485. case mnn.schema.DataType.DT_VARIANT: return 'variant';
  486. default: throw new mnn.Error("Unsupported data type '" + JSON.stringify(type) + "'.");
  487. }
  488. }
  489. static enum(name, value) {
  490. const type = name && mnn.schema ? mnn.schema[name] : undefined;
  491. if (type) {
  492. mnn.Utility._enumKeyMap = mnn.Utility._enumKeyMap || new Map();
  493. if (!mnn.Utility._enumKeyMap.has(name)) {
  494. const map = new Map();
  495. for (const key of Object.keys(type)) {
  496. map.set(type[key], key);
  497. }
  498. mnn.Utility._enumKeyMap.set(name, map);
  499. }
  500. const map = mnn.Utility._enumKeyMap.get(name);
  501. if (map.has(value)) {
  502. return map.get(value);
  503. }
  504. }
  505. return value.toString();
  506. }
  507. static createTensor(param, kind) {
  508. const type = new mnn.TensorType(param.dataType, new mnn.TensorShape(param.dims), param.dataFormat);
  509. let data = null;
  510. switch (type.dataType) {
  511. case 'uint8': data = param.uint8s; break;
  512. case 'int8': data = param.int8s; break;
  513. case 'int32': data = param.int32s; break;
  514. case 'int64': data = param.int64s; break;
  515. case 'float16': data = param.uint8s; break;
  516. case 'float32': data = param.float32s; break;
  517. default: throw new mnn.Error("Unsupported blob data type '" + JSON.stringify(type.dataType) + "'.");
  518. }
  519. return new mnn.Tensor(kind, type, data);
  520. }
  521. };
  522. mnn.Error = class extends Error {
  523. constructor(message) {
  524. super(message);
  525. this.name = 'Error loading MNN model.';
  526. }
  527. };
  528. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  529. module.exports.ModelFactory = mnn.ModelFactory;
  530. }