mnn.js 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. /* jshint esversion: 6 */
  2. var mnn = mnn || {};
  3. var flatbuffers = flatbuffers || require('./flatbuffers');
  4. mnn.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. if (stream.length >= 4) {
  8. const extension = context.identifier.split('.').pop().toLowerCase();
  9. if (extension == 'mnn') {
  10. const buffer = stream.peek(4);
  11. const reader = flatbuffers.BinaryReader.open(buffer);
  12. if (reader.root === 0x00000018 || reader.root === 0x0000001C || reader.root === 0x00000020) {
  13. return true;
  14. }
  15. }
  16. }
  17. return false;
  18. }
  19. open(context) {
  20. return context.require('./mnn-schema').then((/* schema */) => {
  21. let net = null;
  22. try {
  23. mnn.schema = flatbuffers.get('mnn').MNN;
  24. const stream = context.stream;
  25. const reader = flatbuffers.BinaryReader.open(stream);
  26. net = mnn.schema.Net.create(reader);
  27. }
  28. catch (error) {
  29. const message = error && error.message ? error.message : error.toString();
  30. throw new mnn.Error('File format is not mnn.Net (' + message.replace(/\.$/, '') + ').');
  31. }
  32. return mnn.Metadata.open(context).then((metadata) => {
  33. return new mnn.Model(metadata, net);
  34. });
  35. });
  36. }
  37. };
  38. mnn.Model = class {
  39. constructor(metadata, net) {
  40. const NetSource = mnn.schema.NetSource;
  41. switch (net.sourceType) {
  42. case NetSource.CAFFE: this._source = 'Caffe'; break;
  43. case NetSource.TENSORFLOW: this._source = 'TensorFlow'; break;
  44. case NetSource.TFLITE: this._source = 'TensorFlow Lite'; break;
  45. case NetSource.ONNX: this._source = 'ONNX'; break;
  46. case NetSource.TORCH: this._source = 'Torch'; break;
  47. }
  48. this._graphs = [ new mnn.Graph(metadata, net) ];
  49. }
  50. get format() {
  51. return 'MNN v2';
  52. }
  53. get source() {
  54. return this._source || '';
  55. }
  56. get graphs() {
  57. return this._graphs;
  58. }
  59. };
  60. mnn.Graph = class {
  61. constructor(metadata, net) {
  62. this._nodes = [];
  63. this._inputs = [];
  64. this._outputs = [];
  65. const inputSet = new Set();
  66. for (let i = 0; i < net.tensorName.length; i++) {
  67. if (net.tensorName[i] === '') {
  68. net.tensorName[i] = '\n' + i.toString();
  69. }
  70. }
  71. for (let i = 0; i < net.oplists.length; i++) {
  72. const op = net.oplists[i];
  73. if (op.type === mnn.schema.OpType.Input) {
  74. const args = [];
  75. for (let j = 0; j < op.outputIndexes.length; j++) {
  76. const index = op.outputIndexes[j];
  77. const name = net.tensorName[index];
  78. const extraTensorDescribe = net.extraTensorDescribe[index];
  79. const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
  80. const type = blob ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims)) : null;
  81. args.push(new mnn.Argument(name, type, null));
  82. }
  83. this._inputs.push(new mnn.Parameter(op.name, true, args));
  84. }
  85. else {
  86. this._nodes.push(new mnn.Node(metadata, op, net));
  87. }
  88. for (let k = 0; k < op.inputIndexes.length; k++) {
  89. const index = op.inputIndexes[k];
  90. inputSet.add(index);
  91. }
  92. }
  93. for (let i = 0; i < net.tensorName.length; i++) {
  94. if (!inputSet.has(i)) {
  95. const name = net.tensorName[i];
  96. const extraTensorDescribe = net.extraTensorDescribe[i];
  97. const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
  98. const type = blob ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims)) : null;
  99. this._outputs.push(new mnn.Parameter(name, true, [
  100. new mnn.Argument(name, type, null)
  101. ]));
  102. }
  103. }
  104. }
  105. get name() {
  106. return '';
  107. }
  108. get groups() {
  109. return false;
  110. }
  111. get nodes() {
  112. return this._nodes;
  113. }
  114. get outputs() {
  115. return this._outputs;
  116. }
  117. get inputs() {
  118. return this._inputs;
  119. }
  120. };
  121. mnn.Node = class {
  122. constructor(metadata, op, net) {
  123. const type = mnn.Utility.enum('OpType', op.type) || '(' + op.type.toString() + ')';
  124. this._type = metadata.type(type) || { name: type };
  125. this._name = op.name || '';
  126. this._attributes = [];
  127. this._inputs = [];
  128. this._outputs = [];
  129. this._chains = [];
  130. const inputs = [];
  131. for (let i = 0; i < op.inputIndexes.length; i++) {
  132. const index = op.inputIndexes[i];
  133. const id = net.tensorName[index];
  134. inputs.push(new mnn.Argument(id, null, null));
  135. }
  136. this._inputs.push(new mnn.Parameter('input', true, inputs));
  137. const outputs = [];
  138. for (let i = 0; i < op.outputIndexes.length; i++) {
  139. const index = op.outputIndexes[i];
  140. const name = net.tensorName[index];
  141. outputs.push(new mnn.Argument(name, null, null));
  142. }
  143. this._outputs.push(new mnn.Parameter('output', true, outputs));
  144. const ignoreAttributes = new Set();
  145. const parameter = op.main;
  146. if (parameter) {
  147. const parameters = [ parameter ];
  148. if (parameter instanceof mnn.schema.Blob) {
  149. const type = new mnn.TensorType(parameter.dataType, new mnn.TensorShape(parameter.dims));
  150. let data = null;
  151. switch (type.dataType) {
  152. case 'uint8': data = parameter.uint8s; break;
  153. case 'int8': data = parameter.int8s; break;
  154. case 'int32': data = parameter.int32s; break;
  155. case 'int64': data = parameter.int64s; break;
  156. case 'float16': data = parameter.uint8s; break;
  157. case 'float32': data = parameter.float32s; break;
  158. default:
  159. throw new mnn.Error("Unknown blob data type '" + JSON.stringify(type.dataType) + "'.");
  160. }
  161. this._inputs.push(new mnn.Parameter('value', true, [
  162. new mnn.Argument('', null, new mnn.Tensor('Blob', type, data))
  163. ]));
  164. parameters.splice(0, parameters.length);
  165. }
  166. else if (parameter instanceof mnn.schema.Convolution2D) {
  167. const common = parameter.common;
  168. const outputCount = common.outputCount;
  169. const inputCount = common.inputCount;
  170. const kernelX = common.kernelX;
  171. const kernelY = common.kernelY;
  172. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'weight', [ outputCount, inputCount, kernelX, kernelY ], parameter.weight);
  173. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ outputCount ], parameter.bias);
  174. ignoreAttributes.add('weight');
  175. ignoreAttributes.add('bias');
  176. ignoreAttributes.add('quanParameter');
  177. ignoreAttributes.add('symmetricQuan');
  178. }
  179. else if (parameter instanceof mnn.schema.InnerProduct) {
  180. const outputCount = parameter.outputCount;
  181. const inputCount = parameter.weightSize / outputCount;
  182. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'weight', [ outputCount, inputCount ], parameter.weight);
  183. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ outputCount ], parameter.bias);
  184. ignoreAttributes.add('weight');
  185. ignoreAttributes.add('bias');
  186. ignoreAttributes.add('quanParameter');
  187. }
  188. else if (parameter instanceof mnn.schema.Scale) {
  189. const scaleDataCount = parameter.channels;
  190. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'scale', [ scaleDataCount ], parameter.scaleData);
  191. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ scaleDataCount ], parameter.biasData);
  192. ignoreAttributes.add('scaleData');
  193. ignoreAttributes.add('biasData');
  194. }
  195. else if (parameter instanceof mnn.schema.BatchNorm) {
  196. const channels = parameter.channels;
  197. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'mean', [ channels ], parameter.meanData);
  198. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'slope', [ channels ], parameter.slopeData);
  199. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'variance', [ channels ], parameter.varData);
  200. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ channels ], parameter.biasData);
  201. ignoreAttributes.add('slopeData');
  202. ignoreAttributes.add('meanData');
  203. ignoreAttributes.add('varData');
  204. ignoreAttributes.add('biasData');
  205. }
  206. else if (parameter instanceof mnn.schema.PRelu) {
  207. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'slope', [ parameter.slopeCount ], parameter.slope);
  208. ignoreAttributes.add('slope');
  209. }
  210. else if (parameter instanceof mnn.schema.Normalize) {
  211. this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'scale', [ parameter.scale.length ], parameter.scale);
  212. ignoreAttributes.add('scale');
  213. }
  214. while (parameters.length > 0) {
  215. const parameter = parameters.shift();
  216. for (const key of Object.keys(parameter)) {
  217. if (ignoreAttributes && ignoreAttributes.has(key)) {
  218. continue;
  219. }
  220. if (Object.prototype.hasOwnProperty.call(parameter, key)) {
  221. const value = parameter[key];
  222. if (Object.keys(mnn.schema).find((key) => mnn.schema[key].prototype && value instanceof mnn.schema[key])) {
  223. parameters.push(value);
  224. continue;
  225. }
  226. const schema = metadata.attribute(this.type, key);
  227. this._attributes.push(new mnn.Attribute(schema, key, value));
  228. }
  229. }
  230. }
  231. }
  232. }
  233. _buildTensor(dataType, name, dimensions, value) {
  234. this._inputs.push(new mnn.Parameter(name, true, [
  235. new mnn.Argument('', null, new mnn.Tensor('Weight', new mnn.TensorType(dataType, new mnn.TensorShape(dimensions)), value))
  236. ]));
  237. }
  238. get type() {
  239. return this._type;
  240. }
  241. get name() {
  242. return this._name;
  243. }
  244. get group() {
  245. return null;
  246. }
  247. get inputs() {
  248. return this._inputs;
  249. }
  250. get outputs() {
  251. return this._outputs;
  252. }
  253. get chain() {
  254. return this._chains;
  255. }
  256. get attributes() {
  257. return this._attributes;
  258. }
  259. };
  260. mnn.Attribute = class {
  261. constructor(schema, name, value, visible) {
  262. this._type = null;
  263. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  264. this._name = name;
  265. this._visible = visible ? true : false;
  266. if (schema) {
  267. if (schema.type) {
  268. this._type = schema.type;
  269. switch (this._type) {
  270. case 'DataType':
  271. this._value = mnn.Utility.dataType(this._value);
  272. break;
  273. default:
  274. this._value = mnn.Utility.enum(this._type, this._value);
  275. break;
  276. }
  277. }
  278. }
  279. }
  280. get name() {
  281. return this._name;
  282. }
  283. get type() {
  284. return this._type;
  285. }
  286. get value() {
  287. return this._value;
  288. }
  289. get visible() {
  290. return this._visible == false ? false : true;
  291. }
  292. };
  293. mnn.Parameter = class {
  294. constructor(name, visible, args) {
  295. this._name = name;
  296. this._visible = visible;
  297. this._arguments = args;
  298. }
  299. get name() {
  300. return this._name;
  301. }
  302. get visible() {
  303. return this._visible;
  304. }
  305. get arguments() {
  306. return this._arguments;
  307. }
  308. };
  309. mnn.Argument = class {
  310. constructor(name, type, initializer) {
  311. this._name = name;
  312. this._type = type || null;
  313. this._initializer = initializer || null;
  314. }
  315. get name() {
  316. return this._name;
  317. }
  318. get type() {
  319. if (this._initializer) {
  320. return this._initializer.type;
  321. }
  322. return this._type;
  323. }
  324. get initializer() {
  325. return this._initializer;
  326. }
  327. };
  328. mnn.Tensor = class {
  329. constructor(kind, type, data) {
  330. this._kind = kind;
  331. this._type = type;
  332. this._data = data ? data.slice(0) : null;
  333. }
  334. get kind() {
  335. return this._kind;
  336. }
  337. get type() {
  338. return this._type;
  339. }
  340. get state() {
  341. return this._context().state;
  342. }
  343. get value() {
  344. const context = this._context();
  345. if (context.state) {
  346. return null;
  347. }
  348. context.limit = Number.MAX_SAFE_INTEGER;
  349. return this._decode(context, 0);
  350. }
  351. toString() {
  352. const context = this._context();
  353. if (context.state) {
  354. return '';
  355. }
  356. context.limit = 10000;
  357. const value = this._decode(context, 0);
  358. return JSON.stringify(value, null, 4);
  359. }
  360. _context() {
  361. const context = {};
  362. context.state = null;
  363. if (!this._data || this._data.length === 0) {
  364. context.state = 'Tensor data is empty.';
  365. return context;
  366. }
  367. context.index = 0;
  368. context.count = 0;
  369. context.dataType = this._type.dataType;
  370. context.dimensions = this._type.shape.dimensions;
  371. switch (context.dataType) {
  372. case 'float16':
  373. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  374. break;
  375. default:
  376. context.data = this._data;
  377. break;
  378. }
  379. return context;
  380. }
  381. _decode(context, dimension) {
  382. let shape = context.dimensions;
  383. if (shape.length == 0) {
  384. shape = [ 1 ];
  385. }
  386. const results = [];
  387. const size = shape[dimension];
  388. if (dimension == shape.length - 1) {
  389. for (let i = 0; i < size; i++) {
  390. if (context.count > context.limit) {
  391. results.push('...');
  392. return results;
  393. }
  394. switch (context.dataType) {
  395. case 'float16':
  396. results.push(context.view.getFloat16(context.index, true));
  397. context.index += 2;
  398. break;
  399. default:
  400. results.push(context.data[context.index]);
  401. context.index++;
  402. break;
  403. }
  404. context.count++;
  405. }
  406. }
  407. else {
  408. for (let j = 0; j < size; j++) {
  409. if (context.count > context.limit) {
  410. results.push('...');
  411. return results;
  412. }
  413. results.push(this._decode(context, dimension + 1));
  414. }
  415. }
  416. if (context.dimensions.length == 0) {
  417. return results[0];
  418. }
  419. return results;
  420. }
  421. };
  422. mnn.TensorType = class {
  423. constructor(dataType, shape) {
  424. this._dataType = mnn.Utility.dataType(dataType);
  425. this._shape = shape;
  426. }
  427. get dataType() {
  428. return this._dataType;
  429. }
  430. get shape() {
  431. return this._shape;
  432. }
  433. toString() {
  434. return this._dataType + this._shape.toString();
  435. }
  436. };
  437. mnn.TensorShape = class {
  438. constructor(dimensions) {
  439. this._dimensions = Array.from(dimensions);
  440. }
  441. get dimensions() {
  442. return this._dimensions;
  443. }
  444. toString() {
  445. if (this._dimensions && this._dimensions.length > 0) {
  446. return '[' + this._dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']';
  447. }
  448. return '';
  449. }
  450. };
  451. mnn.Metadata = class {
  452. static open(context) {
  453. if (mnn.Metadata._metadata) {
  454. return Promise.resolve(mnn.Metadata._metadata);
  455. }
  456. return context.request('mnn-metadata.json', 'utf-8', null).then((data) => {
  457. mnn.Metadata._metadata = new mnn.Metadata(data);
  458. return mnn.Metadata._metadata;
  459. }).catch(() => {
  460. mnn.Metadata._metadata = new mnn.Metadata(null);
  461. return mnn.Metadata._metadata;
  462. });
  463. }
  464. constructor(data) {
  465. this._map = new Map();
  466. if (data) {
  467. const metadata = JSON.parse(data);
  468. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  469. }
  470. }
  471. type(name) {
  472. return this._map.get(name);
  473. }
  474. attribute(type, name) {
  475. const schema = this.type(type);
  476. if (schema) {
  477. let attributeMap = schema.attributeMap;
  478. if (!attributeMap) {
  479. attributeMap = {};
  480. if (schema.attributes) {
  481. for (const attribute of schema.attributes) {
  482. attributeMap[attribute.name] = attribute;
  483. }
  484. }
  485. schema.attributeMap = attributeMap;
  486. }
  487. const attributeSchema = attributeMap[name];
  488. if (attributeSchema) {
  489. return attributeSchema;
  490. }
  491. }
  492. return null;
  493. }
  494. };
  495. mnn.Utility = class {
  496. static dataType(type) {
  497. switch (type) {
  498. case mnn.schema.DataType.DT_INVALID: return '?';
  499. case mnn.schema.DataType.DT_FLOAT: return 'float32';
  500. case mnn.schema.DataType.DT_DOUBLE: return 'float64';
  501. case mnn.schema.DataType.DT_INT32: return 'int32';
  502. case mnn.schema.DataType.DT_UINT8: return 'uint8';
  503. case mnn.schema.DataType.DT_INT16: return 'int16';
  504. case mnn.schema.DataType.DT_INT8: return 'int8';
  505. case mnn.schema.DataType.DT_STRING: return 'string';
  506. case mnn.schema.DataType.DT_COMPLEX64: return 'complex64';
  507. case mnn.schema.DataType.DT_INT64: return 'int64';
  508. case mnn.schema.DataType.DT_BOOL: return 'boolean';
  509. case mnn.schema.DataType.DT_QINT8: return 'qint8';
  510. case mnn.schema.DataType.DT_QUINT8: return 'quint8';
  511. case mnn.schema.DataType.DT_QINT32: return 'qint32';
  512. case mnn.schema.DataType.DT_BFLOAT16: return 'bfloat16';
  513. case mnn.schema.DataType.DT_QINT16: return 'qint16';
  514. case mnn.schema.DataType.DT_QUINT16: return 'quint16';
  515. case mnn.schema.DataType.DT_UINT16: return 'uint16';
  516. case mnn.schema.DataType.DT_COMPLEX128: return 'complex128';
  517. case mnn.schema.DataType.DT_HALF: return 'float16';
  518. case mnn.schema.DataType.DT_RESOURCE: return 'resource';
  519. case mnn.schema.DataType.DT_VARIANT: return 'variant';
  520. }
  521. throw new mnn.Error("Unknown data type '" + JSON.stringify(type) + "'.");
  522. }
  523. static enum(name, value) {
  524. const type = name && mnn.schema ? mnn.schema[name] : undefined;
  525. if (type) {
  526. mnn.Utility._enumKeyMap = mnn.Utility._enumKeyMap || new Map();
  527. if (!mnn.Utility._enumKeyMap.has(name)) {
  528. const map = new Map();
  529. for (const key of Object.keys(type)) {
  530. map.set(type[key], key);
  531. }
  532. mnn.Utility._enumKeyMap.set(name, map);
  533. }
  534. const map = mnn.Utility._enumKeyMap.get(name);
  535. if (map.has(value)) {
  536. return map.get(value);
  537. }
  538. }
  539. return value.toString();
  540. }
  541. };
  542. mnn.Error = class extends Error {
  543. constructor(message) {
  544. super(message);
  545. this.name = 'Error loading MNN model.';
  546. }
  547. };
  548. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  549. module.exports.ModelFactory = mnn.ModelFactory;
  550. }