onednn.js 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. const onednn = {};
  2. onednn.ModelFactory = class {
  3. async match(context) {
  4. const obj = await context.peek('json');
  5. if (obj && obj.version && obj.engine_kind && obj.fpmath_mode && obj.graph) {
  6. return context.set('onednn', obj);
  7. }
  8. return null;
  9. }
  10. async open(context) {
  11. const metadata = await context.metadata('onednn-metadata.json');
  12. return new onednn.Model(metadata, context.value);
  13. }
  14. };
  15. onednn.Model = class {
  16. constructor(metadata, symbol) {
  17. const version = symbol.version;
  18. this.format = `oneDNN${version ? ` v${version}` : ''}`;
  19. this.runtime = `${symbol.engine_kind} ${symbol.fpmath_mode}`;
  20. this.graphs = [new onednn.Graph(metadata, symbol)];
  21. }
  22. };
  23. onednn.Graph = class {
  24. constructor(metadata, symbol) {
  25. this.nodes = [];
  26. this.inputs = [];
  27. this.outputs = [];
  28. const layers = [];
  29. const tensors = new Set();
  30. for (const layer of symbol.graph) {
  31. if (layer.kind === 'Wildcard' && layer.inputs.length === 0) {
  32. for (const output of layer.outputs) {
  33. tensors.add(output.id);
  34. }
  35. } else {
  36. layers.push(layer);
  37. }
  38. }
  39. const values = new Map();
  40. const value = (obj) => {
  41. const id = obj.id;
  42. const shape = !obj.shape || (obj.shape.length === 1 && obj.shape[0] === -1) ? null : new onednn.TensorShape(obj.shape);
  43. const type = new onednn.TensorType(obj.dtype, shape);
  44. const tensor = tensors.has(id) ? new onednn.Tensor(type, obj.property_type) : null;
  45. if (!values.has(id)) {
  46. values.set(id, new onednn.Value(id.toString(), type, tensor));
  47. } else if ((type && !type.equals(values.get(id).type)) || (tensor && !tensor.equals(values.get(id).initializer))) {
  48. throw new onednn.Error(`Duplicate value '${id}'.`);
  49. }
  50. return values.get(id);
  51. };
  52. for (const layer of layers) {
  53. for (const input of layer.inputs) {
  54. value(input);
  55. }
  56. for (const output of layer.outputs) {
  57. value(output);
  58. }
  59. }
  60. const engine = symbol.engine_kind;
  61. for (const layer of layers) {
  62. const node = new onednn.Node(metadata, layer, engine, value, tensors);
  63. this.nodes.push(node);
  64. }
  65. const inputs = symbol.input_ports || [];
  66. for (const input of inputs) {
  67. const value = values.get(input);
  68. if (value) {
  69. const argument = new onednn.Argument(input.toString(), [value]);
  70. this.inputs.push(argument);
  71. }
  72. }
  73. const outputs = symbol.output_ports || [];
  74. for (const output of outputs) {
  75. const value = values.get(output);
  76. if (value) {
  77. const argument = new onednn.Argument(output.toString(), [value]);
  78. this.outputs.push(argument);
  79. }
  80. }
  81. }
  82. };
  83. onednn.Node = class {
  84. constructor(metadata, node, device, value) {
  85. this.name = node.name;
  86. this.attributes = [];
  87. this.inputs = [];
  88. this.outputs = [];
  89. this.type = metadata.type(node.kind) || { name: node.kind };
  90. this.device = device;
  91. this.identifier = node.id;
  92. const attrs = node.attrs;
  93. if (attrs) {
  94. for (const [name, obj] of Object.entries(attrs)) {
  95. let type = obj.type;
  96. let value = obj.value;
  97. switch (type) {
  98. case 'bool':
  99. type = 'boolean';
  100. switch (value) {
  101. case 1: value = true; break;
  102. case 0: value = false; break;
  103. default: throw new onednn.Error(`Unsupported attribute boolean value '${value}'.`);
  104. }
  105. break;
  106. case 's64': {
  107. type = 'int64';
  108. const number = Number.parseInt(value, 10);
  109. value = Number.isNaN(value - number) ? value : number;
  110. break;
  111. }
  112. case 's64[]':
  113. type = 'int64[]';
  114. if (value.length > 2 && value.toString().startsWith('[') && value.toString().endsWith(']')) {
  115. let array = [];
  116. const items = value.substring(1, value.length - 1).split(',')
  117. .map((item) => item.trim())
  118. .map((item) => item.endsWith('L') ? item.substring(0, item.length - 1) : item);
  119. for (const item of items) {
  120. const value = Number.parseInt(item, 10);
  121. if (Number.isNaN(item - value)) {
  122. array = null;
  123. } else if (array !== null) {
  124. array.push(value);
  125. }
  126. }
  127. if (array !== null) {
  128. value = array;
  129. }
  130. }
  131. break;
  132. case 'f32': {
  133. type = 'float32';
  134. const number = Number.parseFloat(value);
  135. value = Number.isNaN(value - number) ? value : number;
  136. break;
  137. }
  138. case 'f32[]':
  139. type = 'float32[]';
  140. if (value.length > 2 && value.toString().startsWith('[') && value.toString().endsWith(']')) {
  141. let array = [];
  142. const items = value.substring(1, value.length - 1).split(',')
  143. .map((item) => item.trim())
  144. .map((item) => item.endsWith('L') ? item.substring(0, item.length - 1) : item);
  145. for (const item of items) {
  146. const value = Number.parseFloat(item);
  147. if (Number.isNaN(item - value)) {
  148. array = null;
  149. } else if (array !== null) {
  150. array.push(value);
  151. }
  152. }
  153. if (array !== null) {
  154. value = array;
  155. }
  156. }
  157. break;
  158. case 'string':
  159. type = 'string';
  160. break;
  161. default: {
  162. throw new onednn.Error(`Unsupported attribute array data type '${type}'.`);
  163. }
  164. }
  165. const attribute = new onednn.Argument(name, value, type);
  166. this.attributes.push(attribute);
  167. }
  168. }
  169. const inputs = node.inputs || [];
  170. for (let i = 0; i < inputs.length; i++) {
  171. let name = inputs.length === 1 ? 'input' : i.toString();
  172. if (this.type && this.type.inputs && this.type.inputs.length > 0) {
  173. name = this.type.inputs[i].name;
  174. }
  175. const argument = new onednn.Argument(name, [value(inputs[i])]);
  176. this.inputs.push(argument);
  177. }
  178. const outputs = node.outputs || [];
  179. for (let i = 0; i < outputs.length; i++) {
  180. let name = outputs.length === 1 ? 'output' : i.toString();
  181. if (this.type && this.type.outputs && this.type.outputs.length > 0) {
  182. name = this.type.outputs[i].name;
  183. }
  184. const argument = new onednn.Argument(name, [value(outputs[i])]);
  185. this.outputs.push(argument);
  186. }
  187. }
  188. };
  189. onednn.Argument = class {
  190. constructor(name, value, type) {
  191. this.name = name;
  192. this.value = value;
  193. this.type = type || null;
  194. }
  195. };
  196. onednn.Value = class {
  197. constructor(name, type, initializer) {
  198. if (typeof name !== 'string') {
  199. throw new onednn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  200. }
  201. this.name = name;
  202. this.type = type || null;
  203. this.initializer = initializer || null;
  204. }
  205. };
  206. onednn.TensorType = class {
  207. constructor(dataType, shape) {
  208. switch (dataType) {
  209. case 'f8_e4m3': this.dataType = 'float8e4m3'; break;
  210. case 'f8_e5m2': this.dataType = 'float8e5m2'; break;
  211. case 'f16': this.dataType = 'float16'; break;
  212. case 'f32': this.dataType = 'float32'; break;
  213. case 's4': this.dataType = 'int4'; break;
  214. case 's8': this.dataType = 'int8'; break;
  215. case 's32': this.dataType = 'int32'; break;
  216. case 'u4': this.dataType = 'uint4'; break;
  217. case 'u8': this.dataType = 'uint8'; break;
  218. case 'bf16': this.dataType = 'bfloat16'; break;
  219. case 'boolean': this.dataType = 'boolean'; break;
  220. case 'undef': this.dataType = '?'; break;
  221. default: throw new onednn.Error(`Unsupported tensor data type '${dataType}'.`);
  222. }
  223. this.shape = shape;
  224. }
  225. equals(obj) {
  226. return obj && this.dataType === obj.dataType &&
  227. ((this.shape && this.shape.equals(obj.shape)) || (this.shape === null && obj.shape === null));
  228. }
  229. toString() {
  230. return this.dataType + (this.shape ? this.shape.toString() : '[?]');
  231. }
  232. };
  233. onednn.TensorShape = class {
  234. constructor(dimensions) {
  235. this.dimensions = dimensions;
  236. }
  237. equals(obj) {
  238. return obj && Array.isArray(obj.dimensions) &&
  239. Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length
  240. && obj.dimensions.every((value, index) => this.dimensions[index] === value);
  241. }
  242. toString() {
  243. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`) : '';
  244. }
  245. };
  246. onednn.Tensor = class {
  247. constructor(type, property_type) {
  248. this.type = type;
  249. this.category = property_type;
  250. }
  251. equals(obj) {
  252. return obj && this.type.equals(obj.type) && this.category === obj.category;
  253. }
  254. };
  255. onednn.Error = class extends Error {
  256. constructor(message) {
  257. super(message);
  258. this.name = 'Error loading oneDNN Graph model.';
  259. }
  260. };
  261. export const ModelFactory = onednn.ModelFactory;