dlc.js 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. var dlc = dlc || {};
  2. var text = text || require('./text');
  3. dlc.ModelFactory = class {
  4. match(context) {
  5. return dlc.Container.open(context);
  6. }
  7. open(context, match) {
  8. return context.require('./dlc-schema').then(() => {
  9. dlc.schema = flatbuffers.get('dlc').dlc;
  10. const container = match;
  11. return context.metadata('dlc-metadata.json').then((metadata) => {
  12. let model = null;
  13. let params = null;
  14. const metadata_props = container.metadata;
  15. try {
  16. model = container.model;
  17. }
  18. catch (error) {
  19. const message = error && error.message ? error.message : error.toString();
  20. throw new dlc.Error('File format is not dlc.NetDef (' + message.replace(/\.$/, '') + ').');
  21. }
  22. try {
  23. params = container.params;
  24. }
  25. catch (error) {
  26. const message = error && error.message ? error.message : error.toString();
  27. throw new dlc.Error('File format is not dlc.NetParam (' + message.replace(/\.$/, '') + ').');
  28. }
  29. return new dlc.Model(metadata, model, params, metadata_props);
  30. });
  31. });
  32. }
  33. };
  34. dlc.Model = class {
  35. constructor(metadata, model, params, metadata_props) {
  36. this._format = model ? 'DLC' : 'DLC Weights';
  37. this._metadata = [];
  38. if (metadata_props.size > 0) {
  39. const version = metadata_props.get('model-version');
  40. if (version) {
  41. this._version = version;
  42. }
  43. const converter = metadata_props.get('converter-command');
  44. if (converter) {
  45. const source = converter.split(' ').shift().trim();
  46. if (source.length > 0) {
  47. const version = metadata_props.get('converter-version');
  48. this._metadata.push({ name: 'source', value: version ? source + ' v' + version : source });
  49. }
  50. }
  51. }
  52. this._graphs = [ new dlc.Graph(metadata, model, params) ];
  53. }
  54. get format() {
  55. return this._format;
  56. }
  57. get version() {
  58. return this._version;
  59. }
  60. get metadata() {
  61. return this._metadata;
  62. }
  63. get graphs() {
  64. return this._graphs;
  65. }
  66. };
  67. dlc.Graph = class {
  68. constructor(metadata, model, params) {
  69. this._inputs = [];
  70. this._outputs = [];
  71. const args = new Map();
  72. const arg = (name) => {
  73. if (!args.has(name)) {
  74. args.set(name, new dlc.Argument(name));
  75. }
  76. return args.get(name);
  77. };
  78. if (model) {
  79. for (const node of model.nodes) {
  80. for (const input of node.inputs) {
  81. if (!args.has(input)) {
  82. args.set(input, {});
  83. }
  84. }
  85. const shapes = new Array(node.outputs.length);
  86. for (const attr of node.attributes) {
  87. if (attr.name === 'OutputDims') {
  88. for (const entry of Object.entries(attr.attributes)) {
  89. const index = parseInt(entry[0], 10);
  90. shapes[index] = Array.from(entry[1].int32_list);
  91. }
  92. break;
  93. }
  94. }
  95. for (let i = 0; i < node.outputs.length; i++) {
  96. const output = node.outputs[i];
  97. if (!args.has(output)) {
  98. args.set(output, {});
  99. }
  100. const value = args.get(output);
  101. if (i < shapes.length) {
  102. value.shape = shapes[i];
  103. }
  104. }
  105. }
  106. for (const entry of args) {
  107. const value = entry[1];
  108. const type = value.shape ? new dlc.TensorType(null, value.shape) : null;
  109. const argument = new dlc.Argument(entry[0], type);
  110. args.set(entry[0], argument);
  111. }
  112. this._nodes = [];
  113. const weights = new Map(params ? params.weights.map((weights) => [ weights.name, weights ]) : []);
  114. for (const node of model.nodes) {
  115. if (node.type === 'Input') {
  116. this._inputs.push(new dlc.Parameter(node.name, node.inputs.map((input) => arg(input))));
  117. continue;
  118. }
  119. this._nodes.push(new dlc.Node(metadata, node, weights.get(node.name), arg));
  120. }
  121. }
  122. else {
  123. this._nodes = params.weights.map((weights) => new dlc.Node(metadata, null, weights, arg));
  124. }
  125. }
  126. get inputs() {
  127. return this._inputs;
  128. }
  129. get outputs() {
  130. return this._outputs;
  131. }
  132. get nodes() {
  133. return this._nodes;
  134. }
  135. };
  136. dlc.Parameter = class {
  137. constructor(name, args) {
  138. this._name = name;
  139. this._arguments = args;
  140. }
  141. get name() {
  142. return this._name;
  143. }
  144. get visible() {
  145. return true;
  146. }
  147. get arguments() {
  148. return this._arguments;
  149. }
  150. };
  151. dlc.Argument = class {
  152. constructor(name, type, initializer) {
  153. if (typeof name !== 'string') {
  154. throw new dlc.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  155. }
  156. this._name = name;
  157. this._type = type;
  158. this._initializer = initializer;
  159. }
  160. get name() {
  161. return this._name;
  162. }
  163. get type() {
  164. return this._type;
  165. }
  166. get initializer() {
  167. return this._initializer;
  168. }
  169. };
  170. dlc.Node = class {
  171. constructor(metadata, node, weights, arg) {
  172. if (node) {
  173. this._type = metadata.type(node.type);
  174. this._name = node.name;
  175. const inputs = Array.from(node.inputs).map((input) => arg(input));
  176. this._inputs = inputs.length === 0 ? [] : [ new dlc.Parameter(inputs.length === 1 ? 'input' : 'inputs', inputs) ];
  177. const outputs = Array.from(node.outputs).map((output) => arg(output));
  178. this._outputs = outputs.length === 0 ? [] : [ new dlc.Parameter(outputs.length === 1 ? 'output' : 'outputs', outputs) ];
  179. this._attributes = [];
  180. for (const attr of node.attributes) {
  181. if (attr.name === 'OutputDims') {
  182. continue;
  183. }
  184. const attribute = new dlc.Attribute(metadata.attribute(node.type, attr.name), attr);
  185. this._attributes.push(attribute);
  186. }
  187. if (weights) {
  188. for (const tensor of weights.tensors) {
  189. const type = new dlc.TensorType(tensor.data.data_type, tensor.shape);
  190. const argument = new dlc.Argument('', type, new dlc.Tensor(type, tensor.data));
  191. this._inputs.push(new dlc.Parameter(tensor.name, [ argument ]));
  192. }
  193. }
  194. }
  195. else {
  196. this._type = { name: 'Weights' };
  197. this._name = weights.name;
  198. this._inputs = weights.tensors.map((tensor) => {
  199. const type = new dlc.TensorType(tensor.data.data_type, tensor.shape);
  200. const argument = new dlc.Argument('', type, new dlc.Tensor(type, tensor.data));
  201. return new dlc.Parameter(tensor.name, [ argument ]);
  202. });
  203. this._outputs = [];
  204. this._attributes = [];
  205. }
  206. }
  207. get type() {
  208. return this._type;
  209. }
  210. get name() {
  211. return this._name;
  212. }
  213. get inputs() {
  214. return this._inputs;
  215. }
  216. get outputs() {
  217. return this._outputs;
  218. }
  219. get attributes() {
  220. return this._attributes;
  221. }
  222. };
  223. dlc.Attribute = class {
  224. constructor(metadata, attr) {
  225. this._name = attr.name;
  226. const read = (attr) => {
  227. switch (attr.type) {
  228. case 1: return [ 'boolean', attr.bool_value ];
  229. case 2: return [ 'int32', attr.int32_value ];
  230. case 3: return [ 'uint32', attr.uint32_value ];
  231. case 4: return [ 'float32', attr.float32_value ];
  232. case 5: return [ 'string', attr.string_value ];
  233. case 7: return [ 'byte[]', Array.from(attr.byte_list) ];
  234. case 8: return [ 'int32[]', Array.from(attr.int32_list) ];
  235. case 9: return [ 'float32[]', Array.from(attr.float32_list) ];
  236. case 11: {
  237. const obj = {};
  238. for (const attribute of attr.attributes) {
  239. const entry = read(attribute);
  240. obj[attribute.name] = entry[1];
  241. }
  242. return [ '', obj ];
  243. }
  244. default:
  245. throw new dlc.Error("Unsupported attribute type '" + attr.type + "'.");
  246. }
  247. };
  248. const entry = read(attr);
  249. if (entry) {
  250. this._type = entry[0];
  251. this._value = entry[1];
  252. }
  253. if (metadata && metadata.type) {
  254. this._type = metadata.type;
  255. this._value = dlc.Utility.enum(this._type, this._value);
  256. }
  257. }
  258. get name() {
  259. return this._name;
  260. }
  261. get type() {
  262. return this._type;
  263. }
  264. get value() {
  265. return this._value;
  266. }
  267. };
  268. dlc.TensorType = class {
  269. constructor(dataType, shape) {
  270. switch (dataType) {
  271. case null: this._dataType = '?'; break;
  272. case 6: this._dataType = 'uint8'; break;
  273. case 9: this._dataType = 'float32'; break;
  274. default:
  275. throw new dlc.Error("Unsupported data type '" + JSON.stringify(dataType) + "'.");
  276. }
  277. this._shape = new dlc.TensorShape(shape);
  278. }
  279. get dataType() {
  280. return this._dataType;
  281. }
  282. get shape() {
  283. return this._shape;
  284. }
  285. toString() {
  286. return this.dataType + this._shape.toString();
  287. }
  288. };
  289. dlc.TensorShape = class {
  290. constructor(dimensions) {
  291. this._dimensions = Array.from(dimensions);
  292. }
  293. get dimensions() {
  294. return this._dimensions;
  295. }
  296. toString() {
  297. if (!this._dimensions || this._dimensions.length == 0) {
  298. return '';
  299. }
  300. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  301. }
  302. };
  303. dlc.Tensor = class {
  304. constructor(type, data) {
  305. this._type = type;
  306. switch (type.dataType) {
  307. case 'uint8': this._data = data.bytes; break;
  308. case 'float32': this._data = data.floats; break;
  309. default: throw new dlc.Error("Unsupported tensor data type '" + type.dataType + "'.");
  310. }
  311. }
  312. get type() {
  313. return this._type;
  314. }
  315. get state() {
  316. return this._context().state || null;
  317. }
  318. get value() {
  319. const context = this._context();
  320. if (context.state) {
  321. return null;
  322. }
  323. context.limit = Number.MAX_SAFE_INTEGER;
  324. return this._decode(context, 0);
  325. }
  326. toString() {
  327. const context = this._context();
  328. if (context.state) {
  329. return '';
  330. }
  331. context.limit = 10000;
  332. const value = this._decode(context, 0);
  333. return JSON.stringify(value, null, 4);
  334. }
  335. _context() {
  336. const context = {};
  337. context.state = null;
  338. context.index = 0;
  339. context.count = 0;
  340. context.shape = this._type.shape.dimensions;
  341. context.data = this._data;
  342. return context;
  343. }
  344. _decode(context, dimension) {
  345. const results = [];
  346. const size = context.shape[dimension];
  347. if (dimension == context.shape.length - 1) {
  348. for (let i = 0; i < size; i++) {
  349. if (context.count > context.limit) {
  350. results.push('...');
  351. return results;
  352. }
  353. results.push(context.data[context.index]);
  354. context.index++;
  355. context.count++;
  356. }
  357. }
  358. else {
  359. for (let j = 0; j < size; j++) {
  360. if (context.count > context.limit) {
  361. results.push('...');
  362. return results;
  363. }
  364. results.push(this._decode(context, dimension + 1));
  365. }
  366. }
  367. return results;
  368. }
  369. };
  370. dlc.Container = class {
  371. static open(context) {
  372. const entries = context.entries('zip');
  373. if (entries.size > 0) {
  374. const model = entries.get('model');
  375. const params = entries.get('model.params');
  376. if (model || params) {
  377. return new dlc.Container(model, params, entries.get('dlc.metadata'));
  378. }
  379. }
  380. const stream = context.stream;
  381. switch (dlc.Container._idenfitier(stream)) {
  382. case 'NETD':
  383. return new dlc.Container(stream, null, null);
  384. case 'NETP':
  385. return new dlc.Container(null, stream, null);
  386. default:
  387. break;
  388. }
  389. return null;
  390. }
  391. constructor(model, params, metadata) {
  392. this._model = model || null;
  393. this._params = params || null;
  394. this._metadata = metadata || new Uint8Array(0);
  395. }
  396. get model() {
  397. if (this._model && this._model.peek) {
  398. const stream = this._model;
  399. const reader = this._open(stream, 'NETD');
  400. stream.seek(0);
  401. this._model = dlc.schema.NetDef.decode(reader, reader.root);
  402. }
  403. return this._model;
  404. }
  405. get params() {
  406. if (this._params && this._params.peek) {
  407. const stream = this._params;
  408. const reader = this._open(stream, 'NETP');
  409. stream.seek(0);
  410. this._params = dlc.schema.NetParam.decode(reader, reader.root);
  411. }
  412. return this._params;
  413. }
  414. get metadata() {
  415. if (this._metadata && this._metadata.peek) {
  416. const reader = text.Reader.open(this._metadata);
  417. const metadata = new Map();
  418. for (;;) {
  419. const line = reader.read();
  420. if (line === undefined) {
  421. break;
  422. }
  423. const index = line.indexOf('=');
  424. if (index === -1) {
  425. break;
  426. }
  427. const key = line.substring(0, index);
  428. const value = line.substring(index + 1);
  429. metadata.set(key, value);
  430. }
  431. this._metadata = metadata;
  432. }
  433. return this._metadata;
  434. }
  435. _open(stream, identifier) {
  436. if (dlc.Container._signature(stream, [ 0xD5, 0x0A, 0x02, 0x00 ])) {
  437. throw new dlc.Error("Unsupported DLC format '0x00020AD5'.");
  438. }
  439. if (dlc.Container._signature(stream, [ 0xD5, 0x0A, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00 ])) {
  440. stream.read(8);
  441. }
  442. const buffer = new Uint8Array(stream.read());
  443. const reader = flatbuffers.BinaryReader.open(buffer);
  444. if (identifier != reader.identifier) {
  445. throw new dlc.Error("File contains undocumented '" + reader.identifier + "' data.");
  446. }
  447. return reader;
  448. }
  449. static _idenfitier(stream) {
  450. if (dlc.Container._signature(stream, [ 0xD5, 0x0A, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00 ])) {
  451. const buffer = stream.peek(16).slice(8, 16);
  452. const reader = flatbuffers.BinaryReader.open(buffer);
  453. return reader.identifier;
  454. }
  455. else if (stream && stream.length > 8) {
  456. const buffer = stream.peek(8);
  457. const reader = flatbuffers.BinaryReader.open(buffer);
  458. return reader.identifier;
  459. }
  460. return '';
  461. }
  462. static _signature(stream, signature) {
  463. return stream && stream.length > 16 && stream.peek(signature.length).every((value, index) => value === signature[index]);
  464. }
  465. };
  466. dlc.Utility = class {
  467. static enum(name, value) {
  468. const type = name && dlc.schema ? dlc.schema[name] : undefined;
  469. if (type) {
  470. dlc.Utility._enums = dlc.Utility._enums || new Map();
  471. if (!dlc.Utility._enums.has(name)) {
  472. const map = new Map(Object.keys(type).map((key) => [ type[key], key ]));
  473. dlc.Utility._enums.set(name, map);
  474. }
  475. const map = dlc.Utility._enums.get(name);
  476. if (map.has(value)) {
  477. return map.get(value);
  478. }
  479. }
  480. return value;
  481. }
  482. };
  483. dlc.Error = class extends Error {
  484. constructor(message) {
  485. super(message);
  486. this.name = 'Error loading DLC model.';
  487. this.stack = undefined;
  488. }
  489. };
  490. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  491. module.exports.ModelFactory = dlc.ModelFactory;
  492. }