kann.js 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. const kann = {};
  2. kann.ModelFactory = class {
  3. async match(context) {
  4. const reader = await context.peek('flatbuffers.binary');
  5. if (reader && reader.identifier === 'KaNN') {
  6. return context.set('kann.flatbuffers', reader);
  7. }
  8. return null;
  9. }
  10. async open(context) {
  11. kann.schema = await context.require('./kann-schema');
  12. kann.schema = kann.schema.kann;
  13. let model = null;
  14. switch (context.type) {
  15. case 'kann.flatbuffers': {
  16. try {
  17. const reader = context.value;
  18. model = kann.schema.Model.create(reader);
  19. } catch (error) {
  20. const message = error && error.message ? error.message : error.toString();
  21. throw new kann.Error(`File format is not kann.Model (${message.replace(/\.$/, '')}).`);
  22. }
  23. break;
  24. }
  25. default: {
  26. throw new kann.Error(`Unsupported KaNN format '${context.type}'.`);
  27. }
  28. }
  29. const metadata = await context.metadata('kann-metadata.json');
  30. return new kann.Model(metadata, model, context.identifier);
  31. }
  32. };
  33. kann.Model = class {
  34. constructor(metadata, model, identifier) {
  35. this.format = 'KaNN';
  36. this.name = identifier;
  37. this.modules = model.graph.map((graph) => new kann.Graph(metadata, graph));
  38. }
  39. };
  40. kann.Graph = class {
  41. constructor(metadata, graph) {
  42. const arcs = new Map();
  43. for (const arc of graph.arcs) {
  44. arcs.set(arc.name, new kann.Value(arc.name, arc.type, null));
  45. }
  46. this.nodes = graph.nodes.map((node) => new kann.Node(metadata, node, arcs));
  47. this.inputs = graph.inputs.map((input) => new kann.Argument(input, [arcs.get(input)]));
  48. this.outputs = graph.outputs.map((output) => new kann.Argument(output, [arcs.get(output)]));
  49. }
  50. };
  51. kann.Node = class {
  52. constructor(metadata, node, arcs) {
  53. this.type = metadata.type(node.type);
  54. this.name = node.name;
  55. this.inputs = [];
  56. this.outputs = [];
  57. this.attributes = [];
  58. const extractData = (value) => {
  59. switch (value.type) {
  60. case 'int': case 'int8': case 'int16': case 'int32': case 'int64': return value.value_int;
  61. case 'uint': case 'uint8': case 'uint16': case 'uint32': case 'uint64': return value.value_uint;
  62. case 'float': case 'float16': case 'float32': case 'float64': return value.value_float;
  63. case 'string': return value.value_string;
  64. case 'int[]': case 'int8[]': case 'int16[]': case 'int32[]': case 'int64[]': return Array.from(value.list_int);
  65. case 'uint[]': case 'uint8[]': case 'uint16[]': case 'uint32[]': case 'uint64[]': return Array.from(value.list_uint);
  66. case 'float[]': case 'float16[]': case 'float32[]': case 'float64[]': return Array.from(value.list_float);
  67. case 'string[]': return Array.from(value.list_string);
  68. default: throw new kann.Error(`Unsupported data type '${value.type}'.`);
  69. }
  70. };
  71. const getAttributeValue = (attribute) => {
  72. if (attribute.type === 'attributes') {
  73. const obj = {};
  74. for (const attr of attribute.attributes) {
  75. obj[attr.name] = getAttributeValue(attr);
  76. }
  77. return obj;
  78. }
  79. if (attribute.value !== null) {
  80. return extractData(attribute.value);
  81. }
  82. throw new kann.Error(`${attribute.name} doesn't have a value.`);
  83. };
  84. if (Array.isArray(node.attributes) && node.attributes.length > 0) {
  85. for (const attr of node.attributes) {
  86. let value = attr.type ? getAttributeValue(attr) : attr;
  87. value = Array.isArray(value) ? value : [value];
  88. const type = value.type === 'attributes' ? null : attr.type || null;
  89. const attribute = new kann.Argument(attr.name, value, type);
  90. this.attributes.push(attribute);
  91. }
  92. }
  93. if (Array.isArray(node.inputs) && node.inputs.length > 0) {
  94. const name = node.inputs.length > 1 ? 'inputs' : 'input';
  95. const argument = new kann.Argument(name, node.inputs.map((input) => arcs.get(input)));
  96. this.inputs.push(argument);
  97. }
  98. if (Array.isArray(node.outputs) && node.outputs.length > 0) {
  99. const name = node.outputs.length > 1 ? 'outputs' : 'output';
  100. const argument = new kann.Argument(name, node.outputs.map((output) => arcs.get(output)));
  101. this.outputs.push(argument);
  102. }
  103. if (Array.isArray(node.params) && node.params.length > 0) {
  104. for (const param of node.params) {
  105. const type = new kann.TensorType(param.type, param.shape);
  106. const data = param.value ? extractData(param.value) : null;
  107. const quantization = param.scale && param.zero_point ? {
  108. type: 'linear',
  109. scale: extractData(param.scale),
  110. offset: extractData(param.zero_point)
  111. } : null;
  112. const tensor = new kann.Tensor(param.name, type, data, quantization);
  113. const value = new kann.Value('', type, tensor);
  114. const argument = new kann.Argument(param.name, [value]);
  115. this.inputs.push(argument);
  116. }
  117. }
  118. if (node.relu) {
  119. const relu = { type: 'ReLU', name: `${node.name}/relu`, params: [] };
  120. this.chain = [new kann.Node(metadata, relu, arcs)];
  121. }
  122. }
  123. };
  124. kann.Argument = class {
  125. constructor(name, value, type = null) {
  126. this.name = name;
  127. this.value = value;
  128. this.type = type;
  129. }
  130. };
  131. kann.Value = class {
  132. constructor(name, type, initializer) {
  133. this.name = name;
  134. this.type = type;
  135. this.initializer = initializer;
  136. this.quantization = initializer && initializer.quantization ? initializer.quantization : null;
  137. }
  138. };
  139. kann.Tensor = class {
  140. constructor(name, type, values, quantization) {
  141. this.name = name;
  142. this.type = type;
  143. this.encoding = Array.isArray(values) ? '|' : '<';
  144. this.values = values;
  145. this.quantization = quantization ? quantization : null;
  146. }
  147. };
  148. kann.TensorType = class {
  149. constructor(dataType, shape) {
  150. this.dataType = dataType || '?';
  151. this.shape = new kann.TensorShape(shape);
  152. }
  153. toString() {
  154. return this.dataType + this.shape.toString();
  155. }
  156. };
  157. kann.TensorShape = class {
  158. constructor(dimensions) {
  159. this.dimensions = Array.from(dimensions);
  160. }
  161. toString() {
  162. if (Array.isArray(this.dimensions) && this.dimensions.length > 0) {
  163. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  164. }
  165. return '';
  166. }
  167. };
  168. kann.Error = class extends Error {
  169. constructor(message) {
  170. super(message);
  171. this.name = 'Error loading KaNN model.';
  172. }
  173. };
  174. export const ModelFactory = kann.ModelFactory;