dl4j.js 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var dl4j = dl4j || {};
  4. dl4j.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier.toLowerCase();
  7. const extension = identifier.split('.').pop().toLowerCase();
  8. if (extension === 'zip' && context.entries('zip').length > 0) {
  9. if (dl4j.ModelFactory._openContainer(context)) {
  10. return true;
  11. }
  12. }
  13. return false;
  14. }
  15. open(context, host) {
  16. const identifier = context.identifier;
  17. try {
  18. const container = dl4j.ModelFactory._openContainer(context);
  19. const configuration = JSON.parse(container.configuration);
  20. return dl4j.Metadata.open(host).then((metadata) => {
  21. try {
  22. return new dl4j.Model(metadata, configuration, container.coefficients);
  23. }
  24. catch (error) {
  25. host.exception(error, false);
  26. const message = error && error.message ? error.message : error.toString();
  27. throw new dl4j.Error(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  28. }
  29. });
  30. }
  31. catch (error) {
  32. host.exception(error, false);
  33. const message = error && error.message ? error.message : error.toString();
  34. return Promise.reject(new dl4j.Error(message.replace(/\.$/, '') + " in '" + identifier + "'."));
  35. }
  36. }
  37. static _openContainer(context) {
  38. const entries = context.entries('zip');
  39. const configurationEntries = entries.filter((entry) => entry.name === 'configuration.json');
  40. if (configurationEntries.length != 1) {
  41. return null;
  42. }
  43. let configuration = null;
  44. try {
  45. configuration = new TextDecoder('utf-8').decode(configurationEntries[0].data);
  46. }
  47. catch (error) {
  48. return null;
  49. }
  50. if (configuration.indexOf('"vertices"') === -1 && configuration.indexOf('"confs"') === -1) {
  51. return null;
  52. }
  53. const coefficientsEntries = entries.filter((entry) => entry.name === 'coefficients.bin');
  54. if (coefficientsEntries.length > 1) {
  55. return null;
  56. }
  57. const coefficients = coefficientsEntries.length == 1 ? coefficientsEntries[0].data : [];
  58. return {
  59. configuration: configuration,
  60. coefficients: coefficients
  61. };
  62. }
  63. };
  64. dl4j.Model = class {
  65. constructor(metadata, configuration, coefficients) {
  66. this._graphs = [];
  67. this._graphs.push(new dl4j.Graph(metadata, configuration, coefficients));
  68. }
  69. get format() {
  70. return 'Deeplearning4j';
  71. }
  72. get graphs() {
  73. return this._graphs;
  74. }
  75. };
  76. dl4j.Graph = class {
  77. constructor(metadata, configuration, coefficients) {
  78. this._inputs = [];
  79. this._outputs =[];
  80. this._nodes = [];
  81. const reader = new dl4j.NDArrayReader(coefficients);
  82. const dataType = reader.dataType;
  83. if (configuration.networkInputs) {
  84. for (const input of configuration.networkInputs) {
  85. this._inputs.push(new dl4j.Parameter(input, true, [
  86. new dl4j.Argument(input, null, null)
  87. ]));
  88. }
  89. }
  90. if (configuration.networkOutputs) {
  91. for (const output of configuration.networkOutputs) {
  92. this._outputs.push(new dl4j.Parameter(output, true, [
  93. new dl4j.Argument(output, null, null)
  94. ]));
  95. }
  96. }
  97. let inputs = null;
  98. // Computation Graph
  99. if (configuration.vertices) {
  100. for (const name in configuration.vertices) {
  101. const vertex = dl4j.Node._object(configuration.vertices[name]);
  102. inputs = configuration.vertexInputs[name];
  103. let variables = [];
  104. let layer = null;
  105. switch (vertex.__type__) {
  106. case 'LayerVertex':
  107. layer = dl4j.Node._object(vertex.layerConf.layer);
  108. variables = vertex.layerConf.variables;
  109. break;
  110. case 'MergeVertex':
  111. layer = { __type__: 'Merge', layerName: name };
  112. break;
  113. case 'ElementWiseVertex':
  114. layer = { __type__: 'ElementWise', layerName: name, op: vertex.op };
  115. break;
  116. case 'PreprocessorVertex':
  117. layer = { __type__: 'Preprocessor', layerName: name };
  118. break;
  119. default:
  120. throw new dl4j.Error("Unsupported vertex class '" + vertex['@class'] + "'.");
  121. }
  122. this._nodes.push(new dl4j.Node(metadata, layer, inputs, dataType, variables));
  123. }
  124. }
  125. // Multi Layer Network
  126. if (configuration.confs) {
  127. inputs = [ 'input' ];
  128. this._inputs.push(new dl4j.Parameter('input', true, [
  129. new dl4j.Argument('input', null, null)
  130. ]));
  131. for (const conf of configuration.confs) {
  132. const layer = dl4j.Node._object(conf.layer);
  133. this._nodes.push(new dl4j.Node(metadata, layer, inputs, dataType, conf.variables));
  134. inputs = [ layer.layerName ];
  135. }
  136. this._outputs.push(new dl4j.Parameter('output', true, [
  137. new dl4j.Argument(inputs[0], null, null)
  138. ]));
  139. }
  140. }
  141. get inputs() {
  142. return this._inputs;
  143. }
  144. get outputs() {
  145. return this._outputs;
  146. }
  147. get nodes() {
  148. return this._nodes;
  149. }
  150. };
  151. dl4j.Parameter = class {
  152. constructor(name, visible, args) {
  153. this._name = name;
  154. this._visible = visible;
  155. this._arguments = args;
  156. }
  157. get name() {
  158. return this._name;
  159. }
  160. get visible() {
  161. return this._visible;
  162. }
  163. get arguments() {
  164. return this._arguments;
  165. }
  166. };
  167. dl4j.Argument = class {
  168. constructor(name, type, initializer) {
  169. if (typeof name !== 'string') {
  170. throw new dl4j.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  171. }
  172. this._name = name;
  173. this._type = type;
  174. this._initializer = initializer;
  175. }
  176. get name() {
  177. return this._name;
  178. }
  179. get type() {
  180. if (this._initializer) {
  181. return this._initializer.type;
  182. }
  183. return this._type;
  184. }
  185. get initializer() {
  186. return this._initializer;
  187. }
  188. };
  189. dl4j.Node = class {
  190. constructor(metadata, layer, inputs, dataType, variables) {
  191. this._metadata = metadata;
  192. this._type = layer.__type__;
  193. this._name = layer.layerName || '';
  194. this._inputs = [];
  195. this._outputs = [];
  196. this._attributes = [];
  197. if (inputs && inputs.length > 0) {
  198. const args = inputs.map((input) => new dl4j.Argument(input, null, null));
  199. this._inputs.push(new dl4j.Parameter(args.length < 2 ? 'input' : 'inputs', true, args));
  200. }
  201. if (variables) {
  202. for (const variable of variables) {
  203. let tensor = null;
  204. switch (this._type) {
  205. case 'Convolution':
  206. switch (variable) {
  207. case 'W':
  208. tensor = new dl4j.Tensor(dataType, layer.kernelSize.concat([ layer.nin, layer.nout ]));
  209. break;
  210. case 'b':
  211. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  212. break;
  213. default:
  214. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  215. }
  216. break;
  217. case 'SeparableConvolution2D':
  218. switch (variable) {
  219. case 'W':
  220. tensor = new dl4j.Tensor(dataType, layer.kernelSize.concat([ layer.nin, layer.nout ]));
  221. break;
  222. case 'pW':
  223. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  224. break;
  225. default:
  226. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  227. }
  228. break;
  229. case 'Output':
  230. case 'Dense':
  231. switch (variable) {
  232. case 'W':
  233. tensor = new dl4j.Tensor(dataType, [ layer.nout, layer.nin ]);
  234. break;
  235. case 'b':
  236. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  237. break;
  238. default:
  239. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  240. }
  241. break;
  242. case 'BatchNormalization':
  243. tensor = new dl4j.Tensor(dataType, [ layer.nin ]);
  244. break;
  245. default:
  246. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  247. }
  248. this._inputs.push(new dl4j.Parameter(variable, true, [
  249. new dl4j.Argument(variable, null, tensor)
  250. ]));
  251. }
  252. }
  253. if (this._name) {
  254. this._outputs.push(new dl4j.Parameter('output', true, [
  255. new dl4j.Argument(this._name, null, null)
  256. ]));
  257. }
  258. let attributes = layer;
  259. if (layer.activationFn) {
  260. const activation = dl4j.Node._object(layer.activationFn);
  261. if (activation.__type__ !== 'ActivationIdentity' && activation.__type__ !== 'Identity') {
  262. if (activation.__type__.startsWith('Activation')) {
  263. activation.__type__ = activation.__type__.substring('Activation'.length);
  264. }
  265. if (this._type == 'Activation') {
  266. this._type = activation.__type__;
  267. attributes = activation;
  268. }
  269. else {
  270. this._chain = this._chain || [];
  271. this._chain.push(new dl4j.Node(metadata, activation, [], null, null));
  272. }
  273. }
  274. }
  275. for (const key in attributes) {
  276. switch (key) {
  277. case '__type__':
  278. case 'constraints':
  279. case 'layerName':
  280. case 'activationFn':
  281. case 'idropout':
  282. case 'hasBias':
  283. continue;
  284. }
  285. this._attributes.push(new dl4j.Attribute(metadata.attribute(this._type, key), key, attributes[key]));
  286. }
  287. if (layer.idropout) {
  288. const dropout = dl4j.Node._object(layer.idropout);
  289. if (dropout.p !== 1.0) {
  290. throw new dl4j.Error("Layer 'idropout' not implemented.");
  291. }
  292. }
  293. }
  294. get type() {
  295. return this._type;
  296. }
  297. get name() {
  298. return this._name;
  299. }
  300. get metadata() {
  301. return this._metadata.type(this._type);
  302. }
  303. get inputs() {
  304. return this._inputs;
  305. }
  306. get outputs() {
  307. return this._outputs;
  308. }
  309. get attributes() {
  310. return this._attributes;
  311. }
  312. get chain() {
  313. return this._chain;
  314. }
  315. static _object(value) {
  316. let result = {};
  317. if (value['@class']) {
  318. result = value;
  319. let type = value['@class'].split('.').pop();
  320. if (type.endsWith('Layer')) {
  321. type = type.substring(0, type.length - 5);
  322. }
  323. delete value['@class'];
  324. result.__type__ = type;
  325. }
  326. else {
  327. let key = Object.keys(value)[0];
  328. result = value[key];
  329. if (key.length > 0) {
  330. key = key[0].toUpperCase() + key.substring(1);
  331. }
  332. result.__type__ = key;
  333. }
  334. return result;
  335. }
  336. };
  337. dl4j.Attribute = class {
  338. constructor(schema, name, value) {
  339. this._name = name;
  340. this._value = value;
  341. this._visible = false;
  342. if (schema) {
  343. if (schema.visible) {
  344. this._visible = true;
  345. }
  346. }
  347. }
  348. get name() {
  349. return this._name;
  350. }
  351. get type() {
  352. return this._type;
  353. }
  354. get value() {
  355. return this._value;
  356. }
  357. get visible() {
  358. return this._visible;
  359. }
  360. };
  361. dl4j.Tensor = class {
  362. constructor(dataType, shape) {
  363. this._type = new dl4j.TensorType(dataType, new dl4j.TensorShape(shape));
  364. }
  365. get type() {
  366. return this._type;
  367. }
  368. get state() {
  369. return 'Not implemented.';
  370. }
  371. };
  372. dl4j.TensorType = class {
  373. constructor(dataType, shape) {
  374. this._dataType = dataType;
  375. this._shape = shape;
  376. }
  377. get dataType() {
  378. return this._dataType;
  379. }
  380. get shape() {
  381. return this._shape;
  382. }
  383. toString() {
  384. return (this.dataType || '?') + this._shape.toString();
  385. }
  386. };
  387. dl4j.TensorShape = class {
  388. constructor(dimensions) {
  389. this._dimensions = dimensions;
  390. }
  391. get dimensions() {
  392. return this._dimensions;
  393. }
  394. toString() {
  395. if (this._dimensions) {
  396. if (this._dimensions.length == 0) {
  397. return '';
  398. }
  399. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  400. }
  401. return '';
  402. }
  403. };
  404. dl4j.Metadata = class {
  405. static open(host) {
  406. dl4j.Metadata.textDecoder = dl4j.Metadata.textDecoder || new TextDecoder('utf-8');
  407. if (dl4j.Metadata._metadata) {
  408. return Promise.resolve(dl4j.Metadata._metadata);
  409. }
  410. return host.request(null, 'dl4j-metadata.json', 'utf-8').then((data) => {
  411. dl4j.Metadata._metadata = new dl4j.Metadata(data);
  412. return dl4j.Metadata._metadata;
  413. }).catch(() => {
  414. dl4j.Metadata._metadata = new dl4j.Metadata(null);
  415. return dl4j.Metadata._metadata;
  416. });
  417. }
  418. constructor(data) {
  419. this._map = {};
  420. this._attributeCache = {};
  421. if (data) {
  422. if (data) {
  423. const items = JSON.parse(data);
  424. if (items) {
  425. for (const item of items) {
  426. item.schema.name = item.name;
  427. this._map[item.name] = item.schema;
  428. }
  429. }
  430. }
  431. }
  432. }
  433. type(name) {
  434. return this._map[name];
  435. }
  436. attribute(type, name) {
  437. let map = this._attributeCache[type];
  438. if (!map) {
  439. map = {};
  440. const schema = this.type(type);
  441. if (schema && schema.attributes && schema.attributes.length > 0) {
  442. for (const attribute of schema.attributes) {
  443. map[attribute.name] = attribute;
  444. }
  445. }
  446. this._attributeCache[type] = map;
  447. }
  448. return map[name] || null;
  449. }
  450. };
  451. dl4j.NDArrayReader = class {
  452. constructor(buffer) {
  453. const reader = new dl4j.BinaryReader(buffer);
  454. /* let shape = */ dl4j.NDArrayReader._header(reader);
  455. const data = dl4j.NDArrayReader._header(reader);
  456. this._dataType = data.type;
  457. }
  458. get dataType() {
  459. return this._dataType;
  460. }
  461. static _header(reader) {
  462. const header = {};
  463. header.alloc = reader.string();
  464. header.length = 0;
  465. switch (header.alloc) {
  466. case 'DIRECT':
  467. case 'HEAP':
  468. case 'JAVACPP':
  469. header.length = reader.int32();
  470. break;
  471. case 'LONG_SHAPE':
  472. case 'MIXED_DATA_TYPES':
  473. header.length = reader.int64();
  474. break;
  475. }
  476. header.type = reader.string();
  477. switch (header.type) {
  478. case 'INT':
  479. header.type = 'int32';
  480. header.itemsize = 4;
  481. break;
  482. case 'FLOAT':
  483. header.type = 'float32';
  484. header.itemsize = 4;
  485. break;
  486. }
  487. header.data = reader.bytes(header.itemsize * header.length);
  488. return header;
  489. }
  490. };
  491. dl4j.BinaryReader = class {
  492. constructor(buffer) {
  493. this._buffer = buffer;
  494. this._position = 0;
  495. this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  496. }
  497. bytes(size) {
  498. const data = this._buffer.subarray(this._position, this._position + size);
  499. this._position += size;
  500. return data;
  501. }
  502. string() {
  503. const size = this._buffer[this._position++] << 8 | this._buffer[this._position++];
  504. const buffer = this.bytes(size);
  505. return new TextDecoder('ascii').decode(buffer);
  506. }
  507. int32() {
  508. const position = this._position;
  509. this._position += 4;
  510. return this._view.getInt32(position, false);
  511. }
  512. int64() {
  513. const position = this._position;
  514. this._position += 4;
  515. return this._view.getInt64(position, false).toNumber();
  516. }
  517. };
  518. dl4j.Error = class extends Error {
  519. constructor(message) {
  520. super(message);
  521. this.name = 'Error loading Deeplearning4j model.';
  522. }
  523. };
  524. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  525. module.exports.ModelFactory = dl4j.ModelFactory;
  526. }