npz.js 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var npz = npz || {};
  4. npz.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier;
  7. const extension = identifier.split('.').pop().toLowerCase();
  8. if (extension === 'npz') {
  9. const entries = context.entries('zip');
  10. return entries.length > 0 && entries.every((entry) => entry.name.endsWith('.npy'));
  11. }
  12. return false;
  13. }
  14. open(context, host) {
  15. return host.require('./numpy').then((numpy) => {
  16. return host.require('./pickle').then((pickle) => {
  17. const identifier = context.identifier;
  18. try {
  19. const modules = [];
  20. const modulesMap = new Map();
  21. const functionTable = new Map();
  22. const constructorTable = new Map();
  23. functionTable.set('_codecs.encode', function(obj /*, econding */) {
  24. return obj;
  25. });
  26. constructorTable.set('numpy.core.multiarray._reconstruct', function(subtype, shape, dtype) {
  27. this.subtype = subtype;
  28. this.shape = shape;
  29. this.dtype = dtype;
  30. this.__setstate__ = function(state) {
  31. this.version = state[0];
  32. this.shape = state[1];
  33. this.typecode = state[2];
  34. this.is_f_order = state[3];
  35. this.rawdata = state[4];
  36. };
  37. this.__read__ = function(unpickler) {
  38. const array = {};
  39. array.__type__ = this.subtype;
  40. array.dtype = this.typecode;
  41. array.shape = this.shape;
  42. let size = array.dtype.itemsize;
  43. for (let i = 0; i < array.shape.length; i++) {
  44. size = size * array.shape[i];
  45. }
  46. if (typeof this.rawdata == 'string') {
  47. array.data = unpickler.unescape(this.rawdata, size);
  48. if (array.data.length != size) {
  49. throw new npz.Error('Invalid string array data size.');
  50. }
  51. }
  52. else {
  53. array.data = this.rawdata;
  54. if (array.data.length != size) {
  55. // TODO
  56. // throw new npz.Error('Invalid array data size.');
  57. }
  58. }
  59. return array;
  60. };
  61. });
  62. constructorTable.set('numpy.dtype', function(obj, align, copy) {
  63. switch (obj) {
  64. case 'i1': this.name = 'int8'; this.itemsize = 1; break;
  65. case 'i2': this.name = 'int16'; this.itemsize = 2; break;
  66. case 'i4': this.name = 'int32'; this.itemsize = 4; break;
  67. case 'i8': this.name = 'int64'; this.itemsize = 8; break;
  68. case 'u1': this.name = 'uint8'; this.itemsize = 1; break;
  69. case 'u2': this.name = 'uint16'; this.itemsize = 2; break;
  70. case 'u4': this.name = 'uint32'; this.itemsize = 4; break;
  71. case 'u8': this.name = 'uint64'; this.itemsize = 8; break;
  72. case 'f4': this.name = 'float32'; this.itemsize = 4; break;
  73. case 'f8': this.name = 'float64'; this.itemsize = 8; break;
  74. default:
  75. if (obj.startsWith('V')) {
  76. this.itemsize = Number(obj.substring(1));
  77. this.name = 'void' + (this.itemsize * 8).toString();
  78. }
  79. else if (obj.startsWith('O')) {
  80. this.itemsize = Number(obj.substring(1));
  81. this.name = 'object';
  82. }
  83. else if (obj.startsWith('S')) {
  84. this.itemsize = Number(obj.substring(1));
  85. this.name = 'string';
  86. }
  87. else if (obj.startsWith('U')) {
  88. this.itemsize = Number(obj.substring(1));
  89. this.name = 'string';
  90. }
  91. else if (obj.startsWith('M')) {
  92. this.itemsize = Number(obj.substring(1));
  93. this.name = 'datetime';
  94. }
  95. else {
  96. throw new npz.Error("Unknown dtype '" + obj.toString() + "'.");
  97. }
  98. break;
  99. }
  100. this.align = align;
  101. this.copy = copy;
  102. this.__setstate__ = function(state) {
  103. switch (state.length) {
  104. case 8:
  105. this.version = state[0];
  106. this.byteorder = state[1];
  107. this.subarray = state[2];
  108. this.names = state[3];
  109. this.fields = state[4];
  110. this.elsize = state[5];
  111. this.alignment = state[6];
  112. this.int_dtypeflags = state[7];
  113. break;
  114. default:
  115. throw new npz.Error("Unknown numpy.dtype setstate length '" + state.length.toString() + "'.");
  116. }
  117. };
  118. });
  119. const function_call = (name, args) => {
  120. if (functionTable.has(name)) {
  121. const func = functionTable.get(name);
  122. return func.apply(null, args);
  123. }
  124. const obj = { __type__: name };
  125. if (constructorTable.has(name)) {
  126. const constructor = constructorTable.get(name);
  127. constructor.apply(obj, args);
  128. }
  129. else {
  130. throw new npz.Error("Unknown function '" + name + "'.");
  131. }
  132. return obj;
  133. };
  134. const dataTypeMap = new Map([
  135. [ 'i1', 'int8'], [ 'i2', 'int16' ], [ 'i4', 'int32'], [ 'i8', 'int64' ],
  136. [ 'u1', 'uint8'], [ 'u2', 'uint16' ], [ 'u4', 'uint32'], [ 'u8', 'uint64' ],
  137. [ 'f2', 'float16'], [ 'f4', 'float32' ], [ 'f8', 'float64']
  138. ]);
  139. for (const entry of context.entries('zip')) {
  140. if (!entry.name.endsWith('.npy')) {
  141. throw new npz.Error("Invalid file name '" + entry.name + "'.");
  142. }
  143. const id = entry.name.replace(/\.npy$/, '');
  144. const parts = id.split('/');
  145. const parameterName = parts.pop();
  146. const moduleName = (parts.length >= 2) ? parts.join('/') : '';
  147. if (!modulesMap.has(moduleName)) {
  148. const newModule = { name: moduleName, parameters: [] };
  149. modules.push(newModule);
  150. modulesMap.set(moduleName, newModule);
  151. }
  152. const module = modulesMap.get(moduleName);
  153. let array = new numpy.Array(entry.data);
  154. if (array.byteOrder === '|') {
  155. if (array.dataType !== 'O') {
  156. throw new npz.Error("Invalid data type '" + array.dataType + "'.");
  157. }
  158. const unpickler = new pickle.Unpickler(array.data);
  159. const root = unpickler.load(function_call);
  160. array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
  161. }
  162. module.parameters.push({
  163. name: parameterName,
  164. dataType: dataTypeMap.has(array.dataType) ? dataTypeMap.get(array.dataType) : array.dataType,
  165. shape: array.shape,
  166. data: array.data,
  167. byteOrder: array.byteOrder
  168. });
  169. }
  170. return new npz.Model(modules, 'NumPy Zip');
  171. }
  172. catch (error) {
  173. const message = error && error.message ? error.message : error.toString();
  174. throw new npz.Error(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  175. }
  176. });
  177. });
  178. }
  179. };
  180. npz.Model = class {
  181. constructor(modules, format) {
  182. this._format = format;
  183. this._graphs = [];
  184. this._graphs.push(new npz.Graph(modules));
  185. }
  186. get format() {
  187. return this._format;
  188. }
  189. get graphs() {
  190. return this._graphs;
  191. }
  192. };
  193. npz.Graph = class {
  194. constructor(modules) {
  195. this._nodes = [];
  196. for (const module of modules) {
  197. this._nodes.push(new npz.Node(module));
  198. }
  199. }
  200. get inputs() {
  201. return [];
  202. }
  203. get outputs() {
  204. return [];
  205. }
  206. get nodes() {
  207. return this._nodes;
  208. }
  209. };
  210. npz.Parameter = class {
  211. constructor(name, args) {
  212. this._name = name;
  213. this._arguments = args;
  214. }
  215. get name() {
  216. return this._name;
  217. }
  218. get visible() {
  219. return true;
  220. }
  221. get arguments() {
  222. return this._arguments;
  223. }
  224. };
  225. npz.Argument = class {
  226. constructor(name, initializer) {
  227. if (typeof name !== 'string') {
  228. throw new npz.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  229. }
  230. this._name = name;
  231. this._initializer = initializer || null;
  232. }
  233. get name() {
  234. return this._name;
  235. }
  236. get type() {
  237. return this._initializer.type;
  238. }
  239. get initializer() {
  240. return this._initializer;
  241. }
  242. };
  243. npz.Node = class {
  244. constructor(module) {
  245. this._name = module.name;
  246. this._inputs = [];
  247. for (const parameter of module.parameters) {
  248. const name = this._name ? [ this._name, parameter.name ].join('/') : parameter.name;
  249. const initializer = new npz.Tensor(name, parameter.dataType, parameter.shape, parameter.data, parameter.byteOrder);
  250. this._inputs.push(new npz.Parameter(parameter.name, [
  251. new npz.Argument(name, initializer)
  252. ]));
  253. }
  254. }
  255. get type() {
  256. return 'Module';
  257. }
  258. get name() {
  259. return this._name;
  260. }
  261. get metadata() {
  262. return null;
  263. }
  264. get inputs() {
  265. return this._inputs;
  266. }
  267. get outputs() {
  268. return [];
  269. }
  270. get attributes() {
  271. return [];
  272. }
  273. };
  274. npz.Tensor = class {
  275. constructor(name, dataType, shape, data, byteOrder) {
  276. this._name = name;
  277. this._type = new npz.TensorType(dataType, new npz.TensorShape(shape));
  278. this._shape = shape;
  279. this._data = data;
  280. this._byteOrder = byteOrder;
  281. }
  282. get kind() {
  283. return '';
  284. }
  285. get name() {
  286. return this._name;
  287. }
  288. get type(){
  289. return this._type;
  290. }
  291. get state() {
  292. return this._context().state;
  293. }
  294. get value() {
  295. const context = this._context();
  296. if (context.state) {
  297. return null;
  298. }
  299. context.limit = Number.MAX_SAFE_INTEGER;
  300. return this._decode(context, 0);
  301. }
  302. toString() {
  303. const context = this._context();
  304. if (context.state) {
  305. return '';
  306. }
  307. context.limit = 10000;
  308. const value = this._decode(context, 0);
  309. return npz.Tensor._stringify(value, '', ' ');
  310. }
  311. _context() {
  312. const context = {};
  313. context.index = 0;
  314. context.count = 0;
  315. context.state = null;
  316. if (this._byteOrder !== '<' && this._byteOrder !== '>') {
  317. context.state = 'Tensor byte order is not supported.';
  318. return context;
  319. }
  320. if (this._reference) {
  321. context.state = 'Tensor reference not implemented.';
  322. return context;
  323. }
  324. if (!this._data || this._data.length == 0) {
  325. context.state = 'Tensor data is empty.';
  326. return context;
  327. }
  328. switch (this._type.dataType) {
  329. case 'float16':
  330. context.itemSize = 2;
  331. break;
  332. case 'float32':
  333. context.itemSize = 4;
  334. break;
  335. case 'float64':
  336. context.itemSize = 8;
  337. break;
  338. case 'int8':
  339. context.itemSize = 1;
  340. break;
  341. case 'int16':
  342. context.itemSize = 2;
  343. break;
  344. case 'int32':
  345. context.itemSize = 4;
  346. break;
  347. case 'int64':
  348. context.itemSize = 8;
  349. break;
  350. case 'uint8':
  351. context.itemSize = 1;
  352. break;
  353. case 'uint16':
  354. context.itemSize = 2;
  355. break;
  356. case 'uint32':
  357. context.itemSize = 4;
  358. break;
  359. default:
  360. context.state = 'Tensor data type is not supported.';
  361. return context;
  362. }
  363. context.dimensions = this._type.shape.dimensions;
  364. context.dataType = this._type.dataType;
  365. context.littleEndian = this._byteOrder == '<';
  366. context.data = this._data;
  367. context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  368. return context;
  369. }
  370. _decode(context, dimension) {
  371. const littleEndian = context.littleEndian;
  372. const shape = context.dimensions.length == 0 ? [ 1 ] : context.dimensions;
  373. const results = [];
  374. const size = shape[dimension];
  375. if (dimension == shape.length - 1) {
  376. for (let i = 0; i < size; i++) {
  377. if (context.count > context.limit) {
  378. results.push('...');
  379. return results;
  380. }
  381. if (context.rawData) {
  382. switch (context.dataType) {
  383. case 'float16':
  384. results.push(context.rawData.getFloat16(context.index, littleEndian));
  385. break;
  386. case 'float32':
  387. results.push(context.rawData.getFloat32(context.index, littleEndian));
  388. break;
  389. case 'float64':
  390. results.push(context.rawData.getFloat64(context.index, littleEndian));
  391. break;
  392. case 'int8':
  393. results.push(context.rawData.getInt8(context.index, littleEndian));
  394. break;
  395. case 'int16':
  396. results.push(context.rawData.getInt16(context.index, littleEndian));
  397. break;
  398. case 'int32':
  399. results.push(context.rawData.getInt32(context.index, littleEndian));
  400. break;
  401. case 'int64':
  402. results.push(context.rawData.getInt64(context.index, littleEndian));
  403. break;
  404. case 'uint8':
  405. results.push(context.rawData.getUint8(context.index, littleEndian));
  406. break;
  407. case 'uint16':
  408. results.push(context.rawData.getUint16(context.index, littleEndian));
  409. break;
  410. case 'uint32':
  411. results.push(context.rawData.getUint32(context.index, littleEndian));
  412. break;
  413. }
  414. context.index += context.itemSize;
  415. context.count++;
  416. }
  417. }
  418. }
  419. else {
  420. for (let j = 0; j < size; j++) {
  421. if (context.count > context.limit) {
  422. results.push('...');
  423. return results;
  424. }
  425. results.push(this._decode(context, dimension + 1));
  426. }
  427. }
  428. if (context.dimensions.length == 0) {
  429. return results[0];
  430. }
  431. return results;
  432. }
  433. static _stringify(value, indentation, indent) {
  434. if (Array.isArray(value)) {
  435. const result = [];
  436. result.push(indentation + '[');
  437. const items = value.map((item) => npz.Tensor._stringify(item, indentation + indent, indent));
  438. if (items.length > 0) {
  439. result.push(items.join(',\n'));
  440. }
  441. result.push(indentation + ']');
  442. return result.join('\n');
  443. }
  444. if (typeof value == 'string') {
  445. return indentation + value;
  446. }
  447. if (value == Infinity) {
  448. return indentation + 'Infinity';
  449. }
  450. if (value == -Infinity) {
  451. return indentation + '-Infinity';
  452. }
  453. if (isNaN(value)) {
  454. return indentation + 'NaN';
  455. }
  456. return indentation + value.toString();
  457. }
  458. };
  459. npz.TensorType = class {
  460. constructor(dataType, shape) {
  461. this._dataType = dataType;
  462. this._shape = shape;
  463. }
  464. get dataType() {
  465. return this._dataType || '?';
  466. }
  467. get shape() {
  468. return this._shape;
  469. }
  470. toString() {
  471. return this.dataType + this._shape.toString();
  472. }
  473. };
  474. npz.TensorShape = class {
  475. constructor(dimensions) {
  476. this._dimensions = dimensions;
  477. }
  478. get dimensions() {
  479. return this._dimensions;
  480. }
  481. toString() {
  482. if (!this._dimensions || this._dimensions.length == 0) {
  483. return '';
  484. }
  485. return '[' + this._dimensions.join(',') + ']';
  486. }
  487. };
  488. npz.Error = class extends Error {
  489. constructor(message) {
  490. super(message);
  491. this.name = 'Error loading Chainer model.';
  492. }
  493. };
  494. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  495. module.exports.ModelFactory = npz.ModelFactory;
  496. }