armnn.js 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. var armnn = armnn || {};
  2. var flatbuffers = flatbuffers || require('./flatbuffers');
  3. armnn.ModelFactory = class {
  4. match(context) {
  5. const identifier = context.identifier;
  6. const extension = identifier.split('.').pop().toLowerCase();
  7. if (extension === 'armnn') {
  8. return 'armnn.flatbuffers';
  9. }
  10. if (extension === 'json') {
  11. const obj = context.open('json');
  12. if (obj && obj.layers && obj.inputIds && obj.outputIds) {
  13. return 'armnn.flatbuffers.json';
  14. }
  15. }
  16. return undefined;
  17. }
  18. open(context, match) {
  19. return context.require('./armnn-schema').then((/* schema */) => {
  20. armnn.schema = flatbuffers.get('armnn').armnnSerializer;
  21. let model = null;
  22. switch (match) {
  23. case 'armnn.flatbuffers': {
  24. try {
  25. const stream = context.stream;
  26. const reader = flatbuffers.BinaryReader.open(stream);
  27. model = armnn.schema.SerializedGraph.create(reader);
  28. }
  29. catch (error) {
  30. const message = error && error.message ? error.message : error.toString();
  31. throw new armnn.Error('File format is not armnn.SerializedGraph (' + message.replace(/\.$/, '') + ').');
  32. }
  33. break;
  34. }
  35. case 'armnn.flatbuffers.json': {
  36. try {
  37. const obj = context.open('json');
  38. const reader = flatbuffers.TextReader.open(obj);
  39. model = armnn.schema.SerializedGraph.createText(reader);
  40. }
  41. catch (error) {
  42. const message = error && error.message ? error.message : error.toString();
  43. throw new armnn.Error('File text format is not armnn.SerializedGraph (' + message.replace(/\.$/, '') + ').');
  44. }
  45. break;
  46. }
  47. default: {
  48. throw new armnn.Error("Unsupported Arm NN '" + match + "'.");
  49. }
  50. }
  51. return context.metadata('armnn-metadata.json').then((metadata) => {
  52. return new armnn.Model(metadata, model);
  53. });
  54. });
  55. }
  56. };
  57. armnn.Model = class {
  58. constructor(metadata, model) {
  59. this._graphs = [];
  60. this._graphs.push(new armnn.Graph(metadata, model));
  61. }
  62. get format() {
  63. return 'Arm NN';
  64. }
  65. get description() {
  66. return '';
  67. }
  68. get graphs() {
  69. return this._graphs;
  70. }
  71. };
  72. armnn.Graph = class {
  73. constructor(metadata, graph) {
  74. this._name = '';
  75. this._nodes = [];
  76. this._inputs = [];
  77. this._outputs = [];
  78. const counts = new Map();
  79. for (const layer of graph.layers) {
  80. const base = armnn.Node.getBase(layer);
  81. for (const slot of base.inputSlots) {
  82. const name = slot.connection.sourceLayerIndex.toString() + ':' + slot.connection.outputSlotIndex.toString();
  83. counts.set(name, counts.has(name) ? counts.get(name) + 1 : 1);
  84. }
  85. }
  86. const args = new Map();
  87. const arg = (layerIndex, slotIndex, tensor) => {
  88. const name = layerIndex.toString() + ':' + slotIndex.toString();
  89. if (!args.has(name)) {
  90. const layer = graph.layers[layerIndex];
  91. const base = layerIndex < graph.layers.length ? armnn.Node.getBase(layer) : null;
  92. const tensorInfo = base && slotIndex < base.outputSlots.length ? base.outputSlots[slotIndex].tensorInfo : null;
  93. args.set(name, new armnn.Argument(name, tensorInfo, tensor));
  94. }
  95. return args.get(name);
  96. };
  97. const layers = graph.layers.filter((layer) => {
  98. const base = armnn.Node.getBase(layer);
  99. if (base.layerType == armnn.schema.LayerType.Constant && base.outputSlots.length === 1 && layer.layer.input) {
  100. const slot = base.outputSlots[0];
  101. const name = base.index.toString() + ':' + slot.index.toString();
  102. if (counts.get(name) === 1) {
  103. const tensor = new armnn.Tensor(layer.layer.input, 'Constant');
  104. arg(base.index, slot.index, tensor);
  105. return false;
  106. }
  107. }
  108. return true;
  109. });
  110. for (const layer of layers) {
  111. const base = armnn.Node.getBase(layer);
  112. for (const slot of base.inputSlots) {
  113. arg(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex);
  114. }
  115. }
  116. for (const layer of layers) {
  117. const base = armnn.Node.getBase(layer);
  118. switch (base.layerType) {
  119. case armnn.schema.LayerType.Input: {
  120. const name = base ? base.layerName : '';
  121. for (const slot of base.outputSlots) {
  122. const argument = arg(base.index, slot.index);
  123. this._inputs.push(new armnn.Parameter(name, [ argument ]));
  124. }
  125. break;
  126. }
  127. case armnn.schema.LayerType.Output: {
  128. const base = armnn.Node.getBase(layer);
  129. const name = base ? base.layerName : '';
  130. for (const slot of base.inputSlots) {
  131. const argument = arg(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex);
  132. this._outputs.push(new armnn.Parameter(name, [ argument ]));
  133. }
  134. break;
  135. }
  136. default:
  137. this._nodes.push(new armnn.Node(metadata, layer, arg));
  138. break;
  139. }
  140. }
  141. }
  142. get name() {
  143. return this._name;
  144. }
  145. get inputs() {
  146. return this._inputs;
  147. }
  148. get outputs() {
  149. return this._outputs;
  150. }
  151. get nodes() {
  152. return this._nodes;
  153. }
  154. };
  155. armnn.Node = class {
  156. constructor(metadata, layer, arg) {
  157. const type = layer.layer.constructor.name;
  158. this._type = Object.assign({}, metadata.type(type) || { name: type });
  159. this._type.name = this._type.name.replace(/Layer$/, '');
  160. this._name = '';
  161. this._outputs = [];
  162. this._inputs = [];
  163. this._attributes = [];
  164. const inputSchemas = (this._type && this._type.inputs) ? [...this._type.inputs] : [ { name: 'input' } ];
  165. const outputSchemas = (this._type && this._type.outputs) ? [...this._type.outputs] : [ { name: 'output' } ];
  166. const base = armnn.Node.getBase(layer);
  167. if (base) {
  168. this._name = base.layerName;
  169. const inputSlots = [...base.inputSlots];
  170. while (inputSlots.length > 0) {
  171. const inputSchema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: '?' };
  172. const inputCount = inputSchema.list ? inputSlots.length : 1;
  173. this._inputs.push(new armnn.Parameter(inputSchema.name, inputSlots.splice(0, inputCount).map((inputSlot) => {
  174. return arg(inputSlot.connection.sourceLayerIndex, inputSlot.connection.outputSlotIndex);
  175. })));
  176. }
  177. const outputSlots = [...base.outputSlots];
  178. while (outputSlots.length > 0) {
  179. const outputSchema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: '?' };
  180. const outputCount = outputSchema.list ? outputSlots.length : 1;
  181. this._outputs.push(new armnn.Parameter(outputSchema.name, outputSlots.splice(0, outputCount).map((outputSlot) => {
  182. return arg(base.index, outputSlot.index);
  183. })));
  184. }
  185. }
  186. if (layer.layer && layer.layer.descriptor && this._type.attributes) {
  187. for (const pair of Object.entries(layer.layer.descriptor)) {
  188. const name = pair[0];
  189. const value = pair[1];
  190. const attribute = new armnn.Attribute(metadata.attribute(type, name), name, value);
  191. this._attributes.push(attribute);
  192. }
  193. }
  194. if (layer.layer) {
  195. for (const entry of Object.entries(layer.layer).filter((entry) => entry[1] instanceof armnn.schema.ConstTensor)) {
  196. const name = entry[0];
  197. const tensor = entry[1];
  198. const argument = new armnn.Argument('', tensor.info, new armnn.Tensor(tensor));
  199. this._inputs.push(new armnn.Parameter(name, [ argument ]));
  200. }
  201. }
  202. }
  203. get type() {
  204. return this._type;
  205. }
  206. get name() {
  207. return this._name;
  208. }
  209. get inputs() {
  210. return this._inputs;
  211. }
  212. get outputs() {
  213. return this._outputs;
  214. }
  215. get attributes() {
  216. return this._attributes;
  217. }
  218. static getBase(layer) {
  219. return layer.layer.base.base ? layer.layer.base.base : layer.layer.base;
  220. }
  221. static makeKey(layer_id, index) {
  222. return layer_id.toString() + "_" + index.toString();
  223. }
  224. };
  225. armnn.Attribute = class {
  226. constructor(metadata, name, value) {
  227. this._name = name;
  228. this._type = metadata ? metadata.type : null;
  229. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  230. if (armnn.schema[this._type]) {
  231. this._value = armnn.Utility.enum(this._type, this._value);
  232. }
  233. }
  234. get name() {
  235. return this._name;
  236. }
  237. get type() {
  238. return this._type;
  239. }
  240. get value() {
  241. return this._value;
  242. }
  243. get visible() {
  244. return this._visible == false ? false : true;
  245. }
  246. };
  247. armnn.Parameter = class {
  248. constructor(name, args) {
  249. this._name = name;
  250. this._arguments = args;
  251. }
  252. get name() {
  253. return this._name;
  254. }
  255. get visible() {
  256. return true;
  257. }
  258. get arguments() {
  259. return this._arguments;
  260. }
  261. };
  262. armnn.Argument = class {
  263. constructor(name, tensorInfo, initializer) {
  264. if (typeof name !== 'string') {
  265. throw new armnn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  266. }
  267. this._name = name;
  268. this._type = new armnn.TensorType(tensorInfo);
  269. this._initializer = initializer;
  270. if (this._type.dataType.startsWith('q') && tensorInfo) {
  271. this._scale = tensorInfo.quantizationScale;
  272. this._zeroPoint = tensorInfo.quantizationOffset;
  273. }
  274. }
  275. get name() {
  276. return this._name;
  277. }
  278. get type() {
  279. return this._type;
  280. }
  281. get quantization() {
  282. if (this._scale !== undefined && this._zeroPoint !== undefined) {
  283. return this._scale.toString() + ' * ' + (this._zeroPoint == 0 ? 'q' : ('(q - ' + this._zeroPoint.toString() + ')'));
  284. }
  285. return undefined;
  286. }
  287. get initializer() {
  288. return this._initializer;
  289. }
  290. };
  291. armnn.Tensor = class {
  292. constructor(tensor, kind) {
  293. this._type = new armnn.TensorType(tensor.info);
  294. this._data = tensor.data.data.slice(0);
  295. this._kind = kind ? kind : '';
  296. }
  297. get kind() {
  298. return this._kind;
  299. }
  300. get type() {
  301. return this._type;
  302. }
  303. get state() {
  304. return this._context().state;
  305. }
  306. get value() {
  307. const context = this._context();
  308. if (context.state) {
  309. return null;
  310. }
  311. context.limit = Number.MAX_SAFE_INTEGER;
  312. return this._decode(context, 0);
  313. }
  314. toString() {
  315. const context = this._context();
  316. if (context.state) {
  317. return '';
  318. }
  319. context.limit = 10000;
  320. const value = this._decode(context, 0);
  321. return JSON.stringify(value, null, 4);
  322. }
  323. _context() {
  324. const context = {};
  325. context.state = null;
  326. context.index = 0;
  327. context.count = 0;
  328. if (this._data == null) {
  329. context.state = 'Tensor data is empty.';
  330. return context;
  331. }
  332. context.dataType = this._type.dataType;
  333. context.shape = this._type.shape.dimensions;
  334. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  335. return context;
  336. }
  337. _decode(context, dimension) {
  338. let shape = context.shape;
  339. if (shape.length == 0) {
  340. shape = [ 1 ];
  341. }
  342. const size = shape[dimension];
  343. const results = [];
  344. if (dimension == shape.length - 1) {
  345. for (let i = 0; i < size; i++) {
  346. if (context.count > context.limit) {
  347. results.push('...');
  348. return results;
  349. }
  350. switch (context.dataType) {
  351. case 'float16':
  352. results.push(context.data.getFloat16(context.index, true));
  353. context.index += 2;
  354. context.count++;
  355. break;
  356. case 'float32':
  357. results.push(context.data.getFloat32(context.index, true));
  358. context.index += 4;
  359. context.count++;
  360. break;
  361. case 'quint8':
  362. results.push(context.data.getUint8(context.index));
  363. context.index += 1;
  364. context.count++;
  365. break;
  366. case 'qint16':
  367. results.push(context.data.getInt16(context.index, true));
  368. context.index += 2;
  369. context.count++;
  370. break;
  371. case 'int32':
  372. results.push(context.data.getInt32(context.index, true));
  373. context.index += 4;
  374. context.count++;
  375. break;
  376. case 'boolean':
  377. results.push(context.data.getInt8(context.index));
  378. context.index += 1;
  379. context.count++;
  380. break;
  381. default:
  382. break;
  383. }
  384. }
  385. }
  386. else {
  387. for (let j = 0; j < size; j++) {
  388. if (context.count > context.limit) {
  389. results.push('...');
  390. return results;
  391. }
  392. results.push(this._decode(context, dimension + 1));
  393. }
  394. }
  395. if (context.shape.length == 0) {
  396. return results[0];
  397. }
  398. return results;
  399. }
  400. };
  401. armnn.TensorType = class {
  402. constructor(tensorInfo) {
  403. const dataType = tensorInfo.dataType;
  404. switch (dataType) {
  405. case 0: this._dataType = 'float16'; break;
  406. case 1: this._dataType = 'float32'; break;
  407. case 2: this._dataType = 'quint8'; break; // QuantisedAsymm8
  408. case 3: this._dataType = 'int32'; break;
  409. case 4: this._dataType = 'boolean'; break;
  410. case 5: this._dataType = 'qint16'; break; // QuantisedSymm16
  411. case 6: this._dataType = 'quint8'; break; // QAsymmU8
  412. case 7: this._dataType = 'qint16'; break; // QSymmS16
  413. case 8: this._dataType = 'qint8'; break; // QAsymmS8
  414. case 9: this._dataType = 'qint8'; break; // QSymmS8
  415. default:
  416. throw new armnn.Error("Unsupported data type '" + JSON.stringify(dataType) + "'.");
  417. }
  418. this._shape = new armnn.TensorShape(tensorInfo.dimensions);
  419. }
  420. get dataType() {
  421. return this._dataType;
  422. }
  423. get shape() {
  424. return this._shape;
  425. }
  426. toString() {
  427. return this.dataType + this._shape.toString();
  428. }
  429. };
  430. armnn.TensorShape = class {
  431. constructor(dimensions) {
  432. this._dimensions = Array.from(dimensions);
  433. }
  434. get dimensions() {
  435. return this._dimensions;
  436. }
  437. toString() {
  438. if (!this._dimensions || this._dimensions.length == 0) {
  439. return '';
  440. }
  441. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  442. }
  443. };
  444. armnn.Utility = class {
  445. static enum(name, value) {
  446. const type = name && armnn.schema ? armnn.schema[name] : undefined;
  447. if (type) {
  448. armnn.Utility._enums = armnn.Utility._enums || new Map();
  449. if (!armnn.Utility._enums.has(name)) {
  450. const map = new Map(Object.keys(type).map((key) => [ type[key], key ]));
  451. armnn.Utility._enums.set(name, map);
  452. }
  453. const map = armnn.Utility._enums.get(name);
  454. if (map.has(value)) {
  455. return map.get(value);
  456. }
  457. }
  458. return value;
  459. }
  460. };
  461. armnn.Error = class extends Error {
  462. constructor(message) {
  463. super(message);
  464. this.name = 'Error loading Arm NN model.';
  465. }
  466. };
  467. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  468. module.exports.ModelFactory = armnn.ModelFactory;
  469. }