chainer.js 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. /* jshint esversion: 6 */
  2. /* eslint "indent": [ "error", 4, { "SwitchCase": 1 } ] */
  3. // Experimental
  4. var chainer = chainer || {};
  5. var long = long || { Long: require('long') };
  6. chainer.ModelFactory = class {
  7. match(context) {
  8. const identifier = context.identifier;
  9. const extension = identifier.split('.').pop().toLowerCase();
  10. if (extension === 'npz') {
  11. const entries = context.entries('zip');
  12. return entries.length > 0 && entries.every((entry) => entry.name.indexOf('/') !== -1);
  13. }
  14. if (extension === 'h5' || extension === 'hd5' || extension === 'hdf5' || extension === 'keras' || extension === 'model') {
  15. const buffer = context.buffer;
  16. const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
  17. if (buffer && buffer.length > signature.length && signature.every((v, i) => v === buffer[i])) {
  18. return true;
  19. }
  20. }
  21. return false;
  22. }
  23. open(context, host) {
  24. const extension = context.identifier.split('.').pop().toLowerCase();
  25. switch (extension) {
  26. case 'npz':
  27. return this._openNumPy(context, host);
  28. case 'h5':
  29. case 'hd5':
  30. case 'hdf5':
  31. return this._openHdf5(context, host);
  32. }
  33. }
  34. _openNumPy(context, host) {
  35. const identifier = context.identifier;
  36. return host.require('./numpy').then((numpy) => {
  37. return host.require('./pickle').then((pickle) => {
  38. try {
  39. let modules = [];
  40. let modulesMap = new Map();
  41. let functionTable = new Map();
  42. let constructorTable = new Map();
  43. functionTable.set('_codecs.encode', function(obj /*, econding */) {
  44. return obj;
  45. });
  46. constructorTable.set('numpy.core.multiarray._reconstruct', function(subtype, shape, dtype) {
  47. this.subtype = subtype;
  48. this.shape = shape;
  49. this.dtype = dtype;
  50. this.__setstate__ = function(state) {
  51. this.version = state[0];
  52. this.shape = state[1];
  53. this.typecode = state[2];
  54. this.is_f_order = state[3];
  55. this.rawdata = state[4];
  56. };
  57. this.__read__ = function(unpickler) {
  58. let array = {};
  59. array.__type__ = this.subtype;
  60. array.dtype = this.typecode;
  61. array.shape = this.shape;
  62. let size = array.dtype.itemsize;
  63. for (let i = 0; i < array.shape.length; i++) {
  64. size = size * array.shape[i];
  65. }
  66. if (typeof this.rawdata == 'string') {
  67. array.data = unpickler.unescape(this.rawdata, size);
  68. if (array.data.length != size) {
  69. throw new chainer.Error('Invalid string array data size.');
  70. }
  71. }
  72. else {
  73. array.data = this.rawdata;
  74. if (array.data.length != size) {
  75. // TODO
  76. // throw new chainer.Error('Invalid array data size.');
  77. }
  78. }
  79. return array;
  80. };
  81. });
  82. constructorTable.set('numpy.dtype', function(obj, align, copy) {
  83. switch (obj) {
  84. case 'i1': this.name = 'int8'; this.itemsize = 1; break;
  85. case 'i2': this.name = 'int16'; this.itemsize = 2; break;
  86. case 'i4': this.name = 'int32'; this.itemsize = 4; break;
  87. case 'i8': this.name = 'int64'; this.itemsize = 8; break;
  88. case 'u1': this.name = 'uint8'; this.itemsize = 1; break;
  89. case 'u2': this.name = 'uint16'; this.itemsize = 2; break;
  90. case 'u4': this.name = 'uint32'; this.itemsize = 4; break;
  91. case 'u8': this.name = 'uint64'; this.itemsize = 8; break;
  92. case 'f4': this.name = 'float32'; this.itemsize = 4; break;
  93. case 'f8': this.name = 'float64'; this.itemsize = 8; break;
  94. default:
  95. if (obj.startsWith('V')) {
  96. this.itemsize = Number(obj.substring(1));
  97. this.name = 'void' + (this.itemsize * 8).toString();
  98. }
  99. else if (obj.startsWith('O')) {
  100. this.itemsize = Number(obj.substring(1));
  101. this.name = 'object';
  102. }
  103. else if (obj.startsWith('S')) {
  104. this.itemsize = Number(obj.substring(1));
  105. this.name = 'string';
  106. }
  107. else if (obj.startsWith('U')) {
  108. this.itemsize = Number(obj.substring(1));
  109. this.name = 'string';
  110. }
  111. else if (obj.startsWith('M')) {
  112. this.itemsize = Number(obj.substring(1));
  113. this.name = 'datetime';
  114. }
  115. else {
  116. throw new chainer.Error("Unknown dtype '" + obj.toString() + "'.");
  117. }
  118. break;
  119. }
  120. this.align = align;
  121. this.copy = copy;
  122. this.__setstate__ = function(state) {
  123. switch (state.length) {
  124. case 8:
  125. this.version = state[0];
  126. this.byteorder = state[1];
  127. this.subarray = state[2];
  128. this.names = state[3];
  129. this.fields = state[4];
  130. this.elsize = state[5];
  131. this.alignment = state[6];
  132. this.int_dtypeflags = state[7];
  133. break;
  134. default:
  135. throw new chainer.Error("Unknown numpy.dtype setstate length '" + state.length.toString() + "'.");
  136. }
  137. };
  138. });
  139. let function_call = (name, args) => {
  140. if (functionTable.has(name)) {
  141. const func = functionTable.get(name);
  142. return func.apply(null, args);
  143. }
  144. let obj = { __type__: name };
  145. if (constructorTable.has(name)) {
  146. const constructor = constructorTable.get(name);
  147. constructor.apply(obj, args);
  148. }
  149. else {
  150. throw new chainer.Error("Unknown function '" + name + "'.");
  151. }
  152. return obj;
  153. };
  154. const dataTypeMap = new Map([
  155. [ 'i1', 'int8'], [ 'i2', 'int16' ], [ 'i4', 'int32'], [ 'i8', 'int64' ],
  156. [ 'u1', 'uint8'], [ 'u2', 'uint16' ], [ 'u4', 'uint32'], [ 'u8', 'uint64' ],
  157. [ 'f2', 'float16'], [ 'f4', 'float32' ], [ 'f8', 'float64']
  158. ]);
  159. for (const entry of context.entries('zip')) {
  160. if (!entry.name.endsWith('.npy')) {
  161. throw new chainer.Error("Invalid file name '" + entry.name + "'.");
  162. }
  163. const id = entry.name.replace(/\.npy$/, '');
  164. const parts = id.split('/');
  165. if (parts.length < 2) {
  166. throw new chainer.Error("Invalid parameter name '" + entry.name + "'.");
  167. }
  168. const parameterName = parts.pop();
  169. const moduleName = parts.join('/');
  170. if (!modulesMap.has(moduleName)) {
  171. const newModule = { name: moduleName, parameters: [] };
  172. modules.push(newModule);
  173. modulesMap.set(moduleName, newModule);
  174. }
  175. const module = modulesMap.get(moduleName);
  176. let array = new numpy.Array(entry.data);
  177. if (array.byteOrder === '|') {
  178. if (array.dataType !== 'O') {
  179. throw new chainer.Error("Invalid data type '" + array.dataType + "'.");
  180. }
  181. const unpickler = new pickle.Unpickler(array.data);
  182. const root = unpickler.load(function_call);
  183. array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
  184. }
  185. module.parameters.push({
  186. name: parameterName,
  187. dataType: dataTypeMap.has(array.dataType) ? dataTypeMap.get(array.dataType) : array.dataType,
  188. shape: array.shape,
  189. data: array.data,
  190. byteOrder: array.byteOrder
  191. });
  192. }
  193. return new chainer.Model(modules, 'Chainer NumPy');
  194. }
  195. catch (error) {
  196. const message = error && error.message ? error.message : error.toString();
  197. throw new chainer.Error(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  198. }
  199. });
  200. });
  201. }
  202. _openHdf5(context, host) {
  203. const identifier = context.identifier;
  204. return host.require('./hdf5').then((hdf5) => {
  205. try {
  206. const file = new hdf5.File(context.buffer);
  207. const rootGroup = file.rootGroup;
  208. if (Object.keys(rootGroup.attributes).length !== 0 || rootGroup.value !== null) {
  209. throw new chainer.Error('File format is not Chainer HDF5');
  210. }
  211. let modules = [];
  212. let modulesMap = new Map();
  213. for (const moduleGroup of rootGroup.groups) {
  214. if (Object.keys(moduleGroup.attributes).length !== 0 || moduleGroup.value !== null) {
  215. throw new chainer.Error('Module group format is not Chainer HDF5');
  216. }
  217. const moduleName = moduleGroup.name;
  218. if (!modulesMap.has(moduleName)) {
  219. const newModule = { name: moduleName, parameters: [] };
  220. modulesMap.set(moduleName, newModule);
  221. modules.push(newModule);
  222. }
  223. const module = modulesMap.get(moduleName);
  224. for (const variableGroup of moduleGroup.groups) {
  225. if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
  226. throw new chainer.Error('Variable format is not Chainer HDF5');
  227. }
  228. const variable = variableGroup.value;
  229. if (!variable) {
  230. throw new chainer.Error('Variable value is not Chainer HDF5');
  231. }
  232. module.parameters.push({
  233. name: variableGroup.name,
  234. dataType: variable.type,
  235. byteOrder: variable.littleEndian ? '<' : '>',
  236. shape: variable.shape,
  237. data: variable.data
  238. });
  239. }
  240. }
  241. return new chainer.Model(modules, 'Chainer HDF5');
  242. }
  243. catch (error) {
  244. const message = error && error.message ? error.message : error.toString();
  245. throw new chainer.Error(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  246. }
  247. });
  248. }
  249. }
  250. chainer.Model = class {
  251. constructor(modules, format) {
  252. this._format = format;
  253. this._graphs = [];
  254. this._graphs.push(new chainer.Graph(modules));
  255. }
  256. get format() {
  257. return this._format;
  258. }
  259. get graphs() {
  260. return this._graphs;
  261. }
  262. }
  263. chainer.Graph = class {
  264. constructor(modules) {
  265. this._nodes = [];
  266. for (const module of modules) {
  267. this._nodes.push(new chainer.Node(module));
  268. }
  269. }
  270. get inputs() {
  271. return [];
  272. }
  273. get outputs() {
  274. return [];
  275. }
  276. get nodes() {
  277. return this._nodes;
  278. }
  279. }
  280. chainer.Parameter = class {
  281. constructor(name, args) {
  282. this._name = name;
  283. this._arguments = args;
  284. }
  285. get name() {
  286. return this._name;
  287. }
  288. get visible() {
  289. return true;
  290. }
  291. get arguments() {
  292. return this._arguments;
  293. }
  294. }
  295. chainer.Argument = class {
  296. constructor(id, initializer) {
  297. if (typeof id !== 'string') {
  298. throw new chainer.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
  299. }
  300. this._id = id;
  301. this._initializer = initializer || null;
  302. }
  303. get id() {
  304. return this._id;
  305. }
  306. get type() {
  307. return this._initializer.type;
  308. }
  309. get initializer() {
  310. return this._initializer;
  311. }
  312. }
  313. chainer.Node = class {
  314. constructor(module) {
  315. this._name = module.name;
  316. this._inputs = [];
  317. for (const parameter of module.parameters) {
  318. const name = [ this._name, parameter.name ].join('/');
  319. const initializer = new chainer.Tensor(name, parameter.dataType, parameter.shape, parameter.data, parameter.byteOrder);
  320. this._inputs.push(new chainer.Parameter(parameter.name, [
  321. new chainer.Argument(name, initializer)
  322. ]));
  323. }
  324. }
  325. get operator() {
  326. return 'Module';
  327. }
  328. get name() {
  329. return this._name;
  330. }
  331. get metadata() {
  332. return null;
  333. }
  334. get inputs() {
  335. return this._inputs;
  336. }
  337. get outputs() {
  338. return [];
  339. }
  340. get attributes() {
  341. return [];
  342. }
  343. }
  344. chainer.Tensor = class {
  345. constructor(name, dataType, shape, data, byteOrder) {
  346. this._name = name;
  347. this._type = new chainer.TensorType(dataType, new chainer.TensorShape(shape));
  348. this._shape = shape;
  349. this._data = data;
  350. this._byteOrder = byteOrder;
  351. }
  352. get kind() {
  353. return '';
  354. }
  355. get name() {
  356. return this._name;
  357. }
  358. get type(){
  359. return this._type;
  360. }
  361. get state() {
  362. return this._context().state;
  363. }
  364. get value() {
  365. let context = this._context();
  366. if (context.state) {
  367. return null;
  368. }
  369. context.limit = Number.MAX_SAFE_INTEGER;
  370. return this._decode(context, 0);
  371. }
  372. toString() {
  373. let context = this._context();
  374. if (context.state) {
  375. return '';
  376. }
  377. context.limit = 10000;
  378. let value = this._decode(context, 0);
  379. return chainer.Tensor._stringify(value, '', ' ');
  380. }
  381. _context() {
  382. let context = {};
  383. context.index = 0;
  384. context.count = 0;
  385. context.state = null;
  386. if (this._byteOrder !== '<' && this._byteOrder !== '>') {
  387. context.state = 'Tensor byte order is not supported.';
  388. return context;
  389. }
  390. if (this._reference) {
  391. context.state = 'Tensor reference not implemented.';
  392. return context;
  393. }
  394. if (!this._data || this._data.length == 0) {
  395. context.state = 'Tensor data is empty.';
  396. return context;
  397. }
  398. switch (this._type.dataType) {
  399. case 'float16':
  400. context.itemSize = 2;
  401. break;
  402. case 'float32':
  403. context.itemSize = 4;
  404. break;
  405. case 'float64':
  406. context.itemSize = 8;
  407. break;
  408. case 'int64':
  409. context.itemSize = 8;
  410. break;
  411. }
  412. if (!context.itemSize) {
  413. context.state = 'Tensor data type is not supported.';
  414. return context;
  415. }
  416. context.dimensions = this._type.shape.dimensions;
  417. context.dataType = this._type.dataType;
  418. context.littleEndian = this._byteOrder == '<';
  419. context.data = this._data;
  420. context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  421. return context;
  422. }
  423. _decode(context, dimension) {
  424. const littleEndian = context.littleEndian;
  425. let shape = context.dimensions;
  426. if (shape.length == 0) {
  427. shape = [ 1 ];
  428. }
  429. let results = [];
  430. let size = shape[dimension];
  431. if (dimension == shape.length - 1) {
  432. for (let i = 0; i < size; i++) {
  433. if (context.count > context.limit) {
  434. results.push('...');
  435. return results;
  436. }
  437. if (context.rawData) {
  438. switch (context.dataType) {
  439. case 'float16':
  440. results.push(context.rawData.getFloat16(context.index, littleEndian));
  441. break;
  442. case 'float32':
  443. results.push(context.rawData.getFloat32(context.index, littleEndian));
  444. break;
  445. case 'float64':
  446. results.push(context.rawData.getFloat64(context.index, littleEndian));
  447. break;
  448. case 'int64':
  449. results.push(long.Long.fromBytes(context.data.subarray(context.index, context.index + 8), true, littleEndian));
  450. break;
  451. }
  452. context.index += context.itemSize;
  453. context.count++;
  454. }
  455. }
  456. }
  457. else {
  458. for (let j = 0; j < size; j++) {
  459. if (context.count > context.limit) {
  460. results.push('...');
  461. return results;
  462. }
  463. results.push(this._decode(context, dimension + 1));
  464. }
  465. }
  466. if (context.dimensions.length == 0) {
  467. return results[0];
  468. }
  469. return results;
  470. }
  471. static _stringify(value, indentation, indent) {
  472. if (Array.isArray(value)) {
  473. let result = [];
  474. result.push(indentation + '[');
  475. const items = value.map((item) => chainer.Tensor._stringify(item, indentation + indent, indent));
  476. if (items.length > 0) {
  477. result.push(items.join(',\n'));
  478. }
  479. result.push(indentation + ']');
  480. return result.join('\n');
  481. }
  482. if (typeof value == 'string') {
  483. return indentation + value;
  484. }
  485. if (value == Infinity) {
  486. return indentation + 'Infinity';
  487. }
  488. if (value == -Infinity) {
  489. return indentation + '-Infinity';
  490. }
  491. if (isNaN(value)) {
  492. return indentation + 'NaN';
  493. }
  494. return indentation + value.toString();
  495. }
  496. }
  497. chainer.TensorType = class {
  498. constructor(dataType, shape) {
  499. this._dataType = dataType;
  500. this._shape = shape;
  501. }
  502. get dataType() {
  503. return this._dataType || '?';
  504. }
  505. get shape() {
  506. return this._shape;
  507. }
  508. toString() {
  509. return this.dataType + this._shape.toString();
  510. }
  511. }
  512. chainer.TensorShape = class {
  513. constructor(dimensions) {
  514. this._dimensions = dimensions;
  515. }
  516. get dimensions() {
  517. return this._dimensions;
  518. }
  519. toString() {
  520. if (!this._dimensions || this._dimensions.length == 0) {
  521. return '';
  522. }
  523. return '[' + this._dimensions.join(',') + ']';
  524. }
  525. }
  526. chainer.Error = class extends Error {
  527. constructor(message) {
  528. super(message);
  529. this.name = 'Error loading Chainer model.';
  530. }
  531. }
  532. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  533. module.exports.ModelFactory = chainer.ModelFactory;
  534. }