armnn.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. const armnn = {};
  2. armnn.ModelFactory = class {
  3. async match(context) {
  4. const identifier = context.identifier;
  5. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  6. if (extension === 'armnn') {
  7. const reader = await context.peek('flatbuffers.binary');
  8. if (reader) {
  9. return context.set('armnn.flatbuffers', reader);
  10. }
  11. }
  12. if (extension === 'json') {
  13. const obj = await context.peek('json');
  14. if (obj && obj.layers && obj.inputIds && obj.outputIds) {
  15. return context.set('armnn.flatbuffers.json', obj);
  16. }
  17. }
  18. return null;
  19. }
  20. async open(context) {
  21. armnn.schema = await context.require('./armnn-schema');
  22. armnn.schema = armnn.schema.armnnSerializer;
  23. let model = null;
  24. switch (context.type) {
  25. case 'armnn.flatbuffers': {
  26. try {
  27. const reader = await context.read('flatbuffers.binary');
  28. model = armnn.schema.SerializedGraph.create(reader);
  29. } catch (error) {
  30. const message = error && error.message ? error.message : error.toString();
  31. throw new armnn.Error(`File format is not armnn.SerializedGraph (${message.replace(/\.$/, '')}).`);
  32. }
  33. break;
  34. }
  35. case 'armnn.flatbuffers.json': {
  36. try {
  37. const reader = await context.read('flatbuffers.text');
  38. model = armnn.schema.SerializedGraph.createText(reader);
  39. } catch (error) {
  40. const message = error && error.message ? error.message : error.toString();
  41. throw new armnn.Error(`File text format is not armnn.SerializedGraph (${message.replace(/\.$/, '')}).`);
  42. }
  43. break;
  44. }
  45. default: {
  46. throw new armnn.Error(`Unsupported Arm NN format '${context.type}'.`);
  47. }
  48. }
  49. const metadata = await context.metadata('armnn-metadata.json');
  50. return new armnn.Model(metadata, model);
  51. }
  52. };
  53. armnn.Model = class {
  54. constructor(metadata, model) {
  55. this.format = 'Arm NN';
  56. this.modules = [new armnn.Graph(metadata, model)];
  57. }
  58. };
  59. armnn.Graph = class {
  60. constructor(metadata, graph) {
  61. this.name = '';
  62. this.nodes = [];
  63. this.inputs = [];
  64. this.outputs = [];
  65. const counts = new Map();
  66. for (const layer of graph.layers) {
  67. const base = armnn.Node.getBase(layer);
  68. for (const slot of base.inputSlots) {
  69. const name = `${slot.connection.sourceLayerIndex}:${slot.connection.outputSlotIndex}`;
  70. counts.set(name, counts.has(name) ? counts.get(name) + 1 : 1);
  71. }
  72. }
  73. const values = new Map();
  74. const value = (layerIndex, slotIndex, tensor) => {
  75. const name = `${layerIndex}:${slotIndex}`;
  76. if (!values.has(name)) {
  77. const layer = graph.layers[layerIndex];
  78. const base = layerIndex < graph.layers.length ? armnn.Node.getBase(layer) : null;
  79. const tensorInfo = base && slotIndex < base.outputSlots.length ? base.outputSlots[slotIndex].tensorInfo : null;
  80. values.set(name, new armnn.Value(name, tensorInfo, tensor));
  81. }
  82. return values.get(name);
  83. };
  84. const layers = graph.layers.filter((layer) => {
  85. const base = armnn.Node.getBase(layer);
  86. if (base.layerType === armnn.schema.LayerType.Constant && base.outputSlots.length === 1 && layer.layer.input) {
  87. const [slot] = base.outputSlots;
  88. const name = `${base.index}:${slot.index}`;
  89. if (counts.get(name) === 1) {
  90. const tensor = new armnn.Tensor(layer.layer.input, 'Constant');
  91. value(base.index, slot.index, tensor);
  92. return false;
  93. }
  94. }
  95. return true;
  96. });
  97. for (const layer of layers) {
  98. const base = armnn.Node.getBase(layer);
  99. for (const slot of base.inputSlots) {
  100. value(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex);
  101. }
  102. }
  103. for (const layer of layers) {
  104. const base = armnn.Node.getBase(layer);
  105. switch (base.layerType) {
  106. case armnn.schema.LayerType.Input: {
  107. const name = base ? base.layerName : '';
  108. for (const slot of base.outputSlots) {
  109. const argument = new armnn.Argument(name, [value(base.index, slot.index)]);
  110. this.inputs.push(argument);
  111. }
  112. break;
  113. }
  114. case armnn.schema.LayerType.Output: {
  115. const base = armnn.Node.getBase(layer);
  116. const name = base ? base.layerName : '';
  117. for (const slot of base.inputSlots) {
  118. const argument = new armnn.Argument(name, [value(slot.connection.sourceLayerIndex, slot.connection.outputSlotIndex)]);
  119. this.outputs.push(argument);
  120. }
  121. break;
  122. }
  123. default:
  124. this.nodes.push(new armnn.Node(metadata, layer, value));
  125. break;
  126. }
  127. }
  128. }
  129. };
  130. armnn.Node = class {
  131. constructor(metadata, layer, value) {
  132. const name = layer.layer.constructor.name;
  133. const type = metadata.type(name);
  134. this.type = type ? { ...type } : { name };
  135. this.type.name = this.type.name.replace(/Layer$/, '');
  136. this.name = '';
  137. this.outputs = [];
  138. this.inputs = [];
  139. this.attributes = [];
  140. const inputSchemas = (this.type && this.type.inputs) ? [...this.type.inputs] : [{ name: 'input' }];
  141. const outputSchemas = (this.type && this.type.outputs) ? [...this.type.outputs] : [{ name: 'output' }];
  142. const base = armnn.Node.getBase(layer);
  143. if (base) {
  144. this.name = base.layerName;
  145. const inputs = [...base.inputSlots];
  146. while (inputs.length > 0) {
  147. const schema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: '?' };
  148. const count = schema.list ? inputs.length : 1;
  149. const argument = new armnn.Argument(schema.name, inputs.splice(0, count).map((inputSlot) => {
  150. return value(inputSlot.connection.sourceLayerIndex, inputSlot.connection.outputSlotIndex);
  151. }));
  152. this.inputs.push(argument);
  153. }
  154. const outputs = [...base.outputSlots];
  155. while (outputs.length > 0) {
  156. const schema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: '?' };
  157. const count = schema.list ? outputs.length : 1;
  158. this.outputs.push(new armnn.Argument(schema.name, outputs.splice(0, count).map((outputSlot) => {
  159. return value(base.index, outputSlot.index);
  160. })));
  161. }
  162. }
  163. if (layer.layer) {
  164. if (layer.layer.descriptor && this.type.attributes) {
  165. for (const [key, obj] of Object.entries(layer.layer.descriptor)) {
  166. const schema = metadata.attribute(name, key);
  167. const type = schema ? schema.type : null;
  168. let value = ArrayBuffer.isView(obj) ? Array.from(obj) : obj;
  169. if (armnn.schema[type]) {
  170. value = armnn.Utility.enum(type, value);
  171. }
  172. const attribute = new armnn.Argument(key, value, type);
  173. this.attributes.push(attribute);
  174. }
  175. }
  176. for (const [name, tensor] of Object.entries(layer.layer).filter(([, value]) => value instanceof armnn.schema.ConstTensor)) {
  177. const value = new armnn.Value('', tensor.info, new armnn.Tensor(tensor));
  178. const argument = new armnn.Argument(name, [value]);
  179. this.inputs.push(argument);
  180. }
  181. }
  182. }
  183. static getBase(layer) {
  184. return layer.layer.base.base ? layer.layer.base.base : layer.layer.base;
  185. }
  186. static makeKey(layer_id, index) {
  187. return `${layer_id}_${index}`;
  188. }
  189. };
  190. armnn.Argument = class {
  191. constructor(name, value, type = null) {
  192. this.name = name;
  193. this.value = value;
  194. this.type = type;
  195. }
  196. };
  197. armnn.Value = class {
  198. constructor(name, tensorInfo, initializer) {
  199. if (typeof name !== 'string') {
  200. throw new armnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  201. }
  202. this.name = name;
  203. this.type = new armnn.TensorType(tensorInfo);
  204. this.initializer = initializer;
  205. if (tensorInfo.quantizationScale !== 0 ||
  206. tensorInfo.quantizationOffset !== 0 ||
  207. tensorInfo.quantizationScales.length > 0 ||
  208. tensorInfo.quantizationDim !== 0) {
  209. this.quantization = {
  210. type: 'linear',
  211. dimension: tensorInfo.quantizationDim,
  212. scale: [tensorInfo.quantizationScale],
  213. offset: [tensorInfo.quantizationOffset]
  214. };
  215. }
  216. }
  217. };
  218. armnn.Tensor = class {
  219. constructor(tensor, category = '') {
  220. this.type = new armnn.TensorType(tensor.info);
  221. this.category = category;
  222. const data = tensor.data.data.slice(0);
  223. this.values = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
  224. }
  225. };
  226. armnn.TensorType = class {
  227. constructor(tensorInfo) {
  228. const dataType = tensorInfo.dataType;
  229. switch (dataType) {
  230. case 0: this.dataType = 'float16'; break;
  231. case 1: this.dataType = 'float32'; break;
  232. case 2: this.dataType = 'quint8'; break; // QuantisedAsymm8
  233. case 3: this.dataType = 'int32'; break;
  234. case 4: this.dataType = 'boolean'; break;
  235. case 5: this.dataType = 'qint16'; break; // QuantisedSymm16
  236. case 6: this.dataType = 'quint8'; break; // QAsymmU8
  237. case 7: this.dataType = 'qint16'; break; // QSymmS16
  238. case 8: this.dataType = 'qint8'; break; // QAsymmS8
  239. case 9: this.dataType = 'qint8'; break; // QSymmS8
  240. default:
  241. throw new armnn.Error(`Unsupported data type '${JSON.stringify(dataType)}'.`);
  242. }
  243. this.shape = new armnn.TensorShape(tensorInfo.dimensions);
  244. }
  245. toString() {
  246. return this.dataType + this.shape.toString();
  247. }
  248. };
  249. armnn.TensorShape = class {
  250. constructor(dimensions) {
  251. this.dimensions = Array.from(dimensions);
  252. }
  253. toString() {
  254. if (!this.dimensions || this.dimensions.length === 0) {
  255. return '';
  256. }
  257. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  258. }
  259. };
  260. armnn.Utility = class {
  261. static enum(name, value) {
  262. const type = name && armnn.schema ? armnn.schema[name] : undefined;
  263. if (type) {
  264. armnn.Utility._enums = armnn.Utility._enums || new Map();
  265. if (!armnn.Utility._enums.has(name)) {
  266. const entries = new Map(Object.entries(type).map(([key, value]) => [value, key]));
  267. armnn.Utility._enums.set(name, entries);
  268. }
  269. const entries = armnn.Utility._enums.get(name);
  270. if (entries.has(value)) {
  271. return entries.get(value);
  272. }
  273. }
  274. return value;
  275. }
  276. };
  277. armnn.Error = class extends Error {
  278. constructor(message) {
  279. super(message);
  280. this.name = 'Error loading Arm NN model.';
  281. }
  282. };
  283. export const ModelFactory = armnn.ModelFactory;