dnn.js 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. // Experimental
  2. const dnn = {};
  3. dnn.ModelFactory = class {
  4. match(context) {
  5. const tags = context.tags('pb');
  6. if (tags.get(4) == 0 && tags.get(10) == 2) {
  7. context.type = 'dnn';
  8. }
  9. }
  10. async open(context) {
  11. dnn.proto = await context.require('./dnn-proto');
  12. dnn.proto = dnn.proto.dnn;
  13. let model = null;
  14. try {
  15. const reader = context.read('protobuf.binary');
  16. model = dnn.proto.Model.decode(reader);
  17. } catch (error) {
  18. const message = error && error.message ? error.message : error.toString();
  19. throw new dnn.Error(`File format is not dnn.Graph (${message.replace(/\.$/, '')}).`);
  20. }
  21. const metadata = await context.metadata('dnn-metadata.json');
  22. return new dnn.Model(metadata, model);
  23. }
  24. };
  25. dnn.Model = class {
  26. constructor(metadata, model) {
  27. this.name = model.name || '';
  28. this.format = `SnapML${model.version ? ` v${model.version}` : ''}`;
  29. this.graphs = [ new dnn.Graph(metadata, model) ];
  30. }
  31. };
  32. dnn.Graph = class {
  33. constructor(metadata, model) {
  34. this.inputs = [];
  35. this.outputs = [];
  36. this.nodes = [];
  37. const scope = {};
  38. let index = 0;
  39. for (const node of model.node) {
  40. node.input = node.input.map((input) => scope[input] ? scope[input] : input);
  41. node.output = node.output.map((output) => {
  42. scope[output] = scope[output] ? `${output}\n${index}` : output; // custom argument id
  43. return scope[output];
  44. });
  45. index++;
  46. }
  47. const values = new Map();
  48. values.map = (name, type) => {
  49. if (!values.has(name)) {
  50. values.set(name, new dnn.Value(name, type));
  51. }
  52. return values.get(name);
  53. };
  54. for (const input of model.input) {
  55. const shape = input.shape;
  56. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape.dim0, shape.dim1, shape.dim2, shape.dim3 ]));
  57. const argument = new dnn.Argument(input.name, [ values.map(input.name, type) ]);
  58. this.inputs.push(argument);
  59. }
  60. for (const output of model.output) {
  61. const shape = output.shape;
  62. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape.dim0, shape.dim1, shape.dim2, shape.dim3 ]));
  63. const argument = new dnn.Argument(output.name, [ values.map(output.name, type) ]);
  64. this.outputs.push(argument);
  65. }
  66. if (this.inputs.length === 0 && model.input_name && model.input_shape && model.input_shape.length === model.input_name.length * 4) {
  67. for (let i = 0; i < model.input_name.length; i++) {
  68. const name = model.input_name[i];
  69. const shape = model.input_shape.slice(i * 4, (i * 4 + 4));
  70. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape[1], shape[3], shape[2], shape[0] ]));
  71. const argument = new dnn.Argument(name, [ values.map(name, type) ]);
  72. this.inputs.push(argument);
  73. }
  74. }
  75. if (this.inputs.length === 0 && model.input_shape && model.input_shape.length === 4 && model.node.length > 0 && model.node[0].input.length > 0) {
  76. /* eslint-disable prefer-destructuring */
  77. const name = model.node[0].input[0];
  78. /* eslint-enable prefer-destructuring */
  79. const shape = model.input_shape;
  80. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape[1], shape[3], shape[2], shape[0] ]));
  81. const argument = new dnn.Argument(name, [ values.map(name, type) ]);
  82. this.inputs.push(argument);
  83. }
  84. for (const node of model.node) {
  85. this.nodes.push(new dnn.Node(metadata, node, values));
  86. }
  87. }
  88. };
  89. dnn.Argument = class {
  90. constructor(name, value) {
  91. this.name = name;
  92. this.value = value;
  93. }
  94. };
  95. dnn.Value = class {
  96. constructor(name, type, initializer, quantization) {
  97. if (typeof name !== 'string') {
  98. throw new dnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  99. }
  100. this.name = name;
  101. this.type = type || null;
  102. this.initializer = initializer || null;
  103. if (quantization) {
  104. this.quantization = {
  105. type: 'lookup',
  106. value: new Map(quantization.map((value, index) => [ index, value ]))
  107. };
  108. }
  109. }
  110. };
  111. dnn.Node = class {
  112. constructor(metadata, node, values) {
  113. const layer = node.layer;
  114. this.name = layer.name;
  115. const type = layer.type;
  116. this.type = metadata.type(type) || { name: type };
  117. this.attributes = [];
  118. this.inputs = [];
  119. this.outputs = [];
  120. const inputs = node.input.map((input) => values.map(input));
  121. for (const weight of layer.weight) {
  122. let quantization = null;
  123. if (layer.is_quantized && weight === layer.weight[0] && layer.quantization && layer.quantization.data) {
  124. const data = layer.quantization.data;
  125. quantization = new Array(data.length >> 2);
  126. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  127. for (let i = 0; i < quantization.length; i++) {
  128. quantization[i] = view.getFloat32(i << 2, true);
  129. }
  130. }
  131. const initializer = new dnn.Tensor(weight, quantization);
  132. inputs.push(new dnn.Value('', initializer.type, initializer, quantization));
  133. }
  134. const outputs = node.output.map((output) => values.map(output));
  135. if (inputs && inputs.length > 0) {
  136. let inputIndex = 0;
  137. if (this.type && this.type.inputs) {
  138. for (const inputSchema of this.type.inputs) {
  139. if (inputIndex < inputs.length || inputSchema.option != 'optional') {
  140. const inputCount = (inputSchema.option == 'variadic') ? (node.input.length - inputIndex) : 1;
  141. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount);
  142. this.inputs.push(new dnn.Argument(inputSchema.name, inputArguments));
  143. inputIndex += inputCount;
  144. }
  145. }
  146. }
  147. this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  148. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  149. return new dnn.Argument(inputName, [ input ]);
  150. }));
  151. }
  152. if (outputs.length > 0) {
  153. this.outputs = outputs.map((output, index) => {
  154. const inputName = (index == 0) ? 'output' : index.toString();
  155. return new dnn.Argument(inputName, [ output ]);
  156. });
  157. }
  158. for (const key of Object.keys(layer)) {
  159. switch (key) {
  160. case 'name':
  161. case 'type':
  162. case 'weight':
  163. case 'is_quantized':
  164. case 'quantization':
  165. break;
  166. default: {
  167. const attribute = new dnn.Attribute(metadata.attribute(type, key), key, layer[key]);
  168. this.attributes.push(attribute);
  169. break;
  170. }
  171. }
  172. }
  173. }
  174. };
  175. dnn.Attribute = class {
  176. constructor(metadata, name, value) {
  177. this.name = name;
  178. this.value = value;
  179. }
  180. };
  181. dnn.Tensor = class {
  182. constructor(weight, quantization) {
  183. const shape = new dnn.TensorShape([ weight.dim0, weight.dim1, weight.dim2, weight.dim3 ]);
  184. this.values = quantization ? weight.quantized_data : weight.data;
  185. const size = shape.dimensions.reduce((a, b) => a * b, 1);
  186. const itemSize = Math.floor(this.values.length / size);
  187. const remainder = this.values.length - (itemSize * size);
  188. if (remainder < 0 || remainder > itemSize) {
  189. throw new dnn.Error('Invalid tensor data size.');
  190. }
  191. let dataType = '?';
  192. switch (itemSize) {
  193. case 1: dataType = 'int8'; break;
  194. case 2: dataType = 'float16'; break;
  195. case 4: dataType = 'float32'; break;
  196. default: dataType = '?'; break;
  197. }
  198. this.type = new dnn.TensorType(dataType, shape);
  199. }
  200. };
  201. dnn.TensorType = class {
  202. constructor(dataType, shape) {
  203. this.dataType = dataType;
  204. this.shape = shape;
  205. }
  206. toString() {
  207. return this.dataType + this.shape.toString();
  208. }
  209. };
  210. dnn.TensorShape = class {
  211. constructor(shape) {
  212. this.dimensions = shape;
  213. }
  214. toString() {
  215. if (!this.dimensions || this.dimensions.length == 0) {
  216. return '';
  217. }
  218. return `[${this.dimensions.join(',')}]`;
  219. }
  220. };
  221. dnn.Error = class extends Error {
  222. constructor(message) {
  223. super(message);
  224. this.name = 'Error loading SnapML model.';
  225. }
  226. };
  227. export const ModelFactory = dnn.ModelFactory;