dnn.js 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var dnn = dnn || {};
  4. dnn.ModelFactory = class {
  5. match(context) {
  6. const tags = context.tags('pb');
  7. if (tags.get(4) == 0 && tags.get(10) == 2) {
  8. return true;
  9. }
  10. return false;
  11. }
  12. open(context) {
  13. return context.require('./dnn-proto').then(() => {
  14. let model = null;
  15. try {
  16. dnn.proto = protobuf.get('dnn').dnn;
  17. const stream = context.stream;
  18. const reader = protobuf.BinaryReader.open(stream);
  19. model = dnn.proto.Model.decode(reader);
  20. }
  21. catch (error) {
  22. const message = error && error.message ? error.message : error.toString();
  23. throw new dnn.Error('File format is not dnn.Graph (' + message.replace(/\.$/, '') + ').');
  24. }
  25. return dnn.Metadata.open(context).then((metadata) => {
  26. return new dnn.Model(metadata, model);
  27. });
  28. });
  29. }
  30. };
  31. dnn.Model = class {
  32. constructor(metadata, model) {
  33. this._name = model.name || '';
  34. this._format = 'SnapML' + (model.version ? ' v' + model.version.toString() : '');
  35. this._graphs = [ new dnn.Graph(metadata, model) ];
  36. }
  37. get format() {
  38. return this._format;
  39. }
  40. get name() {
  41. return this._name;
  42. }
  43. get graphs() {
  44. return this._graphs;
  45. }
  46. };
  47. dnn.Graph = class {
  48. constructor(metadata, model) {
  49. this._inputs = [];
  50. this._outputs = [];
  51. this._nodes = [];
  52. const scope = {};
  53. let index = 0;
  54. for (const node of model.node) {
  55. node.input = node.input.map((input) => scope[input] ? scope[input] : input);
  56. node.output = node.output.map((output) => {
  57. scope[output] = scope[output] ? output + '\n' + index.toString() : output; // custom argument id
  58. return scope[output];
  59. });
  60. index++;
  61. }
  62. const args = new Map();
  63. const arg = (name, type) => {
  64. if (!args.has(name)) {
  65. args.set(name, new dnn.Argument(name, type));
  66. }
  67. return args.get(name);
  68. };
  69. for (const input of model.input) {
  70. const shape = input.shape;
  71. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape.dim0, shape.dim1, shape.dim2, shape.dim3 ]));
  72. this._inputs.push(new dnn.Parameter(input.name, [ arg(input.name, type) ]));
  73. }
  74. for (const output of model.output) {
  75. const shape = output.shape;
  76. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape.dim0, shape.dim1, shape.dim2, shape.dim3 ]));
  77. this._outputs.push(new dnn.Parameter(output.name, [ arg(output.name, type) ]));
  78. }
  79. if (this._inputs.length === 0 && model.input_name && model.input_shape && model.input_shape.length === model.input_name.length * 4) {
  80. for (let i = 0; i < model.input_name.length; i++) {
  81. const name = model.input_name[i];
  82. const shape = model.input_shape.slice(i * 4, (i * 4 + 4));
  83. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape[1], shape[3], shape[2], shape[0] ]));
  84. this._inputs.push(new dnn.Parameter(name, [ arg(name, type) ]));
  85. }
  86. }
  87. if (this._inputs.length === 0 && model.input_shape && model.input_shape.length === 4 &&
  88. model.node.length > 0 && model.node[0].input.length > 0) {
  89. const name = model.node[0].input[0];
  90. const shape = model.input_shape;
  91. const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape[1], shape[3], shape[2], shape[0] ]));
  92. this._inputs.push(new dnn.Parameter(name, [ arg(name, type) ]));
  93. }
  94. for (const node of model.node) {
  95. this._nodes.push(new dnn.Node(metadata, node, arg));
  96. }
  97. }
  98. get inputs() {
  99. return this._inputs;
  100. }
  101. get outputs() {
  102. return this._outputs;
  103. }
  104. get nodes() {
  105. return this._nodes;
  106. }
  107. };
  108. dnn.Parameter = class {
  109. constructor(name, args) {
  110. this._name = name;
  111. this._arguments = args;
  112. }
  113. get name() {
  114. return this._name;
  115. }
  116. get visible() {
  117. return true;
  118. }
  119. get arguments() {
  120. return this._arguments;
  121. }
  122. };
  123. dnn.Argument = class {
  124. constructor(name, type, initializer, quantization) {
  125. if (typeof name !== 'string') {
  126. throw new dnn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  127. }
  128. this._name = name;
  129. this._type = type || null;
  130. this._initializer = initializer || null;
  131. this._quantization = quantization || null;
  132. }
  133. get name() {
  134. return this._name;
  135. }
  136. get type() {
  137. return this._type;
  138. }
  139. get quantization() {
  140. if (this._quantization) {
  141. return this._quantization.map((value, index) => index.toString() + ' = ' + value.toString()).join('; ');
  142. }
  143. return null;
  144. }
  145. get initializer() {
  146. return this._initializer;
  147. }
  148. };
  149. dnn.Node = class {
  150. constructor(metadata, node, arg) {
  151. const layer = node.layer;
  152. this._name = layer.name;
  153. const type = layer.type;
  154. this._type = metadata.type(type) || { name: type };
  155. this._attributes = [];
  156. this._inputs = [];
  157. this._outputs = [];
  158. const inputs = node.input.map((input) => { return arg(input); });
  159. for (const weight of layer.weight) {
  160. let quantization = null;
  161. if (layer.is_quantized && weight === layer.weight[0] && layer.quantization && layer.quantization.data) {
  162. const data = layer.quantization.data;
  163. quantization = new Array(data.length >> 2);
  164. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  165. for (let i = 0; i < quantization.length; i++) {
  166. quantization[i] = view.getFloat32(i << 2, true);
  167. }
  168. }
  169. const initializer = new dnn.Tensor(weight, quantization);
  170. inputs.push(new dnn.Argument('', initializer.type, initializer, quantization));
  171. }
  172. const outputs = node.output.map((output) => { return arg(output); });
  173. if (inputs && inputs.length > 0) {
  174. let inputIndex = 0;
  175. if (this._type && this._type.inputs) {
  176. for (const inputSchema of this._type.inputs) {
  177. if (inputIndex < inputs.length || inputSchema.option != 'optional') {
  178. const inputCount = (inputSchema.option == 'variadic') ? (node.input.length - inputIndex) : 1;
  179. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount);
  180. this._inputs.push(new dnn.Parameter(inputSchema.name, inputArguments));
  181. inputIndex += inputCount;
  182. }
  183. }
  184. }
  185. this._inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  186. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  187. return new dnn.Parameter(inputName, [ input ]);
  188. }));
  189. }
  190. if (outputs.length > 0) {
  191. this._outputs = outputs.map((output, index) => {
  192. const inputName = (index == 0) ? 'output' : index.toString();
  193. return new dnn.Parameter(inputName, [ output ]);
  194. });
  195. }
  196. for (const key of Object.keys(layer)) {
  197. switch (key) {
  198. case 'name':
  199. case 'type':
  200. case 'weight':
  201. case 'is_quantized':
  202. case 'quantization':
  203. break;
  204. default:
  205. this._attributes.push(new dnn.Attribute(metadata.attribute(type, key), key, layer[key]));
  206. break;
  207. }
  208. }
  209. }
  210. get name() {
  211. return this._name;
  212. }
  213. get type() {
  214. return this._type;
  215. }
  216. get inputs() {
  217. return this._inputs;
  218. }
  219. get outputs() {
  220. return this._outputs;
  221. }
  222. get attributes() {
  223. return this._attributes;
  224. }
  225. };
  226. dnn.Attribute = class {
  227. constructor(metadata, name, value) {
  228. this._name = name;
  229. this._value = value;
  230. }
  231. get name() {
  232. return this._name;
  233. }
  234. get value() {
  235. return this._value;
  236. }
  237. };
  238. dnn.Tensor = class {
  239. constructor(weight, quantization) {
  240. const shape = new dnn.TensorShape([ weight.dim0, weight.dim1, weight.dim2, weight.dim3 ]);
  241. this._data = quantization ? weight.quantized_data : weight.data;
  242. const size = shape.dimensions.reduce((a, b) => a * b, 1);
  243. const itemSize = Math.floor(this._data.length / size);
  244. const remainder = this._data.length - (itemSize * size);
  245. if (remainder < 0 || remainder > itemSize) {
  246. throw new dnn.Error('Invalid tensor data size.');
  247. }
  248. switch (itemSize) {
  249. case 1:
  250. this._type = new dnn.TensorType('int8', shape);
  251. break;
  252. case 2:
  253. this._type = new dnn.TensorType('float16', shape);
  254. break;
  255. case 4:
  256. this._type = new dnn.TensorType('float16', shape);
  257. break;
  258. default:
  259. this._type = new dnn.TensorType('?', shape);
  260. break;
  261. }
  262. }
  263. get kind() {
  264. return 'Weight';
  265. }
  266. get type() {
  267. return this._type;
  268. }
  269. get state() {
  270. return this._context().state;
  271. }
  272. get value() {
  273. const context = this._context();
  274. if (context.state) {
  275. return null;
  276. }
  277. context.limit = Number.MAX_SAFE_INTEGER;
  278. return this._decode(context, 0);
  279. }
  280. toString() {
  281. const context = this._context();
  282. if (context.state) {
  283. return '';
  284. }
  285. context.limit = 10000;
  286. const value = this._decode(context, 0);
  287. return JSON.stringify(value, null, 4);
  288. }
  289. _context() {
  290. const context = {};
  291. context.state = null;
  292. context.index = 0;
  293. context.count = 0;
  294. if (this._data == null) {
  295. context.state = 'Tensor data is empty.';
  296. return context;
  297. }
  298. switch (this._type.dataType) {
  299. case 'int8':
  300. case 'float16':
  301. case 'float32':
  302. break;
  303. default:
  304. context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
  305. return context;
  306. }
  307. context.dataType = this._type.dataType;
  308. context.shape = this._type.shape.dimensions;
  309. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  310. return context;
  311. }
  312. _decode(context, dimension) {
  313. const shape = (context.shape.length == 0) ? [ 1 ] : context.shape;
  314. const size = shape[dimension];
  315. const results = [];
  316. if (dimension == shape.length - 1) {
  317. for (let i = 0; i < size; i++) {
  318. if (context.count > context.limit) {
  319. results.push('...');
  320. return results;
  321. }
  322. switch (context.dataType) {
  323. case 'int8':
  324. results.push(context.data.getInt8(context.index));
  325. context.index++;
  326. context.count++;
  327. break;
  328. case 'float16':
  329. results.push(context.data.getFloat16(context.index, true));
  330. context.index += 2;
  331. context.count++;
  332. break;
  333. case 'float32':
  334. results.push(context.data.getFloat32(context.index, true));
  335. context.index += 4;
  336. context.count++;
  337. break;
  338. default:
  339. break;
  340. }
  341. }
  342. }
  343. else {
  344. for (let j = 0; j < size; j++) {
  345. if (context.count > context.limit) {
  346. results.push('...');
  347. return results;
  348. }
  349. results.push(this._decode(context, dimension + 1));
  350. }
  351. }
  352. if (context.shape.length == 0) {
  353. return results[0];
  354. }
  355. return results;
  356. }
  357. };
  358. dnn.TensorType = class {
  359. constructor(dataType, shape) {
  360. this._dataType = dataType;
  361. this._shape = shape;
  362. }
  363. get dataType() {
  364. return this._dataType;
  365. }
  366. get shape() {
  367. return this._shape;
  368. }
  369. toString() {
  370. return this.dataType + this._shape.toString();
  371. }
  372. };
  373. dnn.TensorShape = class {
  374. constructor(shape) {
  375. this._dimensions = shape;
  376. }
  377. get dimensions() {
  378. return this._dimensions;
  379. }
  380. toString() {
  381. if (!this._dimensions || this._dimensions.length == 0) {
  382. return '';
  383. }
  384. return '[' + this._dimensions.join(',') + ']';
  385. }
  386. };
  387. dnn.Metadata = class {
  388. static open(context) {
  389. if (dnn.Metadata._metadata) {
  390. return Promise.resolve(dnn.Metadata._metadata);
  391. }
  392. return context.request('dnn-metadata.json', 'utf-8', null).then((data) => {
  393. dnn.Metadata._metadata = new dnn.Metadata(data);
  394. return dnn.Metadata._metadata;
  395. }).catch(() => {
  396. dnn.Metadata._metadata = new dnn.Metadata(null);
  397. return dnn.Metadata._metadata;
  398. });
  399. }
  400. constructor(data) {
  401. this._map = new Map();
  402. this._attributeCache = new Map();
  403. if (data) {
  404. const metadata = JSON.parse(data);
  405. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  406. }
  407. }
  408. type(name) {
  409. return this._map.get(name);
  410. }
  411. attribute(type, name) {
  412. const key = type + ':' + name;
  413. if (!this._attributeCache.has(key)) {
  414. const schema = this.type(type);
  415. if (schema && schema.attributes && schema.attributes.length > 0) {
  416. for (const attribute of schema.attributes) {
  417. this._attributeCache.set(type + ':' + attribute.name, attribute);
  418. }
  419. }
  420. if (!this._attributeCache.has(key)) {
  421. this._attributeCache.set(key, null);
  422. }
  423. }
  424. return this._attributeCache.get(key);
  425. }
  426. };
  427. dnn.Error = class extends Error {
  428. constructor(message) {
  429. super(message);
  430. this.name = 'Error loading SnapML model.';
  431. }
  432. };
  433. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  434. module.exports.ModelFactory = dnn.ModelFactory;
  435. }