npz.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var npz = npz || {};
  4. var python = python || require('./python');
  5. npz.ModelFactory = class {
  6. match(context) {
  7. const entries = context.entries('zip');
  8. if (entries.length > 0 && entries.every((entry) => entry.name.endsWith('.npy'))) {
  9. return true;
  10. }
  11. const tags = context.tags('pkl');
  12. if (tags.size === 1 && tags.keys().next().value === '') {
  13. if (npz.Utility.weights(tags.values().next().value)) {
  14. return true;
  15. }
  16. }
  17. return false;
  18. }
  19. open(context) {
  20. return context.require('./numpy').then((numpy) => {
  21. const tags = context.tags('pkl');
  22. const groups = new Map();
  23. let format = '';
  24. if (tags.size === 1) {
  25. format = 'NumPy Weights';
  26. const weights = npz.Utility.weights(tags.values().next().value);
  27. let separator = '_';
  28. if (Array.from(weights.keys()).every((key) => key.indexOf('.') !== -1) &&
  29. !Array.from(weights.keys()).every((key) => key.indexOf('_') !== -1)) {
  30. separator = '.';
  31. }
  32. for (const pair of weights) {
  33. const name = pair[0];
  34. const value = pair[1];
  35. const parts = name.split(separator);
  36. const parameterName = parts.length > 1 ? parts.pop() : '?';
  37. const groupName = parts.join(separator);
  38. if (!groups.has(groupName)) {
  39. groups.set(groupName, { name: groupName, parameters: [] });
  40. }
  41. const group = groups.get(groupName);
  42. group.parameters.push({
  43. name: parameterName,
  44. tensor: {
  45. name: name,
  46. byteOrder: value.dtype.byteorder,
  47. dataType: value.dtype.name,
  48. shape: value.shape,
  49. data: value.data
  50. }
  51. });
  52. }
  53. }
  54. else {
  55. format = 'NumPy Zip';
  56. const dataTypeMap = new Map([
  57. [ 'i1', 'int8'], [ 'i2', 'int16' ], [ 'i4', 'int32'], [ 'i8', 'int64' ],
  58. [ 'u1', 'uint8'], [ 'u2', 'uint16' ], [ 'u4', 'uint32'], [ 'u8', 'uint64' ],
  59. [ 'f2', 'float16'], [ 'f4', 'float32' ], [ 'f8', 'float64']
  60. ]);
  61. const execution = new python.Execution(null);
  62. for (const entry of context.entries('zip')) {
  63. if (!entry.name.endsWith('.npy')) {
  64. throw new npz.Error("Invalid file name '" + entry.name + "'.");
  65. }
  66. const name = entry.name.replace(/\.npy$/, '');
  67. const parts = name.split('/');
  68. const parameterName = parts.pop();
  69. const groupName = parts.join('/');
  70. if (!groups.has(groupName)) {
  71. groups.set(groupName, { name: groupName, parameters: [] });
  72. }
  73. const group = groups.get(groupName);
  74. const data = entry.data;
  75. let array = new numpy.Array(data);
  76. if (array.byteOrder === '|') {
  77. if (array.dataType !== 'O') {
  78. throw new npz.Error("Invalid data type '" + array.dataType + "'.");
  79. }
  80. const unpickler = new python.Unpickler(array.data);
  81. const root = unpickler.load((name, args) => execution.invoke(name, args));
  82. array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
  83. }
  84. group.parameters.push({
  85. name: parameterName,
  86. tensor: {
  87. name: name,
  88. byteOrder: array.byteOrder,
  89. dataType: dataTypeMap.has(array.dataType) ? dataTypeMap.get(array.dataType) : array.dataType,
  90. shape: array.shape,
  91. data: array.data,
  92. }
  93. });
  94. }
  95. }
  96. return new npz.Model(format, groups.values());
  97. });
  98. }
  99. };
  100. npz.Model = class {
  101. constructor(format, groups) {
  102. this._format = format;
  103. this._graphs = [];
  104. this._graphs.push(new npz.Graph(groups));
  105. }
  106. get format() {
  107. return this._format;
  108. }
  109. get graphs() {
  110. return this._graphs;
  111. }
  112. };
  113. npz.Graph = class {
  114. constructor(groups) {
  115. this._nodes = [];
  116. for (const group of groups) {
  117. this._nodes.push(new npz.Node(group));
  118. }
  119. }
  120. get inputs() {
  121. return [];
  122. }
  123. get outputs() {
  124. return [];
  125. }
  126. get nodes() {
  127. return this._nodes;
  128. }
  129. };
  130. npz.Parameter = class {
  131. constructor(name, args) {
  132. this._name = name;
  133. this._arguments = args;
  134. }
  135. get name() {
  136. return this._name;
  137. }
  138. get visible() {
  139. return true;
  140. }
  141. get arguments() {
  142. return this._arguments;
  143. }
  144. };
  145. npz.Argument = class {
  146. constructor(name, initializer) {
  147. if (typeof name !== 'string') {
  148. throw new npz.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  149. }
  150. this._name = name;
  151. this._initializer = initializer || null;
  152. }
  153. get name() {
  154. return this._name;
  155. }
  156. get type() {
  157. return this._initializer.type;
  158. }
  159. get initializer() {
  160. return this._initializer;
  161. }
  162. };
  163. npz.Node = class {
  164. constructor(group) {
  165. this._name = group.name;
  166. this._inputs = [];
  167. for (const parameter of group.parameters) {
  168. const name = this._name ? [ this._name, parameter.name ].join('/') : parameter.name;
  169. const tensor = parameter.tensor;
  170. const initializer = new npz.Tensor(name, tensor.dataType, tensor.shape, tensor.data, tensor.byteOrder);
  171. this._inputs.push(new npz.Parameter(parameter.name, [
  172. new npz.Argument(tensor.name || '', initializer)
  173. ]));
  174. }
  175. }
  176. get type() {
  177. return 'Module';
  178. }
  179. get name() {
  180. return this._name;
  181. }
  182. get metadata() {
  183. return null;
  184. }
  185. get inputs() {
  186. return this._inputs;
  187. }
  188. get outputs() {
  189. return [];
  190. }
  191. get attributes() {
  192. return [];
  193. }
  194. };
  195. npz.Tensor = class {
  196. constructor(name, dataType, shape, data, byteOrder) {
  197. this._name = name;
  198. this._type = new npz.TensorType(dataType, new npz.TensorShape(shape));
  199. this._shape = shape;
  200. this._data = data;
  201. this._byteOrder = byteOrder;
  202. }
  203. get kind() {
  204. return 'NumPy Array';
  205. }
  206. get name() {
  207. return this._name;
  208. }
  209. get type(){
  210. return this._type;
  211. }
  212. get state() {
  213. return this._context().state;
  214. }
  215. get value() {
  216. const context = this._context();
  217. if (context.state) {
  218. return null;
  219. }
  220. context.limit = Number.MAX_SAFE_INTEGER;
  221. return this._decode(context, 0);
  222. }
  223. toString() {
  224. const context = this._context();
  225. if (context.state) {
  226. return '';
  227. }
  228. context.limit = 10000;
  229. const value = this._decode(context, 0);
  230. return npz.Tensor._stringify(value, '', ' ');
  231. }
  232. _context() {
  233. const context = {};
  234. context.index = 0;
  235. context.count = 0;
  236. context.state = null;
  237. if (this._byteOrder !== '<' && this._byteOrder !== '>') {
  238. context.state = 'Tensor byte order is not supported.';
  239. return context;
  240. }
  241. if (this._reference) {
  242. context.state = 'Tensor reference not implemented.';
  243. return context;
  244. }
  245. if (!this._data || this._data.length == 0) {
  246. context.state = 'Tensor data is empty.';
  247. return context;
  248. }
  249. switch (this._type.dataType) {
  250. case 'float16':
  251. context.itemSize = 2;
  252. break;
  253. case 'float32':
  254. context.itemSize = 4;
  255. break;
  256. case 'float64':
  257. context.itemSize = 8;
  258. break;
  259. case 'int8':
  260. context.itemSize = 1;
  261. break;
  262. case 'int16':
  263. context.itemSize = 2;
  264. break;
  265. case 'int32':
  266. context.itemSize = 4;
  267. break;
  268. case 'int64':
  269. context.itemSize = 8;
  270. break;
  271. case 'uint8':
  272. context.itemSize = 1;
  273. break;
  274. case 'uint16':
  275. context.itemSize = 2;
  276. break;
  277. case 'uint32':
  278. context.itemSize = 4;
  279. break;
  280. default:
  281. context.state = 'Tensor data type is not supported.';
  282. return context;
  283. }
  284. context.dimensions = this._type.shape.dimensions;
  285. context.dataType = this._type.dataType;
  286. context.littleEndian = this._byteOrder == '<';
  287. context.data = this._data;
  288. context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  289. return context;
  290. }
  291. _decode(context, dimension) {
  292. const littleEndian = context.littleEndian;
  293. const shape = context.dimensions.length == 0 ? [ 1 ] : context.dimensions;
  294. const results = [];
  295. const size = shape[dimension];
  296. if (dimension == shape.length - 1) {
  297. for (let i = 0; i < size; i++) {
  298. if (context.count > context.limit) {
  299. results.push('...');
  300. return results;
  301. }
  302. if (context.rawData) {
  303. switch (context.dataType) {
  304. case 'float16':
  305. results.push(context.rawData.getFloat16(context.index, littleEndian));
  306. break;
  307. case 'float32':
  308. results.push(context.rawData.getFloat32(context.index, littleEndian));
  309. break;
  310. case 'float64':
  311. results.push(context.rawData.getFloat64(context.index, littleEndian));
  312. break;
  313. case 'int8':
  314. results.push(context.rawData.getInt8(context.index, littleEndian));
  315. break;
  316. case 'int16':
  317. results.push(context.rawData.getInt16(context.index, littleEndian));
  318. break;
  319. case 'int32':
  320. results.push(context.rawData.getInt32(context.index, littleEndian));
  321. break;
  322. case 'int64':
  323. results.push(context.rawData.getInt64(context.index, littleEndian));
  324. break;
  325. case 'uint8':
  326. results.push(context.rawData.getUint8(context.index, littleEndian));
  327. break;
  328. case 'uint16':
  329. results.push(context.rawData.getUint16(context.index, littleEndian));
  330. break;
  331. case 'uint32':
  332. results.push(context.rawData.getUint32(context.index, littleEndian));
  333. break;
  334. }
  335. context.index += context.itemSize;
  336. context.count++;
  337. }
  338. }
  339. }
  340. else {
  341. for (let j = 0; j < size; j++) {
  342. if (context.count > context.limit) {
  343. results.push('...');
  344. return results;
  345. }
  346. results.push(this._decode(context, dimension + 1));
  347. }
  348. }
  349. if (context.dimensions.length == 0) {
  350. return results[0];
  351. }
  352. return results;
  353. }
  354. static _stringify(value, indentation, indent) {
  355. if (Array.isArray(value)) {
  356. const result = [];
  357. result.push(indentation + '[');
  358. const items = value.map((item) => npz.Tensor._stringify(item, indentation + indent, indent));
  359. if (items.length > 0) {
  360. result.push(items.join(',\n'));
  361. }
  362. result.push(indentation + ']');
  363. return result.join('\n');
  364. }
  365. if (typeof value == 'string') {
  366. return indentation + value;
  367. }
  368. if (value == Infinity) {
  369. return indentation + 'Infinity';
  370. }
  371. if (value == -Infinity) {
  372. return indentation + '-Infinity';
  373. }
  374. if (isNaN(value)) {
  375. return indentation + 'NaN';
  376. }
  377. return indentation + value.toString();
  378. }
  379. };
  380. npz.TensorType = class {
  381. constructor(dataType, shape) {
  382. this._dataType = dataType;
  383. this._shape = shape;
  384. }
  385. get dataType() {
  386. return this._dataType || '?';
  387. }
  388. get shape() {
  389. return this._shape;
  390. }
  391. toString() {
  392. return this.dataType + this._shape.toString();
  393. }
  394. };
  395. npz.TensorShape = class {
  396. constructor(dimensions) {
  397. this._dimensions = dimensions;
  398. }
  399. get dimensions() {
  400. return this._dimensions;
  401. }
  402. toString() {
  403. if (!this._dimensions || this._dimensions.length == 0) {
  404. return '';
  405. }
  406. return '[' + this._dimensions.join(',') + ']';
  407. }
  408. };
  409. npz.Utility = class {
  410. static isTensor(obj) {
  411. return obj && obj.__module__ === 'numpy' && obj.__name__ === 'ndarray';
  412. }
  413. static weights(obj) {
  414. const keys = [ '', 'blobs' ];
  415. for (const key of keys) {
  416. const dict = key === '' ? obj : obj[key];
  417. if (dict) {
  418. const weights = new Map();
  419. if (dict instanceof Map) {
  420. for (const pair of dict) {
  421. if (!npz.Utility.isTensor(pair[1])) {
  422. return null;
  423. }
  424. weights.set(pair[0], pair[1]);
  425. }
  426. return weights;
  427. }
  428. else if (!Array.isArray(dict)) {
  429. for (const key in dict) {
  430. const value = dict[key];
  431. if (key != 'weight_order' && key != 'lr') {
  432. if (!key || !npz.Utility.isTensor(value)) {
  433. return null;
  434. }
  435. weights.set(key, value);
  436. }
  437. }
  438. return weights;
  439. }
  440. }
  441. }
  442. for (const key of keys) {
  443. const list = key === '' ? obj : obj[key];
  444. if (list && Array.isArray(list)) {
  445. const weights = new Map();
  446. for (let i = 0; i < list.length; i++) {
  447. const value = list[i];
  448. if (!npz.Utility.isTensor(value, 'numpy.ndarray')) {
  449. return null;
  450. }
  451. weights.set(i.toString(), value);
  452. }
  453. return weights;
  454. }
  455. }
  456. return null;
  457. }
  458. };
  459. npz.Error = class extends Error {
  460. constructor(message) {
  461. super(message);
  462. this.name = 'Error loading Chainer model.';
  463. }
  464. };
  465. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  466. module.exports.ModelFactory = npz.ModelFactory;
  467. }