mnn.js 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. var mnn = mnn || {};
  2. var flatbuffers = flatbuffers || require('./flatbuffers');
  3. mnn.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. if (stream.length >= 4) {
  7. const extension = context.identifier.split('.').pop().toLowerCase();
  8. if (extension == 'mnn') {
  9. const buffer = stream.peek(4);
  10. const reader = flatbuffers.BinaryReader.open(buffer);
  11. if (reader.root === 0x00000018 || reader.root === 0x0000001C || reader.root === 0x00000020) {
  12. return 'mnn.flatbuffers';
  13. }
  14. }
  15. }
  16. return undefined;
  17. }
  18. open(context) {
  19. return context.require('./mnn-schema').then((/* schema */) => {
  20. let net = null;
  21. try {
  22. mnn.schema = flatbuffers.get('mnn').MNN;
  23. const stream = context.stream;
  24. const reader = flatbuffers.BinaryReader.open(stream);
  25. net = mnn.schema.Net.create(reader);
  26. }
  27. catch (error) {
  28. const message = error && error.message ? error.message : error.toString();
  29. throw new mnn.Error('File format is not mnn.Net (' + message.replace(/\.$/, '') + ').');
  30. }
  31. return mnn.Metadata.open(context).then((metadata) => {
  32. return new mnn.Model(metadata, net);
  33. });
  34. });
  35. }
  36. };
  37. mnn.Model = class {
  38. constructor(metadata, net) {
  39. const NetSource = mnn.schema.NetSource;
  40. switch (net.sourceType) {
  41. case NetSource.CAFFE: this._source = 'Caffe'; break;
  42. case NetSource.TENSORFLOW: this._source = 'TensorFlow'; break;
  43. case NetSource.TFLITE: this._source = 'TensorFlow Lite'; break;
  44. case NetSource.ONNX: this._source = 'ONNX'; break;
  45. case NetSource.TORCH: this._source = 'Torch'; break;
  46. default: throw new mnn.Error("Unsupported model source '" + net.sourceType + "'.");
  47. }
  48. this._graphs = [ new mnn.Graph(metadata, net) ];
  49. }
  50. get format() {
  51. return 'MNN v2';
  52. }
  53. get source() {
  54. return this._source || '';
  55. }
  56. get graphs() {
  57. return this._graphs;
  58. }
  59. };
  60. mnn.Graph = class {
  61. constructor(metadata, net) {
  62. this._nodes = [];
  63. this._inputs = [];
  64. this._outputs = [];
  65. for (let i = 0; i < net.tensorName.length; i++) {
  66. if (net.tensorName[i] === '') {
  67. net.tensorName[i] = '\n' + i.toString();
  68. }
  69. }
  70. const inputs = new Map();
  71. for (const op of net.oplists) {
  72. for (const input of op.inputIndexes) {
  73. inputs.set(input, (inputs.get(input) || 0) + 1);
  74. }
  75. }
  76. const consts = new Map();
  77. const oplists = net.oplists.filter((op) => {
  78. if (op.type === mnn.schema.OpType.Const &&
  79. op.inputIndexes.length === 0 &&
  80. op.outputIndexes.length === 1 &&
  81. op.main instanceof mnn.schema.Blob &&
  82. inputs.get(op.outputIndexes[0]) === 1) {
  83. consts.set(op.outputIndexes[0], op);
  84. return false;
  85. }
  86. return true;
  87. });
  88. const args = new Map();
  89. const arg = (index) => {
  90. if (!args.has(index)) {
  91. const name = net.tensorName[index];
  92. const op = consts.get(index);
  93. if (op) {
  94. const tensor = op ? mnn.Utility.createTensor(op.main, 'Const') : null;
  95. const argument = new mnn.Argument(name, null, tensor);
  96. args.set(index, argument);
  97. }
  98. else {
  99. const extraTensorDescribe = net.extraTensorDescribe[index];
  100. const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
  101. const type = blob && blob.dims && blob.dims.length > 0 ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims), blob.dataFormat) : null;
  102. const argument = new mnn.Argument(name, type, null);
  103. args.set(index, argument);
  104. }
  105. }
  106. return args.get(index);
  107. };
  108. for (const op of oplists) {
  109. if (op.type === mnn.schema.OpType.Input) {
  110. const args = Array.from(op.outputIndexes).map((index) => arg(index));
  111. this._inputs.push(new mnn.Parameter(op.name, true, args));
  112. }
  113. else {
  114. this._nodes.push(new mnn.Node(metadata, op, net, arg));
  115. }
  116. }
  117. for (let i = 0; i < net.tensorName.length; i++) {
  118. if (!inputs.has(i)) {
  119. const argument = arg(i);
  120. const parameter = new mnn.Parameter(argument.name, true, [ argument ]);
  121. this._outputs.push(parameter);
  122. }
  123. }
  124. }
  125. get name() {
  126. return '';
  127. }
  128. get nodes() {
  129. return this._nodes;
  130. }
  131. get outputs() {
  132. return this._outputs;
  133. }
  134. get inputs() {
  135. return this._inputs;
  136. }
  137. };
  138. mnn.Node = class {
  139. constructor(metadata, op, net, arg) {
  140. const type = mnn.Utility.enum('OpType', op.type) || '(' + op.type.toString() + ')';
  141. this._type = metadata.type(type) || { name: type };
  142. this._name = op.name || '';
  143. this._attributes = [];
  144. this._inputs = [];
  145. this._outputs = [];
  146. this._chains = [];
  147. if (op.inputIndexes && op.inputIndexes.length > 0) {
  148. this._inputs.push(new mnn.Parameter('input', true, Array.from(op.inputIndexes).map((index) => arg(index))));
  149. }
  150. if (op.outputIndexes && op.outputIndexes.length > 0) {
  151. this._outputs.push(new mnn.Parameter('output', true, Array.from(op.outputIndexes).map((index) => arg(index))));
  152. }
  153. const param = op.main;
  154. if (param) {
  155. const parameters = [ param ];
  156. if (param instanceof mnn.schema.Blob) {
  157. const tensor = mnn.Utility.createTensor(param, 'Blob');
  158. const argument = new mnn.Argument('', null, tensor);
  159. const parameter = new mnn.Parameter('value', true, [ argument ]);
  160. this._inputs.push(parameter);
  161. parameters.splice(0, parameters.length);
  162. }
  163. else if (param instanceof mnn.schema.Convolution2D) {
  164. const common = param.common;
  165. const outputCount = common.outputCount;
  166. const inputCount = common.inputCount;
  167. const kernelX = common.kernelX;
  168. const kernelY = common.kernelY;
  169. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount, kernelX, kernelY ], param.weight);
  170. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  171. delete param.weight;
  172. delete param.bias;
  173. delete param.quanParameter;
  174. delete param.symmetricQuan;
  175. }
  176. else if (param instanceof mnn.schema.InnerProduct) {
  177. const outputCount = param.outputCount;
  178. const inputCount = param.weightSize / outputCount;
  179. this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount ], param.weight);
  180. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
  181. delete param.weight;
  182. delete param.bias;
  183. delete param.quanParameter;
  184. }
  185. else if (param instanceof mnn.schema.Scale) {
  186. const scaleDataCount = param.channels;
  187. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.scaleData);
  188. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.biasData);
  189. delete param.scaleData;
  190. delete param.biasData;
  191. }
  192. else if (param instanceof mnn.schema.BatchNorm) {
  193. const channels = param.channels;
  194. this._buildTensor('mean', mnn.schema.DataType.DT_FLOAT, [ channels ], param.meanData);
  195. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ channels ], param.slopeData);
  196. this._buildTensor('variance', mnn.schema.DataType.DT_FLOAT, [ channels ], param.varData);
  197. this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ channels ], param.biasData);
  198. delete param.slopeData;
  199. delete param.meanData;
  200. delete param.varData;
  201. delete param.biasData;
  202. }
  203. else if (param instanceof mnn.schema.PRelu) {
  204. this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ param.slopeCount ], param.slope);
  205. delete param.slopeCount;
  206. }
  207. else if (param instanceof mnn.schema.Normalize) {
  208. this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ param.scale.length ], param.scale);
  209. delete param.scale;
  210. }
  211. while (parameters.length > 0) {
  212. const parameter = parameters.shift();
  213. for (const key of Object.keys(parameter)) {
  214. if (Object.prototype.hasOwnProperty.call(parameter, key)) {
  215. const value = parameter[key];
  216. if (Object.keys(mnn.schema).find((key) => mnn.schema[key].prototype && value instanceof mnn.schema[key])) {
  217. parameters.push(value);
  218. continue;
  219. }
  220. const schema = metadata.attribute(this.type, key);
  221. this._attributes.push(new mnn.Attribute(schema, key, value));
  222. }
  223. }
  224. }
  225. }
  226. }
  227. _buildTensor(name, dataType, dimensions, value) {
  228. const shape = new mnn.TensorShape(dimensions);
  229. const type = new mnn.TensorType(dataType, shape);
  230. const tensor = new mnn.Tensor('Weight', type, value);
  231. const argument = new mnn.Argument('', null, tensor);
  232. const parameter = new mnn.Parameter(name, true, [ argument ]);
  233. this._inputs.push(parameter);
  234. }
  235. get type() {
  236. return this._type;
  237. }
  238. get name() {
  239. return this._name;
  240. }
  241. get inputs() {
  242. return this._inputs;
  243. }
  244. get outputs() {
  245. return this._outputs;
  246. }
  247. get chain() {
  248. return this._chains;
  249. }
  250. get attributes() {
  251. return this._attributes;
  252. }
  253. };
  254. mnn.Attribute = class {
  255. constructor(schema, name, value, visible) {
  256. this._type = null;
  257. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  258. this._name = name;
  259. this._visible = visible ? true : false;
  260. if (schema) {
  261. if (schema.type) {
  262. this._type = schema.type;
  263. switch (this._type) {
  264. case 'DataType':
  265. this._value = mnn.Utility.dataType(this._value);
  266. break;
  267. default:
  268. this._value = mnn.Utility.enum(this._type, this._value);
  269. break;
  270. }
  271. }
  272. }
  273. }
  274. get name() {
  275. return this._name;
  276. }
  277. get type() {
  278. return this._type;
  279. }
  280. get value() {
  281. return this._value;
  282. }
  283. get visible() {
  284. return this._visible == false ? false : true;
  285. }
  286. };
  287. mnn.Parameter = class {
  288. constructor(name, visible, args) {
  289. this._name = name;
  290. this._visible = visible;
  291. this._arguments = args;
  292. }
  293. get name() {
  294. return this._name;
  295. }
  296. get visible() {
  297. return this._visible;
  298. }
  299. get arguments() {
  300. return this._arguments;
  301. }
  302. };
  303. mnn.Argument = class {
  304. constructor(name, type, initializer) {
  305. this._name = name;
  306. this._type = type || null;
  307. this._initializer = initializer || null;
  308. }
  309. get name() {
  310. return this._name;
  311. }
  312. get type() {
  313. if (this._initializer) {
  314. return this._initializer.type;
  315. }
  316. return this._type;
  317. }
  318. get initializer() {
  319. return this._initializer;
  320. }
  321. };
  322. mnn.Tensor = class {
  323. constructor(kind, type, data) {
  324. this._kind = kind;
  325. this._type = type;
  326. this._data = data ? data.slice(0) : null;
  327. }
  328. get kind() {
  329. return this._kind;
  330. }
  331. get type() {
  332. return this._type;
  333. }
  334. get state() {
  335. return this._context().state;
  336. }
  337. get value() {
  338. const context = this._context();
  339. if (context.state) {
  340. return null;
  341. }
  342. context.limit = Number.MAX_SAFE_INTEGER;
  343. return this._decode(context, 0);
  344. }
  345. toString() {
  346. const context = this._context();
  347. if (context.state) {
  348. return '';
  349. }
  350. context.limit = 10000;
  351. const value = this._decode(context, 0);
  352. return JSON.stringify(value, null, 4);
  353. }
  354. _context() {
  355. const context = {};
  356. context.state = null;
  357. if (!this._data || this._data.length === 0) {
  358. context.state = 'Tensor data is empty.';
  359. return context;
  360. }
  361. context.index = 0;
  362. context.count = 0;
  363. context.dataType = this._type.dataType;
  364. context.dimensions = this._type.shape.dimensions;
  365. switch (context.dataType) {
  366. case 'float16':
  367. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  368. break;
  369. default:
  370. context.data = this._data;
  371. break;
  372. }
  373. return context;
  374. }
  375. _decode(context, dimension) {
  376. let shape = context.dimensions;
  377. if (shape.length == 0) {
  378. shape = [ 1 ];
  379. }
  380. const results = [];
  381. const size = shape[dimension];
  382. if (dimension == shape.length - 1) {
  383. for (let i = 0; i < size; i++) {
  384. if (context.count > context.limit) {
  385. results.push('...');
  386. return results;
  387. }
  388. switch (context.dataType) {
  389. case 'float16':
  390. results.push(context.view.getFloat16(context.index, true));
  391. context.index += 2;
  392. break;
  393. default:
  394. results.push(context.data[context.index]);
  395. context.index++;
  396. break;
  397. }
  398. context.count++;
  399. }
  400. }
  401. else {
  402. for (let j = 0; j < size; j++) {
  403. if (context.count > context.limit) {
  404. results.push('...');
  405. return results;
  406. }
  407. results.push(this._decode(context, dimension + 1));
  408. }
  409. }
  410. if (context.dimensions.length == 0) {
  411. return results[0];
  412. }
  413. return results;
  414. }
  415. };
  416. mnn.TensorType = class {
  417. constructor(dataType, shape, format) {
  418. this._dataType = mnn.Utility.dataType(dataType);
  419. this._shape = shape;
  420. if (format) {
  421. switch (format) {
  422. case mnn.schema.MNN_DATA_FORMAT.NCHW: this._denotation = 'NCHW'; break;
  423. case mnn.schema.MNN_DATA_FORMAT.NHWC: this._denotation = 'NHWC'; break;
  424. case mnn.schema.MNN_DATA_FORMAT.NC4HW4: this._denotation = 'NC4HW4'; break;
  425. case mnn.schema.MNN_DATA_FORMAT.NHWC4: this._denotation = 'NHWC4'; break;
  426. default: throw new mnn.Error("Unsupported tensor type format '" + format + "'.");
  427. }
  428. }
  429. }
  430. get dataType() {
  431. return this._dataType;
  432. }
  433. get shape() {
  434. return this._shape;
  435. }
  436. get denotation() {
  437. return this._denotation;
  438. }
  439. toString() {
  440. return this._dataType + this._shape.toString();
  441. }
  442. };
  443. mnn.TensorShape = class {
  444. constructor(dimensions) {
  445. this._dimensions = Array.from(dimensions);
  446. }
  447. get dimensions() {
  448. return this._dimensions;
  449. }
  450. toString() {
  451. if (this._dimensions && this._dimensions.length > 0) {
  452. return '[' + this._dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']';
  453. }
  454. return '';
  455. }
  456. };
  457. mnn.Metadata = class {
  458. static open(context) {
  459. if (mnn.Metadata._metadata) {
  460. return Promise.resolve(mnn.Metadata._metadata);
  461. }
  462. return context.request('mnn-metadata.json', 'utf-8', null).then((data) => {
  463. mnn.Metadata._metadata = new mnn.Metadata(data);
  464. return mnn.Metadata._metadata;
  465. }).catch(() => {
  466. mnn.Metadata._metadata = new mnn.Metadata(null);
  467. return mnn.Metadata._metadata;
  468. });
  469. }
  470. constructor(data) {
  471. this._map = new Map();
  472. if (data) {
  473. const metadata = JSON.parse(data);
  474. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  475. }
  476. }
  477. type(name) {
  478. return this._map.get(name);
  479. }
  480. attribute(type, name) {
  481. const schema = this.type(type);
  482. if (schema) {
  483. let attributeMap = schema.attributeMap;
  484. if (!attributeMap) {
  485. attributeMap = {};
  486. if (schema.attributes) {
  487. for (const attribute of schema.attributes) {
  488. attributeMap[attribute.name] = attribute;
  489. }
  490. }
  491. schema.attributeMap = attributeMap;
  492. }
  493. const attributeSchema = attributeMap[name];
  494. if (attributeSchema) {
  495. return attributeSchema;
  496. }
  497. }
  498. return null;
  499. }
  500. };
  501. mnn.Utility = class {
  502. static dataType(type) {
  503. switch (type) {
  504. case mnn.schema.DataType.DT_INVALID: return '?';
  505. case mnn.schema.DataType.DT_FLOAT: return 'float32';
  506. case mnn.schema.DataType.DT_DOUBLE: return 'float64';
  507. case mnn.schema.DataType.DT_INT32: return 'int32';
  508. case mnn.schema.DataType.DT_UINT8: return 'uint8';
  509. case mnn.schema.DataType.DT_INT16: return 'int16';
  510. case mnn.schema.DataType.DT_INT8: return 'int8';
  511. case mnn.schema.DataType.DT_STRING: return 'string';
  512. case mnn.schema.DataType.DT_COMPLEX64: return 'complex64';
  513. case mnn.schema.DataType.DT_INT64: return 'int64';
  514. case mnn.schema.DataType.DT_BOOL: return 'boolean';
  515. case mnn.schema.DataType.DT_QINT8: return 'qint8';
  516. case mnn.schema.DataType.DT_QUINT8: return 'quint8';
  517. case mnn.schema.DataType.DT_QINT32: return 'qint32';
  518. case mnn.schema.DataType.DT_BFLOAT16: return 'bfloat16';
  519. case mnn.schema.DataType.DT_QINT16: return 'qint16';
  520. case mnn.schema.DataType.DT_QUINT16: return 'quint16';
  521. case mnn.schema.DataType.DT_UINT16: return 'uint16';
  522. case mnn.schema.DataType.DT_COMPLEX128: return 'complex128';
  523. case mnn.schema.DataType.DT_HALF: return 'float16';
  524. case mnn.schema.DataType.DT_RESOURCE: return 'resource';
  525. case mnn.schema.DataType.DT_VARIANT: return 'variant';
  526. default: throw new mnn.Error("Unsupported data type '" + JSON.stringify(type) + "'.");
  527. }
  528. }
  529. static enum(name, value) {
  530. const type = name && mnn.schema ? mnn.schema[name] : undefined;
  531. if (type) {
  532. mnn.Utility._enumKeyMap = mnn.Utility._enumKeyMap || new Map();
  533. if (!mnn.Utility._enumKeyMap.has(name)) {
  534. const map = new Map();
  535. for (const key of Object.keys(type)) {
  536. map.set(type[key], key);
  537. }
  538. mnn.Utility._enumKeyMap.set(name, map);
  539. }
  540. const map = mnn.Utility._enumKeyMap.get(name);
  541. if (map.has(value)) {
  542. return map.get(value);
  543. }
  544. }
  545. return value.toString();
  546. }
  547. static createTensor(param, kind) {
  548. const type = new mnn.TensorType(param.dataType, new mnn.TensorShape(param.dims), param.dataFormat);
  549. let data = null;
  550. switch (type.dataType) {
  551. case 'uint8': data = param.uint8s; break;
  552. case 'int8': data = param.int8s; break;
  553. case 'int32': data = param.int32s; break;
  554. case 'int64': data = param.int64s; break;
  555. case 'float16': data = param.uint8s; break;
  556. case 'float32': data = param.float32s; break;
  557. default: throw new mnn.Error("Unsupported blob data type '" + JSON.stringify(type.dataType) + "'.");
  558. }
  559. return new mnn.Tensor(kind, type, data);
  560. }
  561. };
  562. mnn.Error = class extends Error {
  563. constructor(message) {
  564. super(message);
  565. this.name = 'Error loading MNN model.';
  566. }
  567. };
  568. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  569. module.exports.ModelFactory = mnn.ModelFactory;
  570. }