dl4j.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var dl4j = dl4j || {};
  4. var json = json || require('./json');
  5. dl4j.ModelFactory = class {
  6. match(context) {
  7. const entries = context.entries('zip');
  8. if (dl4j.ModelFactory._openContainer(entries)) {
  9. return true;
  10. }
  11. return false;
  12. }
  13. open(context) {
  14. return Promise.resolve().then(() => {
  15. return dl4j.Metadata.open(context).then((metadata) => {
  16. const entries = context.entries('zip');
  17. const container = dl4j.ModelFactory._openContainer(entries);
  18. return new dl4j.Model(metadata, container.configuration, container.coefficients);
  19. });
  20. });
  21. }
  22. static _openContainer(entries) {
  23. const stream = entries.get('configuration.json');
  24. const coefficients = entries.get('coefficients.bin');
  25. if (stream) {
  26. try {
  27. const reader = json.TextReader.open(stream);
  28. const configuration = reader.read();
  29. if (configuration && (configuration.confs || configuration.vertices)) {
  30. return {
  31. configuration: configuration,
  32. coefficients: coefficients ? coefficients.peek() : []
  33. };
  34. }
  35. }
  36. catch (error) {
  37. // continue regardless of error
  38. }
  39. }
  40. return null;
  41. }
  42. };
  43. dl4j.Model = class {
  44. constructor(metadata, configuration, coefficients) {
  45. this._graphs = [];
  46. this._graphs.push(new dl4j.Graph(metadata, configuration, coefficients));
  47. }
  48. get format() {
  49. return 'Deeplearning4j';
  50. }
  51. get graphs() {
  52. return this._graphs;
  53. }
  54. };
  55. dl4j.Graph = class {
  56. constructor(metadata, configuration, coefficients) {
  57. this._inputs = [];
  58. this._outputs =[];
  59. this._nodes = [];
  60. const reader = new dl4j.NDArrayReader(coefficients);
  61. const dataType = reader.dataType;
  62. if (configuration.networkInputs) {
  63. for (const input of configuration.networkInputs) {
  64. this._inputs.push(new dl4j.Parameter(input, true, [
  65. new dl4j.Argument(input, null, null)
  66. ]));
  67. }
  68. }
  69. if (configuration.networkOutputs) {
  70. for (const output of configuration.networkOutputs) {
  71. this._outputs.push(new dl4j.Parameter(output, true, [
  72. new dl4j.Argument(output, null, null)
  73. ]));
  74. }
  75. }
  76. let inputs = null;
  77. // Computation Graph
  78. if (configuration.vertices) {
  79. for (const name in configuration.vertices) {
  80. const vertex = dl4j.Node._object(configuration.vertices[name]);
  81. inputs = configuration.vertexInputs[name];
  82. let variables = [];
  83. let layer = null;
  84. switch (vertex.__type__) {
  85. case 'LayerVertex':
  86. layer = dl4j.Node._object(vertex.layerConf.layer);
  87. variables = vertex.layerConf.variables;
  88. break;
  89. case 'MergeVertex':
  90. layer = { __type__: 'Merge', layerName: name };
  91. break;
  92. case 'ElementWiseVertex':
  93. layer = { __type__: 'ElementWise', layerName: name, op: vertex.op };
  94. break;
  95. case 'PreprocessorVertex':
  96. layer = { __type__: 'Preprocessor', layerName: name };
  97. break;
  98. default:
  99. throw new dl4j.Error("Unsupported vertex class '" + vertex['@class'] + "'.");
  100. }
  101. this._nodes.push(new dl4j.Node(metadata, layer, inputs, dataType, variables));
  102. }
  103. }
  104. // Multi Layer Network
  105. if (configuration.confs) {
  106. inputs = [ 'input' ];
  107. this._inputs.push(new dl4j.Parameter('input', true, [
  108. new dl4j.Argument('input', null, null)
  109. ]));
  110. for (const conf of configuration.confs) {
  111. const layer = dl4j.Node._object(conf.layer);
  112. this._nodes.push(new dl4j.Node(metadata, layer, inputs, dataType, conf.variables));
  113. inputs = [ layer.layerName ];
  114. }
  115. this._outputs.push(new dl4j.Parameter('output', true, [
  116. new dl4j.Argument(inputs[0], null, null)
  117. ]));
  118. }
  119. }
  120. get inputs() {
  121. return this._inputs;
  122. }
  123. get outputs() {
  124. return this._outputs;
  125. }
  126. get nodes() {
  127. return this._nodes;
  128. }
  129. };
  130. dl4j.Parameter = class {
  131. constructor(name, visible, args) {
  132. this._name = name;
  133. this._visible = visible;
  134. this._arguments = args;
  135. }
  136. get name() {
  137. return this._name;
  138. }
  139. get visible() {
  140. return this._visible;
  141. }
  142. get arguments() {
  143. return this._arguments;
  144. }
  145. };
  146. dl4j.Argument = class {
  147. constructor(name, type, initializer) {
  148. if (typeof name !== 'string') {
  149. throw new dl4j.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  150. }
  151. this._name = name;
  152. this._type = type;
  153. this._initializer = initializer;
  154. }
  155. get name() {
  156. return this._name;
  157. }
  158. get type() {
  159. if (this._initializer) {
  160. return this._initializer.type;
  161. }
  162. return this._type;
  163. }
  164. get initializer() {
  165. return this._initializer;
  166. }
  167. };
  168. dl4j.Node = class {
  169. constructor(metadata, layer, inputs, dataType, variables) {
  170. this._metadata = metadata;
  171. this._type = layer.__type__;
  172. this._name = layer.layerName || '';
  173. this._inputs = [];
  174. this._outputs = [];
  175. this._attributes = [];
  176. if (inputs && inputs.length > 0) {
  177. const args = inputs.map((input) => new dl4j.Argument(input, null, null));
  178. this._inputs.push(new dl4j.Parameter(args.length < 2 ? 'input' : 'inputs', true, args));
  179. }
  180. if (variables) {
  181. for (const variable of variables) {
  182. let tensor = null;
  183. switch (this._type) {
  184. case 'Convolution':
  185. switch (variable) {
  186. case 'W':
  187. tensor = new dl4j.Tensor(dataType, layer.kernelSize.concat([ layer.nin, layer.nout ]));
  188. break;
  189. case 'b':
  190. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  191. break;
  192. default:
  193. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  194. }
  195. break;
  196. case 'SeparableConvolution2D':
  197. switch (variable) {
  198. case 'W':
  199. tensor = new dl4j.Tensor(dataType, layer.kernelSize.concat([ layer.nin, layer.nout ]));
  200. break;
  201. case 'pW':
  202. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  203. break;
  204. default:
  205. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  206. }
  207. break;
  208. case 'Output':
  209. case 'Dense':
  210. switch (variable) {
  211. case 'W':
  212. tensor = new dl4j.Tensor(dataType, [ layer.nout, layer.nin ]);
  213. break;
  214. case 'b':
  215. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  216. break;
  217. default:
  218. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  219. }
  220. break;
  221. case 'BatchNormalization':
  222. tensor = new dl4j.Tensor(dataType, [ layer.nin ]);
  223. break;
  224. default:
  225. throw new dl4j.Error("Unknown '" + this._type + "' variable '" + variable + "'.");
  226. }
  227. this._inputs.push(new dl4j.Parameter(variable, true, [
  228. new dl4j.Argument(variable, null, tensor)
  229. ]));
  230. }
  231. }
  232. if (this._name) {
  233. this._outputs.push(new dl4j.Parameter('output', true, [
  234. new dl4j.Argument(this._name, null, null)
  235. ]));
  236. }
  237. let attributes = layer;
  238. if (layer.activationFn) {
  239. const activation = dl4j.Node._object(layer.activationFn);
  240. if (activation.__type__ !== 'ActivationIdentity' && activation.__type__ !== 'Identity') {
  241. if (activation.__type__.startsWith('Activation')) {
  242. activation.__type__ = activation.__type__.substring('Activation'.length);
  243. }
  244. if (this._type == 'Activation') {
  245. this._type = activation.__type__;
  246. attributes = activation;
  247. }
  248. else {
  249. this._chain = this._chain || [];
  250. this._chain.push(new dl4j.Node(metadata, activation, [], null, null));
  251. }
  252. }
  253. }
  254. for (const key in attributes) {
  255. switch (key) {
  256. case '__type__':
  257. case 'constraints':
  258. case 'layerName':
  259. case 'activationFn':
  260. case 'idropout':
  261. case 'hasBias':
  262. continue;
  263. }
  264. this._attributes.push(new dl4j.Attribute(metadata.attribute(this._type, key), key, attributes[key]));
  265. }
  266. if (layer.idropout) {
  267. const dropout = dl4j.Node._object(layer.idropout);
  268. if (dropout.p !== 1.0) {
  269. throw new dl4j.Error("Layer 'idropout' not implemented.");
  270. }
  271. }
  272. }
  273. get type() {
  274. return this._type;
  275. }
  276. get name() {
  277. return this._name;
  278. }
  279. get metadata() {
  280. return this._metadata.type(this._type);
  281. }
  282. get inputs() {
  283. return this._inputs;
  284. }
  285. get outputs() {
  286. return this._outputs;
  287. }
  288. get attributes() {
  289. return this._attributes;
  290. }
  291. get chain() {
  292. return this._chain;
  293. }
  294. static _object(value) {
  295. let result = {};
  296. if (value['@class']) {
  297. result = value;
  298. let type = value['@class'].split('.').pop();
  299. if (type.endsWith('Layer')) {
  300. type = type.substring(0, type.length - 5);
  301. }
  302. delete value['@class'];
  303. result.__type__ = type;
  304. }
  305. else {
  306. let key = Object.keys(value)[0];
  307. result = value[key];
  308. if (key.length > 0) {
  309. key = key[0].toUpperCase() + key.substring(1);
  310. }
  311. result.__type__ = key;
  312. }
  313. return result;
  314. }
  315. };
  316. dl4j.Attribute = class {
  317. constructor(schema, name, value) {
  318. this._name = name;
  319. this._value = value;
  320. this._visible = false;
  321. if (schema) {
  322. if (schema.visible) {
  323. this._visible = true;
  324. }
  325. }
  326. }
  327. get name() {
  328. return this._name;
  329. }
  330. get type() {
  331. return this._type;
  332. }
  333. get value() {
  334. return this._value;
  335. }
  336. get visible() {
  337. return this._visible;
  338. }
  339. };
  340. dl4j.Tensor = class {
  341. constructor(dataType, shape) {
  342. this._type = new dl4j.TensorType(dataType, new dl4j.TensorShape(shape));
  343. }
  344. get type() {
  345. return this._type;
  346. }
  347. get state() {
  348. return 'Not implemented.';
  349. }
  350. };
  351. dl4j.TensorType = class {
  352. constructor(dataType, shape) {
  353. this._dataType = dataType;
  354. this._shape = shape;
  355. }
  356. get dataType() {
  357. return this._dataType;
  358. }
  359. get shape() {
  360. return this._shape;
  361. }
  362. toString() {
  363. return (this.dataType || '?') + this._shape.toString();
  364. }
  365. };
  366. dl4j.TensorShape = class {
  367. constructor(dimensions) {
  368. this._dimensions = dimensions;
  369. }
  370. get dimensions() {
  371. return this._dimensions;
  372. }
  373. toString() {
  374. if (this._dimensions) {
  375. if (this._dimensions.length == 0) {
  376. return '';
  377. }
  378. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  379. }
  380. return '';
  381. }
  382. };
  383. dl4j.Metadata = class {
  384. static open(context) {
  385. dl4j.Metadata.textDecoder = dl4j.Metadata.textDecoder || new TextDecoder('utf-8');
  386. if (dl4j.Metadata._metadata) {
  387. return Promise.resolve(dl4j.Metadata._metadata);
  388. }
  389. return context.request('dl4j-metadata.json', 'utf-8', null).then((data) => {
  390. dl4j.Metadata._metadata = new dl4j.Metadata(data);
  391. return dl4j.Metadata._metadata;
  392. }).catch(() => {
  393. dl4j.Metadata._metadata = new dl4j.Metadata(null);
  394. return dl4j.Metadata._metadata;
  395. });
  396. }
  397. constructor(data) {
  398. this._map = new Map();
  399. this._attributeCache = {};
  400. if (data) {
  401. const metadata = JSON.parse(data);
  402. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  403. }
  404. }
  405. type(name) {
  406. return this._map.get(name);
  407. }
  408. attribute(type, name) {
  409. let map = this._attributeCache[type];
  410. if (!map) {
  411. map = {};
  412. const schema = this.type(type);
  413. if (schema && schema.attributes && schema.attributes.length > 0) {
  414. for (const attribute of schema.attributes) {
  415. map[attribute.name] = attribute;
  416. }
  417. }
  418. this._attributeCache[type] = map;
  419. }
  420. return map[name] || null;
  421. }
  422. };
  423. dl4j.NDArrayReader = class {
  424. constructor(buffer) {
  425. const reader = new dl4j.BinaryReader(buffer);
  426. /* let shape = */ dl4j.NDArrayReader._header(reader);
  427. const data = dl4j.NDArrayReader._header(reader);
  428. this._dataType = data.type;
  429. }
  430. get dataType() {
  431. return this._dataType;
  432. }
  433. static _header(reader) {
  434. const header = {};
  435. header.alloc = reader.string();
  436. header.length = 0;
  437. switch (header.alloc) {
  438. case 'DIRECT':
  439. case 'HEAP':
  440. case 'JAVACPP':
  441. header.length = reader.int32();
  442. break;
  443. case 'LONG_SHAPE':
  444. case 'MIXED_DATA_TYPES':
  445. header.length = reader.int64();
  446. break;
  447. }
  448. header.type = reader.string();
  449. switch (header.type) {
  450. case 'INT':
  451. header.type = 'int32';
  452. header.itemsize = 4;
  453. break;
  454. case 'FLOAT':
  455. header.type = 'float32';
  456. header.itemsize = 4;
  457. break;
  458. }
  459. header.data = reader.bytes(header.itemsize * header.length);
  460. return header;
  461. }
  462. };
  463. dl4j.BinaryReader = class {
  464. constructor(buffer) {
  465. this._buffer = buffer;
  466. this._position = 0;
  467. this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  468. }
  469. bytes(size) {
  470. const data = this._buffer.subarray(this._position, this._position + size);
  471. this._position += size;
  472. return data;
  473. }
  474. string() {
  475. const size = this._buffer[this._position++] << 8 | this._buffer[this._position++];
  476. const buffer = this.bytes(size);
  477. return new TextDecoder('ascii').decode(buffer);
  478. }
  479. int32() {
  480. const position = this._position;
  481. this._position += 4;
  482. return this._view.getInt32(position, false);
  483. }
  484. int64() {
  485. const position = this._position;
  486. this._position += 4;
  487. return this._view.getInt64(position, false).toNumber();
  488. }
  489. };
  490. dl4j.Error = class extends Error {
  491. constructor(message) {
  492. super(message);
  493. this.name = 'Error loading Deeplearning4j model.';
  494. }
  495. };
  496. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  497. module.exports.ModelFactory = dl4j.ModelFactory;
  498. }