dnn.js 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. // Experimental
  2. const dnn = {};
  3. dnn.ModelFactory = class {
  4. async match(context) {
  5. const tags = await context.tags('pb');
  6. if (tags.get(4) === 0 && tags.get(10) === 2) {
  7. return context.set('dnn');
  8. }
  9. return null;
  10. }
  11. async open(context) {
  12. dnn.proto = await context.require('./dnn-proto');
  13. dnn.proto = dnn.proto.dnn;
  14. let model = null;
  15. try {
  16. const reader = await context.read('protobuf.binary');
  17. model = dnn.proto.Model.decode(reader);
  18. } catch (error) {
  19. const message = error && error.message ? error.message : error.toString();
  20. throw new dnn.Error(`File format is not dnn.Graph (${message.replace(/\.$/, '')}).`);
  21. }
  22. const metadata = await context.metadata('dnn-metadata.json');
  23. return new dnn.Model(metadata, model);
  24. }
  25. };
  26. dnn.Model = class {
  27. constructor(metadata, model) {
  28. this.name = model.name || '';
  29. this.format = `SnapML${model.version ? ` v${model.version}` : ''}`;
  30. this.modules = [new dnn.Graph(metadata, model)];
  31. }
  32. };
  33. dnn.Graph = class {
  34. constructor(metadata, model) {
  35. this.inputs = [];
  36. this.outputs = [];
  37. this.nodes = [];
  38. const scope = {};
  39. for (let i = 0; i < model.node.length; i++) {
  40. const node = model.node[i];
  41. node.input = node.input.map((input) => scope[input] ? scope[input] : input);
  42. node.output = node.output.map((output) => {
  43. scope[output] = scope[output] ? `${output}\n${i}` : output; // custom argument id
  44. return scope[output];
  45. });
  46. }
  47. const values = new Map();
  48. values.map = (name, type) => {
  49. if (!values.has(name)) {
  50. values.set(name, new dnn.Value(name, type));
  51. }
  52. return values.get(name);
  53. };
  54. for (const input of model.input) {
  55. const shape = input.shape;
  56. const type = new dnn.TensorType('float32', new dnn.TensorShape([shape.dim0, shape.dim1, shape.dim2, shape.dim3]));
  57. const argument = new dnn.Argument(input.name, [values.map(input.name, type)]);
  58. this.inputs.push(argument);
  59. }
  60. for (const output of model.output) {
  61. const shape = output.shape;
  62. const type = new dnn.TensorType('float32', new dnn.TensorShape([shape.dim0, shape.dim1, shape.dim2, shape.dim3]));
  63. const argument = new dnn.Argument(output.name, [values.map(output.name, type)]);
  64. this.outputs.push(argument);
  65. }
  66. if (this.inputs.length === 0 && model.input_name && model.input_shape && model.input_shape.length === model.input_name.length * 4) {
  67. for (let i = 0; i < model.input_name.length; i++) {
  68. const name = model.input_name[i];
  69. const shape = model.input_shape.slice(i * 4, (i * 4 + 4));
  70. const type = new dnn.TensorType('float32', new dnn.TensorShape([shape[1], shape[3], shape[2], shape[0]]));
  71. const argument = new dnn.Argument(name, [values.map(name, type)]);
  72. this.inputs.push(argument);
  73. }
  74. }
  75. if (this.inputs.length === 0 && model.input_shape && model.input_shape.length === 4 && model.node.length > 0 && model.node[0].input.length > 0) {
  76. const [name] = model.node[0].input;
  77. const shape = model.input_shape;
  78. const type = new dnn.TensorType('float32', new dnn.TensorShape([shape[1], shape[3], shape[2], shape[0]]));
  79. const argument = new dnn.Argument(name, [values.map(name, type)]);
  80. this.inputs.push(argument);
  81. }
  82. for (const node of model.node) {
  83. this.nodes.push(new dnn.Node(metadata, node, values));
  84. }
  85. }
  86. };
  87. dnn.Argument = class {
  88. constructor(name, value) {
  89. this.name = name;
  90. this.value = value;
  91. }
  92. };
  93. dnn.Value = class {
  94. constructor(name, type = null, initializer = null, quantization = null) {
  95. if (typeof name !== 'string') {
  96. throw new dnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  97. }
  98. this.name = name;
  99. this.type = type;
  100. this.initializer = initializer;
  101. if (quantization) {
  102. this.quantization = {
  103. type: 'lookup',
  104. value: new Map(quantization.map((value, index) => [index, value]))
  105. };
  106. }
  107. }
  108. };
  109. dnn.Node = class {
  110. constructor(metadata, node, values) {
  111. const layer = node.layer;
  112. this.name = layer.name;
  113. const type = layer.type;
  114. this.type = metadata.type(type) || { name: type };
  115. this.attributes = [];
  116. this.inputs = [];
  117. this.outputs = [];
  118. const inputs = node.input.map((input) => values.map(input));
  119. for (const weight of layer.weight) {
  120. let quantization = null;
  121. if (layer.is_quantized && weight === layer.weight[0] && layer.quantization && layer.quantization.data) {
  122. const data = layer.quantization.data;
  123. quantization = new Array(data.length >> 2);
  124. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  125. for (let i = 0; i < quantization.length; i++) {
  126. quantization[i] = view.getFloat32(i << 2, true);
  127. }
  128. }
  129. const initializer = new dnn.Tensor(weight, quantization);
  130. inputs.push(new dnn.Value('', initializer.type, initializer, quantization));
  131. }
  132. const outputs = node.output.map((output) => values.map(output));
  133. if (inputs && inputs.length > 0) {
  134. let inputIndex = 0;
  135. if (this.type && this.type.inputs) {
  136. for (const inputSchema of this.type.inputs) {
  137. if (inputIndex < inputs.length || inputSchema.option !== 'optional') {
  138. const inputCount = (inputSchema.option === 'variadic') ? (node.input.length - inputIndex) : 1;
  139. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount);
  140. this.inputs.push(new dnn.Argument(inputSchema.name, inputArguments));
  141. inputIndex += inputCount;
  142. }
  143. }
  144. }
  145. this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  146. const inputName = ((inputIndex + index) === 0) ? 'input' : (inputIndex + index).toString();
  147. return new dnn.Argument(inputName, [input]);
  148. }));
  149. }
  150. if (outputs.length > 0) {
  151. this.outputs = outputs.map((output, index) => {
  152. const inputName = (index === 0) ? 'output' : index.toString();
  153. return new dnn.Argument(inputName, [output]);
  154. });
  155. }
  156. for (const [key, obj] of Object.entries(layer)) {
  157. switch (key) {
  158. case 'name':
  159. case 'type':
  160. case 'weight':
  161. case 'is_quantized':
  162. case 'quantization':
  163. break;
  164. default: {
  165. const attribute = new dnn.Argument(key, obj);
  166. this.attributes.push(attribute);
  167. break;
  168. }
  169. }
  170. }
  171. }
  172. };
  173. dnn.Tensor = class {
  174. constructor(weight, quantization) {
  175. const shape = new dnn.TensorShape([weight.dim0, weight.dim1, weight.dim2, weight.dim3]);
  176. this.values = quantization ? weight.quantized_data : weight.data;
  177. const size = shape.dimensions.reduce((a, b) => a * b, 1);
  178. const itemsize = Math.floor(this.values.length / size);
  179. const remainder = this.values.length - (itemsize * size);
  180. if (remainder < 0 || remainder > itemsize) {
  181. throw new dnn.Error(`Invalid tensor data size '${this.values.length}' tensor shape '[${shape.dimensions}]' '.`);
  182. }
  183. let dataType = '?';
  184. switch (itemsize) {
  185. case 1: dataType = 'int8'; break;
  186. case 2: dataType = 'float16'; break;
  187. case 4: dataType = 'float32'; break;
  188. default: dataType = '?'; break;
  189. }
  190. this.type = new dnn.TensorType(dataType, shape);
  191. }
  192. };
  193. dnn.TensorType = class {
  194. constructor(dataType, shape) {
  195. this.dataType = dataType;
  196. this.shape = shape;
  197. }
  198. toString() {
  199. return this.dataType + this.shape.toString();
  200. }
  201. };
  202. dnn.TensorShape = class {
  203. constructor(shape) {
  204. this.dimensions = shape;
  205. }
  206. toString() {
  207. if (!this.dimensions || this.dimensions.length === 0) {
  208. return '';
  209. }
  210. return `[${this.dimensions.join(',')}]`;
  211. }
  212. };
  213. dnn.Error = class extends Error {
  214. constructor(message) {
  215. super(message);
  216. this.name = 'Error loading SnapML model.';
  217. }
  218. };
  219. export const ModelFactory = dnn.ModelFactory;