qnn.js 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. // Experimental
  2. const qnn = {};
  3. qnn.ModelFactory = class {
  4. async match(context) {
  5. const obj = await context.peek('json');
  6. if (obj && obj['model.cpp'] !== undefined && obj.graph) {
  7. return context.set('qnn.json', obj);
  8. }
  9. const entries = await context.peek('tar');
  10. if (entries && entries.size > 0 && Array.from(entries).every(([name]) => name.endsWith('.raw'))) {
  11. return context.set('qnn.weights', entries);
  12. }
  13. const identifier = context.identifier.toLowerCase();
  14. if (identifier.endsWith('.bin') || identifier.endsWith('.serialized')) {
  15. const stream = context.stream;
  16. const signatures = [
  17. [0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00],
  18. [0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00],
  19. [0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
  20. [0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01],
  21. ];
  22. if (stream.length >= 16 && signatures.some((signature) => stream.peek(signature.length).every((value, index) => value === signature[index]))) {
  23. return context.set('qnn.serialized');
  24. }
  25. }
  26. return null;
  27. }
  28. async open(context) {
  29. const metadata = await context.metadata('qnn-metadata.json');
  30. switch (context.type) {
  31. case 'qnn.json': {
  32. const obj = context.value;
  33. let weights = new Map();
  34. try {
  35. if (obj['model.bin']) {
  36. const name = obj['model.bin'].split('/').pop();
  37. const content = await context.fetch(name);
  38. const entries = await content.read('tar');
  39. if (entries) {
  40. weights = entries;
  41. }
  42. }
  43. } catch {
  44. // continue regardless of error
  45. }
  46. return new qnn.Model(metadata, obj, weights);
  47. }
  48. case 'qnn.weights': {
  49. const weights = context.value;
  50. const identifier = context.identifier;
  51. const parts = identifier.split('.');
  52. parts.pop();
  53. const base = parts.join('.');
  54. const content = await context.fetch(`${base}_net.json`);
  55. const obj = await content.read('json');
  56. return new qnn.Model(metadata, obj, weights);
  57. }
  58. case 'qnn.serialized': {
  59. throw new qnn.Error("File contains undocumented QNN serialized context.");
  60. }
  61. default: {
  62. throw new qnn.Error(`Unsupported QNN format '${context.type}'.`);
  63. }
  64. }
  65. }
  66. };
  67. qnn.Model = class {
  68. constructor(metadata, obj, weights) {
  69. this.format = 'QNN';
  70. if (obj.converter_command) {
  71. this.producer = obj.converter_command.split(' ').shift();
  72. }
  73. this.metadata = [];
  74. if (obj.copyright_str) {
  75. this.metadata.push(new qnn.Argument('License', obj.copyright_str));
  76. }
  77. this.modules = [new qnn.Graph(metadata, obj.graph, weights)];
  78. }
  79. };
  80. qnn.Graph = class {
  81. constructor(metadata, obj, weights) {
  82. this.inputs = [];
  83. this.outputs = [];
  84. this.nodes = [];
  85. const values = new Map();
  86. values.map = (name, type, tensor, quantization) => {
  87. type = type || null;
  88. tensor = tensor || null;
  89. if (!values.has(name)) {
  90. const value = new qnn.Value(name, type, tensor, quantization);
  91. values.set(name, value);
  92. } else if ((type && !type.equals(values.get(name).type)) || tensor) {
  93. throw new qnn.Error(`Duplicate value '${name}'.`);
  94. }
  95. return values.get(name);
  96. };
  97. const tensors = Object.entries(obj.tensors);
  98. for (const [name, obj] of tensors) {
  99. const type = new qnn.TensorType(obj);
  100. switch (obj.type) {
  101. case 0: {
  102. const value = values.map(name, type, null, obj.quant_params);
  103. const argument = new qnn.Argument(name, [value]);
  104. this.inputs.push(argument);
  105. break;
  106. }
  107. case 1: {
  108. const value = values.map(name, type, null, obj.quant_params);
  109. const argument = new qnn.Argument(name, [value]);
  110. this.outputs.push(argument);
  111. break;
  112. }
  113. case 3: {
  114. values.map(name, type);
  115. break;
  116. }
  117. case 4: {
  118. const reader = weights.get(`${name}.raw`);
  119. const tensor = new qnn.Tensor(name, type, obj, reader);
  120. values.map(name, type, tensor, obj.quant_params);
  121. break;
  122. }
  123. default: {
  124. throw new qnn.Error(`Unsupported tensor type '${obj.type}'.`);
  125. }
  126. }
  127. }
  128. const nodes = Object.entries(obj.nodes);
  129. for (const [name, obj] of nodes) {
  130. const node = new qnn.Node(metadata, name, obj, values, weights);
  131. this.nodes.push(node);
  132. }
  133. }
  134. };
  135. qnn.Argument = class {
  136. constructor(name, value, type = null, visible = true) {
  137. this.name = name;
  138. this.value = value;
  139. this.type = type;
  140. this.visible = visible;
  141. }
  142. };
  143. qnn.Value = class {
  144. constructor(name, type, initializer, quantization) {
  145. if (typeof name !== 'string') {
  146. throw new qnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  147. }
  148. this.name = name;
  149. this.type = type;
  150. this.initializer = initializer;
  151. if (quantization && quantization.definition === 1 && quantization.scale_offset) {
  152. this.quantization = {
  153. type: 'linear',
  154. scale: [quantization.scale_offset.scale],
  155. offset: [quantization.scale_offset.offset]
  156. };
  157. }
  158. }
  159. };
  160. qnn.Node = class {
  161. constructor(metadata, name, obj, values) {
  162. this.name = name;
  163. this.type = { name: obj.type, ...metadata.type(obj.type) };
  164. this.type.module = obj.package;
  165. this.inputs = [];
  166. this.outputs = [];
  167. this.attributes = [];
  168. const inputs = Array.isArray(obj.input_names) ? Array.from(obj.input_names).map((name) => values.map(name)) : [];
  169. if (Array.isArray(this.type.inputs) && inputs.length === this.type.inputs.length) {
  170. for (let i = 0; i < inputs.length; i++) {
  171. const argument = new qnn.Argument(this.type.inputs[i].name, [inputs[i]]);
  172. this.inputs.push(argument);
  173. }
  174. } else if (inputs.length > 0) {
  175. const argument = new qnn.Argument(inputs.length === 1 ? 'input' : 'inputs', inputs);
  176. this.inputs.push(argument);
  177. }
  178. const outputs = Array.isArray(obj.output_names) ? Array.from(obj.output_names).map((name) => values.map(name)) : [];
  179. if (Array.isArray(this.type.outputs) && outputs.length === this.type.outputs.length) {
  180. for (let i = 0; i < outputs.length; i++) {
  181. const argument = new qnn.Argument(this.type.outputs[i].name, [outputs[i]]);
  182. this.outputs.push(argument);
  183. }
  184. } else if (outputs.length > 0) {
  185. const argument = new qnn.Argument(outputs.length === 1 ? 'output' : 'outputs', outputs);
  186. this.outputs.push(argument);
  187. }
  188. for (const [name, value] of Object.entries(obj.scalar_params)) {
  189. const entries = Object.entries(value);
  190. if (entries.length === 1 && name !== 'packageName') {
  191. const dataType = qnn.Utility.dataType(parseInt(entries[0][0], 10));
  192. const argument = new qnn.Argument(name, entries[0][1], dataType);
  193. this.attributes.push(argument);
  194. }
  195. }
  196. for (const [name, value] of Object.entries(obj.tensor_params)) {
  197. const entries = Object.entries(value);
  198. if (entries.length === 1 && name !== 'packageName') {
  199. const tensor = new qnn.Tensor(name, null, entries[0][1]);
  200. const argument = new qnn.Argument(name, tensor, 'tensor');
  201. this.attributes.push(argument);
  202. }
  203. }
  204. }
  205. };
  206. qnn.Tensor = class {
  207. constructor(name, type, obj, data) {
  208. this.type = type || new qnn.TensorType(obj);
  209. this.data = obj.data ? obj.data.flat() : data;
  210. this.encoding = Array.isArray(this.data) ? '|' : '<';
  211. }
  212. get values() {
  213. if (this.data && this.data.peek) {
  214. return this.data.peek();
  215. }
  216. return this.data;
  217. }
  218. };
  219. qnn.TensorType = class {
  220. constructor(obj) {
  221. this.dataType = qnn.Utility.dataType(obj.data_type);
  222. this.shape = new qnn.TensorShape(obj.dims);
  223. this.denotation = obj.axis_format && obj.axis_format !== 'ANY' ? obj.axis_format : '';
  224. }
  225. toString() {
  226. return this.dataType + this.shape.toString();
  227. }
  228. };
  229. qnn.TensorShape = class {
  230. constructor(dimensions) {
  231. this.dimensions = dimensions;
  232. }
  233. toString() {
  234. if (Array.isArray(this.dimensions) && this.dimensions.length > 0) {
  235. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  236. }
  237. return '';
  238. }
  239. };
  240. qnn.Utility = class {
  241. static dataType(value) {
  242. switch (value) {
  243. case 0x0008: return 'int8';
  244. case 0x0016: return 'int16';
  245. case 0x0032: return 'int32';
  246. case 0x0064: return 'int64';
  247. case 0x0108: return 'uint8';
  248. case 0x0116: return 'uint16';
  249. case 0x0132: return 'uint32';
  250. case 0x0164: return 'uint64';
  251. case 0x0216: return 'float16';
  252. case 0x0232: return 'float32';
  253. case 0x0304: return 'qint4';
  254. case 0x0308: return 'qint8';
  255. case 0x0316: return 'qint16';
  256. case 0x0332: return 'qint32';
  257. case 0x0404: return 'quint4';
  258. case 0x0408: return 'quint8';
  259. case 0x0416: return 'quint16';
  260. case 0x0432: return 'quint32';
  261. case 0x0508: return 'boolean';
  262. case 0x0608: return 'string';
  263. case 0x7fffffff: return 'undefined';
  264. default: throw new qnn.Error(`Unsupported data type '${JSON.stringify(value)}'.`);
  265. }
  266. }
  267. };
  268. qnn.Error = class extends Error {
  269. constructor(message) {
  270. super(message);
  271. this.name = 'Error loading QNN model.';
  272. }
  273. };
  274. export const ModelFactory = qnn.ModelFactory;