dl4j.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. // Experimental
  2. var dl4j = {};
  3. var json = require('./json');
  4. dl4j.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier;
  7. if (identifier === 'configuration.json') {
  8. const obj = context.open('json');
  9. if (obj && (obj.confs || obj.vertices)) {
  10. return 'dl4j.configuration';
  11. }
  12. }
  13. if (identifier === 'coefficients.bin') {
  14. const signature = [ 0x00, 0x07, 0x4A, 0x41, 0x56, 0x41, 0x43, 0x50, 0x50 ];
  15. const stream = context.stream;
  16. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  17. return 'dl4j.coefficients';
  18. }
  19. }
  20. return undefined;
  21. }
  22. open(context, match) {
  23. return context.metadata('dl4j-metadata.json').then((metadata) => {
  24. switch (match) {
  25. case 'dl4j.configuration': {
  26. const obj = context.open('json');
  27. return context.request('coefficients.bin', null).then((stream) => {
  28. return new dl4j.Model(metadata, obj, stream.peek());
  29. }).catch(() => {
  30. return new dl4j.Model(metadata, obj, null);
  31. });
  32. }
  33. case 'dl4j.coefficients': {
  34. return context.request('configuration.json', null).then((stream) => {
  35. const reader = json.TextReader.open(stream);
  36. const obj = reader.read();
  37. return new dl4j.Model(metadata, obj, context.stream.peek());
  38. });
  39. }
  40. default: {
  41. throw new dl4j.Error("Unsupported Deeplearning4j format '" + match + "'.");
  42. }
  43. }
  44. });
  45. }
  46. };
  47. dl4j.Model = class {
  48. constructor(metadata, configuration, coefficients) {
  49. this._graphs = [];
  50. this._graphs.push(new dl4j.Graph(metadata, configuration, coefficients));
  51. }
  52. get format() {
  53. return 'Deeplearning4j';
  54. }
  55. get graphs() {
  56. return this._graphs;
  57. }
  58. };
  59. dl4j.Graph = class {
  60. constructor(metadata, configuration, coefficients) {
  61. this._inputs = [];
  62. this._outputs =[];
  63. this._nodes = [];
  64. const dataType = coefficients ? new dl4j.NDArrayReader(coefficients).dataType : '?';
  65. if (configuration.networkInputs) {
  66. for (const input of configuration.networkInputs) {
  67. this._inputs.push(new dl4j.Parameter(input, true, [
  68. new dl4j.Argument(input, null, null)
  69. ]));
  70. }
  71. }
  72. if (configuration.networkOutputs) {
  73. for (const output of configuration.networkOutputs) {
  74. this._outputs.push(new dl4j.Parameter(output, true, [
  75. new dl4j.Argument(output, null, null)
  76. ]));
  77. }
  78. }
  79. let inputs = null;
  80. // Computation Graph
  81. if (configuration.vertices) {
  82. for (const name in configuration.vertices) {
  83. const vertex = dl4j.Node._object(configuration.vertices[name]);
  84. inputs = configuration.vertexInputs[name];
  85. let variables = [];
  86. let layer = null;
  87. switch (vertex.__type__) {
  88. case 'LayerVertex':
  89. layer = dl4j.Node._object(vertex.layerConf.layer);
  90. variables = vertex.layerConf.variables;
  91. break;
  92. case 'MergeVertex':
  93. layer = { __type__: 'Merge', layerName: name };
  94. break;
  95. case 'ElementWiseVertex':
  96. layer = { __type__: 'ElementWise', layerName: name, op: vertex.op };
  97. break;
  98. case 'PreprocessorVertex':
  99. layer = { __type__: 'Preprocessor', layerName: name };
  100. break;
  101. default:
  102. throw new dl4j.Error("Unsupported vertex class '" + vertex['@class'] + "'.");
  103. }
  104. this._nodes.push(new dl4j.Node(metadata, layer, inputs, dataType, variables));
  105. }
  106. }
  107. // Multi Layer Network
  108. if (configuration.confs) {
  109. inputs = [ 'input' ];
  110. this._inputs.push(new dl4j.Parameter('input', true, [
  111. new dl4j.Argument('input', null, null)
  112. ]));
  113. for (const conf of configuration.confs) {
  114. const layer = dl4j.Node._object(conf.layer);
  115. this._nodes.push(new dl4j.Node(metadata, layer, inputs, dataType, conf.variables));
  116. inputs = [ layer.layerName ];
  117. }
  118. this._outputs.push(new dl4j.Parameter('output', true, [
  119. new dl4j.Argument(inputs[0], null, null)
  120. ]));
  121. }
  122. }
  123. get inputs() {
  124. return this._inputs;
  125. }
  126. get outputs() {
  127. return this._outputs;
  128. }
  129. get nodes() {
  130. return this._nodes;
  131. }
  132. };
  133. dl4j.Parameter = class {
  134. constructor(name, visible, args) {
  135. this._name = name;
  136. this._visible = visible;
  137. this._arguments = args;
  138. }
  139. get name() {
  140. return this._name;
  141. }
  142. get visible() {
  143. return this._visible;
  144. }
  145. get arguments() {
  146. return this._arguments;
  147. }
  148. };
  149. dl4j.Argument = class {
  150. constructor(name, type, initializer) {
  151. if (typeof name !== 'string') {
  152. throw new dl4j.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  153. }
  154. this._name = name;
  155. this._type = type;
  156. this._initializer = initializer;
  157. }
  158. get name() {
  159. return this._name;
  160. }
  161. get type() {
  162. if (this._initializer) {
  163. return this._initializer.type;
  164. }
  165. return this._type;
  166. }
  167. get initializer() {
  168. return this._initializer;
  169. }
  170. };
  171. dl4j.Node = class {
  172. constructor(metadata, layer, inputs, dataType, variables) {
  173. this._name = layer.layerName || '';
  174. this._inputs = [];
  175. this._outputs = [];
  176. this._attributes = [];
  177. const type = layer.__type__;
  178. this._type = metadata.type(type) || { name: type };
  179. if (inputs && inputs.length > 0) {
  180. const args = inputs.map((input) => new dl4j.Argument(input, null, null));
  181. this._inputs.push(new dl4j.Parameter(args.length < 2 ? 'input' : 'inputs', true, args));
  182. }
  183. if (variables) {
  184. for (const variable of variables) {
  185. let tensor = null;
  186. switch (type) {
  187. case 'Convolution':
  188. switch (variable) {
  189. case 'W':
  190. tensor = new dl4j.Tensor(dataType, layer.kernelSize.concat([ layer.nin, layer.nout ]));
  191. break;
  192. case 'b':
  193. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  194. break;
  195. default:
  196. throw new dl4j.Error("Unsupported '" + this._type + "' variable '" + variable + "'.");
  197. }
  198. break;
  199. case 'SeparableConvolution2D':
  200. switch (variable) {
  201. case 'W':
  202. tensor = new dl4j.Tensor(dataType, layer.kernelSize.concat([ layer.nin, layer.nout ]));
  203. break;
  204. case 'pW':
  205. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  206. break;
  207. default:
  208. throw new dl4j.Error("Unsupported '" + this._type + "' variable '" + variable + "'.");
  209. }
  210. break;
  211. case 'Output':
  212. case 'Dense':
  213. switch (variable) {
  214. case 'W':
  215. tensor = new dl4j.Tensor(dataType, [ layer.nout, layer.nin ]);
  216. break;
  217. case 'b':
  218. tensor = new dl4j.Tensor(dataType, [ layer.nout ]);
  219. break;
  220. default:
  221. throw new dl4j.Error("Unsupported '" + this._type + "' variable '" + variable + "'.");
  222. }
  223. break;
  224. case 'BatchNormalization':
  225. tensor = new dl4j.Tensor(dataType, [ layer.nin ]);
  226. break;
  227. default:
  228. throw new dl4j.Error("Unsupported '" + this._type + "' variable '" + variable + "'.");
  229. }
  230. this._inputs.push(new dl4j.Parameter(variable, true, [
  231. new dl4j.Argument(variable, null, tensor)
  232. ]));
  233. }
  234. }
  235. if (this._name) {
  236. this._outputs.push(new dl4j.Parameter('output', true, [
  237. new dl4j.Argument(this._name, null, null)
  238. ]));
  239. }
  240. let attributes = layer;
  241. if (layer.activationFn) {
  242. const activation = dl4j.Node._object(layer.activationFn);
  243. if (activation.__type__ !== 'ActivationIdentity' && activation.__type__ !== 'Identity') {
  244. if (activation.__type__.startsWith('Activation')) {
  245. activation.__type__ = activation.__type__.substring('Activation'.length);
  246. }
  247. if (this._type == 'Activation') {
  248. this._type = activation.__type__;
  249. attributes = activation;
  250. }
  251. else {
  252. this._chain = this._chain || [];
  253. this._chain.push(new dl4j.Node(metadata, activation, [], null, null));
  254. }
  255. }
  256. }
  257. for (const key in attributes) {
  258. switch (key) {
  259. case '__type__':
  260. case 'constraints':
  261. case 'layerName':
  262. case 'activationFn':
  263. case 'idropout':
  264. case 'hasBias':
  265. continue;
  266. default:
  267. break;
  268. }
  269. this._attributes.push(new dl4j.Attribute(metadata.attribute(type, key), key, attributes[key]));
  270. }
  271. if (layer.idropout) {
  272. const dropout = dl4j.Node._object(layer.idropout);
  273. if (dropout.p !== 1.0) {
  274. throw new dl4j.Error("Layer 'idropout' not implemented.");
  275. }
  276. }
  277. }
  278. get type() {
  279. return this._type;
  280. }
  281. get name() {
  282. return this._name;
  283. }
  284. get inputs() {
  285. return this._inputs;
  286. }
  287. get outputs() {
  288. return this._outputs;
  289. }
  290. get attributes() {
  291. return this._attributes;
  292. }
  293. get chain() {
  294. return this._chain;
  295. }
  296. static _object(value) {
  297. let result = {};
  298. if (value['@class']) {
  299. result = value;
  300. let type = value['@class'].split('.').pop();
  301. if (type.endsWith('Layer')) {
  302. type = type.substring(0, type.length - 5);
  303. }
  304. delete value['@class'];
  305. result.__type__ = type;
  306. }
  307. else {
  308. let key = Object.keys(value)[0];
  309. result = value[key];
  310. if (key.length > 0) {
  311. key = key[0].toUpperCase() + key.substring(1);
  312. }
  313. result.__type__ = key;
  314. }
  315. return result;
  316. }
  317. };
  318. dl4j.Attribute = class {
  319. constructor(schema, name, value) {
  320. this._name = name;
  321. this._value = value;
  322. this._visible = false;
  323. if (schema) {
  324. if (schema.visible) {
  325. this._visible = true;
  326. }
  327. }
  328. }
  329. get name() {
  330. return this._name;
  331. }
  332. get type() {
  333. return this._type;
  334. }
  335. get value() {
  336. return this._value;
  337. }
  338. get visible() {
  339. return this._visible;
  340. }
  341. };
  342. dl4j.Tensor = class {
  343. constructor(dataType, shape) {
  344. this._type = new dl4j.TensorType(dataType, new dl4j.TensorShape(shape));
  345. }
  346. get type() {
  347. return this._type;
  348. }
  349. };
  350. dl4j.TensorType = class {
  351. constructor(dataType, shape) {
  352. this._dataType = dataType;
  353. this._shape = shape;
  354. }
  355. get dataType() {
  356. return this._dataType;
  357. }
  358. get shape() {
  359. return this._shape;
  360. }
  361. toString() {
  362. return (this.dataType || '?') + this._shape.toString();
  363. }
  364. };
  365. dl4j.TensorShape = class {
  366. constructor(dimensions) {
  367. this._dimensions = dimensions;
  368. }
  369. get dimensions() {
  370. return this._dimensions;
  371. }
  372. toString() {
  373. if (this._dimensions) {
  374. if (this._dimensions.length == 0) {
  375. return '';
  376. }
  377. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  378. }
  379. return '';
  380. }
  381. };
  382. dl4j.NDArrayReader = class {
  383. constructor(buffer) {
  384. const reader = new dl4j.BinaryReader(buffer);
  385. /* let shape = */ dl4j.NDArrayReader._header(reader);
  386. const data = dl4j.NDArrayReader._header(reader);
  387. this._dataType = data.type;
  388. }
  389. get dataType() {
  390. return this._dataType;
  391. }
  392. static _header(reader) {
  393. const header = {};
  394. header.alloc = reader.string();
  395. header.length = 0;
  396. switch (header.alloc) {
  397. case 'DIRECT':
  398. case 'HEAP':
  399. case 'JAVACPP':
  400. header.length = reader.int32();
  401. break;
  402. case 'LONG_SHAPE':
  403. case 'MIXED_DATA_TYPES':
  404. header.length = reader.int64();
  405. break;
  406. default:
  407. throw new dl4j.Error("Unsupported header alloc '" + header.alloc + "'.");
  408. }
  409. header.type = reader.string();
  410. switch (header.type) {
  411. case 'INT':
  412. header.type = 'int32';
  413. header.itemsize = 4;
  414. break;
  415. case 'FLOAT':
  416. header.type = 'float32';
  417. header.itemsize = 4;
  418. break;
  419. default:
  420. throw new dl4j.Error("Unsupported header type '" + header.type + "'.");
  421. }
  422. header.data = reader.read(header.itemsize * header.length);
  423. return header;
  424. }
  425. };
  426. dl4j.BinaryReader = class {
  427. constructor(buffer) {
  428. this._buffer = buffer;
  429. this._position = 0;
  430. this._view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  431. }
  432. read(size) {
  433. const data = this._buffer.subarray(this._position, this._position + size);
  434. this._position += size;
  435. return data;
  436. }
  437. string() {
  438. const size = this._buffer[this._position++] << 8 | this._buffer[this._position++];
  439. const buffer = this.read(size);
  440. this._decoder = this._decoder || new TextDecoder('ascii');
  441. return this._decoder.decode(buffer);
  442. }
  443. int32() {
  444. const position = this._position;
  445. this._position += 4;
  446. return this._view.getInt32(position, false);
  447. }
  448. int64() {
  449. const position = this._position;
  450. this._position += 4;
  451. return this._view.getInt64(position, false).toNumber();
  452. }
  453. };
  454. dl4j.Error = class extends Error {
  455. constructor(message) {
  456. super(message);
  457. this.name = 'Error loading Deeplearning4j model.';
  458. }
  459. };
  460. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  461. module.exports.ModelFactory = dl4j.ModelFactory;
  462. }