armnn.js 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. var armnn = armnn || {};
  2. var flatbuffers = flatbuffers || require('./flatbuffers');
  3. armnn.ModelFactory = class {
  4. match(context) {
  5. const identifier = context.identifier;
  6. const extension = identifier.split('.').pop().toLowerCase();
  7. const stream = context.stream;
  8. if (stream && extension === 'armnn') {
  9. return 'armnn.flatbuffers';
  10. }
  11. if (extension === 'json') {
  12. const obj = context.open('json');
  13. if (obj && obj.layers && obj.inputIds && obj.outputIds) {
  14. return 'armnn.flatbuffers.json';
  15. }
  16. }
  17. return undefined;
  18. }
  19. open(context, match) {
  20. return context.require('./armnn-schema').then((/* schema */) => {
  21. armnn.schema = flatbuffers.get('armnn').armnnSerializer;
  22. let model = null;
  23. switch (match) {
  24. case 'armnn.flatbuffers': {
  25. try {
  26. const stream = context.stream;
  27. const reader = flatbuffers.BinaryReader.open(stream);
  28. model = armnn.schema.SerializedGraph.create(reader);
  29. }
  30. catch (error) {
  31. const message = error && error.message ? error.message : error.toString();
  32. throw new armnn.Error('File format is not armnn.SerializedGraph (' + message.replace(/\.$/, '') + ').');
  33. }
  34. break;
  35. }
  36. case 'armnn.flatbuffers.json': {
  37. try {
  38. const obj = context.open('json');
  39. const reader = flatbuffers.TextReader.open(obj);
  40. model = armnn.schema.SerializedGraph.createText(reader);
  41. }
  42. catch (error) {
  43. const message = error && error.message ? error.message : error.toString();
  44. throw new armnn.Error('File text format is not armnn.SerializedGraph (' + message.replace(/\.$/, '') + ').');
  45. }
  46. break;
  47. }
  48. default: {
  49. throw new armnn.Error("Unsupported Arm NN '" + match + "'.");
  50. }
  51. }
  52. return context.metadata('armnn-metadata.json').then((metadata) => {
  53. return new armnn.Model(metadata, model);
  54. });
  55. });
  56. }
  57. };
  58. armnn.Model = class {
  59. constructor(metadata, model) {
  60. this._graphs = [];
  61. this._graphs.push(new armnn.Graph(metadata, model));
  62. }
  63. get format() {
  64. return 'Arm NN';
  65. }
  66. get description() {
  67. return '';
  68. }
  69. get graphs() {
  70. return this._graphs;
  71. }
  72. };
  73. armnn.Graph = class {
  74. constructor(metadata, graph) {
  75. this._name = '';
  76. this._nodes = [];
  77. this._inputs = [];
  78. this._outputs = [];
  79. const counts = new Map();
  80. for (const layer of graph.layers) {
  81. const base = armnn.Node.getBase(layer);
  82. for (const slot of base.inputSlots) {
  83. const name = slot.connection.sourceLayerIndex.toString() + ':' + slot.connection.outputSlotIndex.toString();
  84. counts.set(name, counts.has(name) ? counts.get(name) + 1 : 1);
  85. }
  86. }
  87. const args = new Map();
  88. const arg = (layerIndex, slotIndex, tensor) => {
  89. const name = layerIndex.toString() + ':' + slotIndex.toString();
  90. if (!args.has(name)) {
  91. const layer = graph.layers[layerIndex];
  92. const base = layerIndex < graph.layers.length ? armnn.Node.getBase(layer) : null;
  93. const tensorInfo = base && slotIndex < base.outputSlots.length ? base.outputSlots[slotIndex].tensorInfo : null;
  94. args.set(name, new armnn.Argument(name, tensorInfo, tensor));
  95. }
  96. return args.get(name);
  97. };
  98. const layers = graph.layers.filter((layer) => {
  99. const base = armnn.Node.getBase(layer);
  100. if (base.layerType == armnn.schema.LayerType.Constant && base.outputSlots.length === 1 && layer.layer.input) {
  101. const slot = base.outputSlots[0];
  102. const name = base.index.toString() + ':' + slot.index.toString();
  103. if (counts.get(name) === 1) {
  104. const tensor = new armnn.Tensor(layer.layer.input, 'Constant');
  105. arg(base.index, slot.index, tensor);
  106. return false;
  107. }
  108. }
  109. return true;
  110. });
  111. for (const layer of layers) {
  112. const base = armnn.Node.getBase(layer);
  113. for (const slot of base.inputSlots) {
  114. arg(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex);
  115. }
  116. }
  117. for (const layer of layers) {
  118. const base = armnn.Node.getBase(layer);
  119. switch (base.layerType) {
  120. case armnn.schema.LayerType.Input: {
  121. const name = base ? base.layerName : '';
  122. for (const slot of base.outputSlots) {
  123. const argument = arg(base.index, slot.index);
  124. this._inputs.push(new armnn.Parameter(name, [ argument ]));
  125. }
  126. break;
  127. }
  128. case armnn.schema.LayerType.Output: {
  129. const base = armnn.Node.getBase(layer);
  130. const name = base ? base.layerName : '';
  131. for (const slot of base.inputSlots) {
  132. const argument = arg(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex);
  133. this._outputs.push(new armnn.Parameter(name, [ argument ]));
  134. }
  135. break;
  136. }
  137. default:
  138. this._nodes.push(new armnn.Node(metadata, layer, arg));
  139. break;
  140. }
  141. }
  142. }
  143. get name() {
  144. return this._name;
  145. }
  146. get inputs() {
  147. return this._inputs;
  148. }
  149. get outputs() {
  150. return this._outputs;
  151. }
  152. get nodes() {
  153. return this._nodes;
  154. }
  155. };
  156. armnn.Node = class {
  157. constructor(metadata, layer, arg) {
  158. const type = layer.layer.constructor.name;
  159. this._type = Object.assign({}, metadata.type(type) || { name: type });
  160. this._type.name = this._type.name.replace(/Layer$/, '');
  161. this._name = '';
  162. this._outputs = [];
  163. this._inputs = [];
  164. this._attributes = [];
  165. const inputSchemas = (this._type && this._type.inputs) ? [...this._type.inputs] : [ { name: 'input' } ];
  166. const outputSchemas = (this._type && this._type.outputs) ? [...this._type.outputs] : [ { name: 'output' } ];
  167. const base = armnn.Node.getBase(layer);
  168. if (base) {
  169. this._name = base.layerName;
  170. const inputSlots = [...base.inputSlots];
  171. while (inputSlots.length > 0) {
  172. const inputSchema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: '?' };
  173. const inputCount = inputSchema.list ? inputSlots.length : 1;
  174. this._inputs.push(new armnn.Parameter(inputSchema.name, inputSlots.splice(0, inputCount).map((inputSlot) => {
  175. return arg(inputSlot.connection.sourceLayerIndex, inputSlot.connection.outputSlotIndex);
  176. })));
  177. }
  178. const outputSlots = [...base.outputSlots];
  179. while (outputSlots.length > 0) {
  180. const outputSchema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: '?' };
  181. const outputCount = outputSchema.list ? outputSlots.length : 1;
  182. this._outputs.push(new armnn.Parameter(outputSchema.name, outputSlots.splice(0, outputCount).map((outputSlot) => {
  183. return arg(base.index, outputSlot.index);
  184. })));
  185. }
  186. }
  187. if (layer.layer && layer.layer.descriptor && this._type.attributes) {
  188. for (const pair of Object.entries(layer.layer.descriptor)) {
  189. const name = pair[0];
  190. const value = pair[1];
  191. const attribute = new armnn.Attribute(metadata.attribute(type, name), name, value);
  192. this._attributes.push(attribute);
  193. }
  194. }
  195. if (layer.layer) {
  196. for (const entry of Object.entries(layer.layer).filter((entry) => entry[1] instanceof armnn.schema.ConstTensor)) {
  197. const name = entry[0];
  198. const tensor = entry[1];
  199. const argument = new armnn.Argument('', tensor.info, new armnn.Tensor(tensor));
  200. this._inputs.push(new armnn.Parameter(name, [ argument ]));
  201. }
  202. }
  203. }
  204. get type() {
  205. return this._type;
  206. }
  207. get name() {
  208. return this._name;
  209. }
  210. get inputs() {
  211. return this._inputs;
  212. }
  213. get outputs() {
  214. return this._outputs;
  215. }
  216. get attributes() {
  217. return this._attributes;
  218. }
  219. static getBase(layer) {
  220. return layer.layer.base.base ? layer.layer.base.base : layer.layer.base;
  221. }
  222. static makeKey(layer_id, index) {
  223. return layer_id.toString() + "_" + index.toString();
  224. }
  225. };
  226. armnn.Attribute = class {
  227. constructor(metadata, name, value) {
  228. this._name = name;
  229. this._type = metadata ? metadata.type : null;
  230. this._value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  231. if (armnn.schema[this._type]) {
  232. this._value = armnn.Utility.enum(this._type, this._value);
  233. }
  234. }
  235. get name() {
  236. return this._name;
  237. }
  238. get type() {
  239. return this._type;
  240. }
  241. get value() {
  242. return this._value;
  243. }
  244. get visible() {
  245. return this._visible == false ? false : true;
  246. }
  247. };
  248. armnn.Parameter = class {
  249. constructor(name, args) {
  250. this._name = name;
  251. this._arguments = args;
  252. }
  253. get name() {
  254. return this._name;
  255. }
  256. get visible() {
  257. return true;
  258. }
  259. get arguments() {
  260. return this._arguments;
  261. }
  262. };
  263. armnn.Argument = class {
  264. constructor(name, tensorInfo, initializer) {
  265. if (typeof name !== 'string') {
  266. throw new armnn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  267. }
  268. this._name = name;
  269. this._type = new armnn.TensorType(tensorInfo);
  270. this._initializer = initializer;
  271. if (this._type.dataType.startsWith('q') && tensorInfo) {
  272. this._scale = tensorInfo.quantizationScale;
  273. this._zeroPoint = tensorInfo.quantizationOffset;
  274. }
  275. }
  276. get name() {
  277. return this._name;
  278. }
  279. get type() {
  280. return this._type;
  281. }
  282. get quantization() {
  283. if (this._scale !== undefined && this._zeroPoint !== undefined) {
  284. return this._scale.toString() + ' * ' + (this._zeroPoint == 0 ? 'q' : ('(q - ' + this._zeroPoint.toString() + ')'));
  285. }
  286. return undefined;
  287. }
  288. get initializer() {
  289. return this._initializer;
  290. }
  291. };
  292. armnn.Tensor = class {
  293. constructor(tensor, kind) {
  294. this._type = new armnn.TensorType(tensor.info);
  295. this._data = tensor.data.data.slice(0);
  296. this._kind = kind ? kind : '';
  297. }
  298. get kind() {
  299. return this._kind;
  300. }
  301. get type() {
  302. return this._type;
  303. }
  304. get state() {
  305. return this._context().state;
  306. }
  307. get value() {
  308. const context = this._context();
  309. if (context.state) {
  310. return null;
  311. }
  312. context.limit = Number.MAX_SAFE_INTEGER;
  313. return this._decode(context, 0);
  314. }
  315. toString() {
  316. const context = this._context();
  317. if (context.state) {
  318. return '';
  319. }
  320. context.limit = 10000;
  321. const value = this._decode(context, 0);
  322. return JSON.stringify(value, null, 4);
  323. }
  324. _context() {
  325. const context = {};
  326. context.state = null;
  327. context.index = 0;
  328. context.count = 0;
  329. if (this._data == null) {
  330. context.state = 'Tensor data is empty.';
  331. return context;
  332. }
  333. context.dataType = this._type.dataType;
  334. context.shape = this._type.shape.dimensions;
  335. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  336. return context;
  337. }
  338. _decode(context, dimension) {
  339. let shape = context.shape;
  340. if (shape.length == 0) {
  341. shape = [ 1 ];
  342. }
  343. const size = shape[dimension];
  344. const results = [];
  345. if (dimension == shape.length - 1) {
  346. for (let i = 0; i < size; i++) {
  347. if (context.count > context.limit) {
  348. results.push('...');
  349. return results;
  350. }
  351. switch (context.dataType) {
  352. case 'float16':
  353. results.push(context.data.getFloat16(context.index, true));
  354. context.index += 2;
  355. context.count++;
  356. break;
  357. case 'float32':
  358. results.push(context.data.getFloat32(context.index, true));
  359. context.index += 4;
  360. context.count++;
  361. break;
  362. case 'quint8':
  363. results.push(context.data.getUint8(context.index));
  364. context.index += 1;
  365. context.count++;
  366. break;
  367. case 'qint16':
  368. results.push(context.data.getInt16(context.index, true));
  369. context.index += 2;
  370. context.count++;
  371. break;
  372. case 'int32':
  373. results.push(context.data.getInt32(context.index, true));
  374. context.index += 4;
  375. context.count++;
  376. break;
  377. case 'boolean':
  378. results.push(context.data.getInt8(context.index));
  379. context.index += 1;
  380. context.count++;
  381. break;
  382. default:
  383. break;
  384. }
  385. }
  386. }
  387. else {
  388. for (let j = 0; j < size; j++) {
  389. if (context.count > context.limit) {
  390. results.push('...');
  391. return results;
  392. }
  393. results.push(this._decode(context, dimension + 1));
  394. }
  395. }
  396. if (context.shape.length == 0) {
  397. return results[0];
  398. }
  399. return results;
  400. }
  401. };
  402. armnn.TensorType = class {
  403. constructor(tensorInfo) {
  404. const dataType = tensorInfo.dataType;
  405. switch (dataType) {
  406. case 0: this._dataType = 'float16'; break;
  407. case 1: this._dataType = 'float32'; break;
  408. case 2: this._dataType = 'quint8'; break; // QuantisedAsymm8
  409. case 3: this._dataType = 'int32'; break;
  410. case 4: this._dataType = 'boolean'; break;
  411. case 5: this._dataType = 'qint16'; break; // QuantisedSymm16
  412. case 6: this._dataType = 'quint8'; break; // QAsymmU8
  413. case 7: this._dataType = 'qint16'; break; // QSymmS16
  414. case 8: this._dataType = 'qint8'; break; // QAsymmS8
  415. case 9: this._dataType = 'qint8'; break; // QSymmS8
  416. default:
  417. throw new armnn.Error("Unsupported data type '" + JSON.stringify(dataType) + "'.");
  418. }
  419. this._shape = new armnn.TensorShape(tensorInfo.dimensions);
  420. }
  421. get dataType() {
  422. return this._dataType;
  423. }
  424. get shape() {
  425. return this._shape;
  426. }
  427. toString() {
  428. return this.dataType + this._shape.toString();
  429. }
  430. };
  431. armnn.TensorShape = class {
  432. constructor(dimensions) {
  433. this._dimensions = Array.from(dimensions);
  434. }
  435. get dimensions() {
  436. return this._dimensions;
  437. }
  438. toString() {
  439. if (!this._dimensions || this._dimensions.length == 0) {
  440. return '';
  441. }
  442. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  443. }
  444. };
  445. armnn.Utility = class {
  446. static enum(name, value) {
  447. const type = name && armnn.schema ? armnn.schema[name] : undefined;
  448. if (type) {
  449. armnn.Utility._enums = armnn.Utility._enums || new Map();
  450. if (!armnn.Utility._enums.has(name)) {
  451. const map = new Map(Object.keys(type).map((key) => [ type[key], key ]));
  452. armnn.Utility._enums.set(name, map);
  453. }
  454. const map = armnn.Utility._enums.get(name);
  455. if (map.has(value)) {
  456. return map.get(value);
  457. }
  458. }
  459. return value;
  460. }
  461. };
  462. armnn.Error = class extends Error {
  463. constructor(message) {
  464. super(message);
  465. this.name = 'Error loading Arm NN model.';
  466. }
  467. };
  468. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  469. module.exports.ModelFactory = armnn.ModelFactory;
  470. }