| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- const kann = {};
- kann.ModelFactory = class {
- async match(context) {
- const reader = await context.peek('flatbuffers.binary');
- if (reader && reader.identifier === 'KaNN') {
- return context.set('kann.flatbuffers', reader);
- }
- return null;
- }
- async open(context) {
- kann.schema = await context.require('./kann-schema');
- kann.schema = kann.schema.kann;
- let model = null;
- switch (context.type) {
- case 'kann.flatbuffers': {
- try {
- const reader = context.value;
- model = kann.schema.Model.create(reader);
- } catch (error) {
- const message = error && error.message ? error.message : error.toString();
- throw new kann.Error(`File format is not kann.Model (${message.replace(/\.$/, '')}).`);
- }
- break;
- }
- default: {
- throw new kann.Error(`Unsupported KaNN format '${context.type}'.`);
- }
- }
- const metadata = await context.metadata('kann-metadata.json');
- return new kann.Model(metadata, model, context.identifier);
- }
- };
- kann.Model = class {
- constructor(metadata, model, identifier) {
- this.format = 'KaNN';
- this.name = identifier;
- this.modules = model.graph.map((graph) => new kann.Graph(metadata, graph));
- }
- };
- kann.Graph = class {
- constructor(metadata, graph) {
- const arcs = new Map();
- for (const arc of graph.arcs) {
- arcs.set(arc.name, new kann.Value(arc.name, arc.type, null));
- }
- this.nodes = graph.nodes.map((node) => new kann.Node(metadata, node, arcs));
- this.inputs = graph.inputs.map((input) => new kann.Argument(input, [arcs.get(input)]));
- this.outputs = graph.outputs.map((output) => new kann.Argument(output, [arcs.get(output)]));
- }
- };
- kann.Node = class {
- constructor(metadata, node, arcs) {
- this.type = metadata.type(node.type);
- this.name = node.name;
- this.inputs = [];
- this.outputs = [];
- this.attributes = [];
- const extractData = (value) => {
- switch (value.type) {
- case 'int': case 'int8': case 'int16': case 'int32': case 'int64': return value.value_int;
- case 'uint': case 'uint8': case 'uint16': case 'uint32': case 'uint64': return value.value_uint;
- case 'float': case 'float16': case 'float32': case 'float64': return value.value_float;
- case 'string': return value.value_string;
- case 'int[]': case 'int8[]': case 'int16[]': case 'int32[]': case 'int64[]': return Array.from(value.list_int);
- case 'uint[]': case 'uint8[]': case 'uint16[]': case 'uint32[]': case 'uint64[]': return Array.from(value.list_uint);
- case 'float[]': case 'float16[]': case 'float32[]': case 'float64[]': return Array.from(value.list_float);
- case 'string[]': return Array.from(value.list_string);
- default: throw new kann.Error(`Unsupported data type '${value.type}'.`);
- }
- };
- const getAttributeValue = (attribute) => {
- if (attribute.type === 'attributes') {
- const obj = {};
- for (const attr of attribute.attributes) {
- obj[attr.name] = getAttributeValue(attr);
- }
- return obj;
- }
- if (attribute.value !== null) {
- return extractData(attribute.value);
- }
- throw new kann.Error(`${attribute.name} doesn't have a value.`);
- };
- if (Array.isArray(node.attributes) && node.attributes.length > 0) {
- for (const attr of node.attributes) {
- let value = attr.type ? getAttributeValue(attr) : attr;
- value = Array.isArray(value) ? value : [value];
- const type = value.type === 'attributes' ? null : attr.type || null;
- const attribute = new kann.Argument(attr.name, value, type);
- this.attributes.push(attribute);
- }
- }
- if (Array.isArray(node.inputs) && node.inputs.length > 0) {
- const name = node.inputs.length > 1 ? 'inputs' : 'input';
- const argument = new kann.Argument(name, node.inputs.map((input) => arcs.get(input)));
- this.inputs.push(argument);
- }
- if (Array.isArray(node.outputs) && node.outputs.length > 0) {
- const name = node.outputs.length > 1 ? 'outputs' : 'output';
- const argument = new kann.Argument(name, node.outputs.map((output) => arcs.get(output)));
- this.outputs.push(argument);
- }
- if (Array.isArray(node.params) && node.params.length > 0) {
- for (const param of node.params) {
- const type = new kann.TensorType(param.type, param.shape);
- const data = param.value ? extractData(param.value) : null;
- const quantization = param.scale && param.zero_point ? {
- type: 'linear',
- scale: extractData(param.scale),
- offset: extractData(param.zero_point)
- } : null;
- const tensor = new kann.Tensor(param.name, type, data, quantization);
- const value = new kann.Value('', type, tensor);
- const argument = new kann.Argument(param.name, [value]);
- this.inputs.push(argument);
- }
- }
- if (node.relu) {
- const relu = { type: 'ReLU', name: `${node.name}/relu`, params: [] };
- this.chain = [new kann.Node(metadata, relu, arcs)];
- }
- }
- };
- kann.Argument = class {
- constructor(name, value, type = null) {
- this.name = name;
- this.value = value;
- this.type = type;
- }
- };
- kann.Value = class {
- constructor(name, type, initializer) {
- this.name = name;
- this.type = type;
- this.initializer = initializer;
- this.quantization = initializer && initializer.quantization ? initializer.quantization : null;
- }
- };
- kann.Tensor = class {
- constructor(name, type, values, quantization) {
- this.name = name;
- this.type = type;
- this.encoding = Array.isArray(values) ? '|' : '<';
- this.values = values;
- this.quantization = quantization ? quantization : null;
- }
- };
- kann.TensorType = class {
- constructor(dataType, shape) {
- this.dataType = dataType || '?';
- this.shape = new kann.TensorShape(shape);
- }
- toString() {
- return this.dataType + this.shape.toString();
- }
- };
- kann.TensorShape = class {
- constructor(dimensions) {
- this.dimensions = Array.from(dimensions);
- }
- toString() {
- if (Array.isArray(this.dimensions) && this.dimensions.length > 0) {
- return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
- }
- return '';
- }
- };
- kann.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading KaNN model.';
- }
- };
- export const ModelFactory = kann.ModelFactory;
|