pickle.js 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. // Experimental
  2. const pickle = {};
  3. pickle.ModelFactory = class {
  4. async match(context) {
  5. const stream = context.stream;
  6. const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
  7. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  8. // Reject PyTorch models with .pkl file extension.
  9. return null;
  10. }
  11. const obj = await context.peek('pkl');
  12. if (obj !== undefined) {
  13. const name = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : '';
  14. if (!name.startsWith('__torch__.')) {
  15. return context.set('pickle', obj);
  16. }
  17. }
  18. return null;
  19. }
  20. async open(context) {
  21. let format = 'Pickle';
  22. const obj = context.value;
  23. if (obj === null || obj === undefined) {
  24. context.error(new pickle.Error("Unsupported Pickle null object."));
  25. } else if (obj instanceof Error) {
  26. throw obj;
  27. } else if (!Array.isArray(obj) && obj && obj.__class__) {
  28. const formats = new Map([
  29. ['cuml.ensemble.randomforestclassifier.RandomForestClassifier', 'cuML'],
  30. ['shap.explainers._linear.LinearExplainer', 'SHAP'],
  31. ['gensim.models.word2vec.Word2Vec', 'Gensim'],
  32. ['builtins.bytearray', 'Pickle'],
  33. ['builtins.dict', 'Pickle'],
  34. ['collections.OrderedDict', 'Pickle'],
  35. ['numpy.ndarray', 'NumPy NDArray'],
  36. ]);
  37. const type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  38. if (formats.has(type)) {
  39. format = formats.get(type);
  40. } else {
  41. context.error(new pickle.Error(`Unsupported Pickle type '${type}'.`));
  42. }
  43. }
  44. return new pickle.Model(obj, format);
  45. }
  46. };
  47. pickle.Model = class {
  48. constructor(value, format) {
  49. this.format = format;
  50. this.modules = [new pickle.Module(null, value)];
  51. }
  52. };
  53. pickle.Module = class {
  54. constructor(type = '', obj = null) {
  55. this.type = type;
  56. this.inputs = [];
  57. this.outputs = [];
  58. this.nodes = [];
  59. const weights = this.type === 'weights' ? obj : pickle.Utility.weights(obj);
  60. if (weights) {
  61. for (const [name, module] of weights) {
  62. const node = new pickle.Node(module, name, 'Weights');
  63. this.nodes.push(node);
  64. }
  65. } else if (pickle.Utility.isTensor(obj)) {
  66. const type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  67. const node = new pickle.Node({ value: obj }, null, type);
  68. this.nodes.push(node);
  69. } else if (Array.isArray(obj) && (obj.every((item) => item.__class__) || (obj.every((item) => Array.isArray(item))))) {
  70. for (const item of obj) {
  71. this.nodes.push(new pickle.Node(item));
  72. }
  73. } else if (obj && obj.__class__) {
  74. this.nodes.push(new pickle.Node(obj));
  75. } else if (obj && Object(obj) === obj) {
  76. this.nodes.push(new pickle.Node(obj));
  77. }
  78. }
  79. };
  80. pickle.Node = class {
  81. constructor(obj, name, type, stack) {
  82. if (typeof type !== 'string') {
  83. type = obj.__class__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : 'builtins.object';
  84. }
  85. this.type = { name: type };
  86. this.name = name || '';
  87. this.inputs = [];
  88. if (type === 'builtins.bytearray') {
  89. const argument = new pickle.Argument('value', Array.from(obj), 'byte[]');
  90. this.inputs.push(argument);
  91. return;
  92. }
  93. const weights = pickle.Utility.weights(obj);
  94. if (weights) {
  95. const type = this.type.name;
  96. this.type = new pickle.Module('weights', weights);
  97. this.type.name = type;
  98. return;
  99. }
  100. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  101. for (const [name, value] of entries) {
  102. if (name === '__class__') {
  103. continue;
  104. } else if (value && pickle.Utility.isTensor(value)) {
  105. const identifier = value.__name__ || '';
  106. const tensor = new pickle.Tensor(value);
  107. const values = [new pickle.Value(identifier, null, tensor)];
  108. const argument = new pickle.Argument(name, values);
  109. this.inputs.push(argument);
  110. } else if (Array.isArray(value) && value.length > 0 && value.every((obj) => pickle.Utility.isTensor(obj))) {
  111. const values = value.map((obj) => new pickle.Value(obj.__name__ || '', null, new pickle.Tensor(obj)));
  112. const argument = new pickle.Argument(name, values);
  113. this.inputs.push(argument);
  114. } else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && (value.__class__.__name__ === 'function' || value.__class__.__name__ === 'type')) {
  115. const obj = {};
  116. obj.__class__ = value;
  117. const node = new pickle.Node(obj, null, null, stack);
  118. const argument = new pickle.Argument(name, node, 'object');
  119. this.inputs.push(argument);
  120. } else if (pickle.Utility.isByteArray(value)) {
  121. const argument = new pickle.Argument(name, Array.from(value), 'byte[]');
  122. this.inputs.push(argument);
  123. } else {
  124. stack = stack || new Set();
  125. if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'string')) {
  126. const argument = new pickle.Argument(name, value, 'string[]');
  127. this.inputs.push(argument);
  128. } else if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'number')) {
  129. const argument = new pickle.Argument(name, value, 'attribute');
  130. this.inputs.push(argument);
  131. } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
  132. const chain = stack;
  133. const values = value.filter((value) => !chain.has(value));
  134. const nodes = values.map((value) => {
  135. chain.add(value);
  136. const node = new pickle.Node(value, null, null, chain);
  137. chain.delete(value);
  138. return node;
  139. });
  140. const argument = new pickle.Argument(name, nodes, 'object[]');
  141. this.inputs.push(argument);
  142. } else if (value && (value.__class__ || pickle.Utility.isObject(value)) && !stack.has(value)) {
  143. stack.add(value);
  144. const node = new pickle.Node(value, null, null, stack);
  145. const visible = name !== '_metadata' || !pickle.Utility.isMetadataObject(value);
  146. const argument = new pickle.Argument(name, node, 'object', visible);
  147. this.inputs.push(argument);
  148. stack.delete(value);
  149. } else {
  150. const argument = new pickle.Argument(name, value, 'attribute');
  151. this.inputs.push(argument);
  152. }
  153. }
  154. }
  155. }
  156. };
  157. pickle.Argument = class {
  158. constructor(name, value, type = null, visible = true) {
  159. this.name = name.toString();
  160. this.value = value;
  161. this.type = type;
  162. this.visible = visible;
  163. }
  164. };
  165. pickle.Value = class {
  166. constructor(name, type, initializer = null) {
  167. if (typeof name !== 'string') {
  168. throw new pickle.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  169. }
  170. this.name = name;
  171. this.type = initializer && initializer.type ? initializer.type : type || null;
  172. this.initializer = initializer;
  173. }
  174. };
  175. pickle.Tensor = class {
  176. constructor(obj) {
  177. if (obj.__class__ && (obj.__class__.__module__ === 'torch' || obj.__class__.__module__ === 'torch.nn.parameter')) {
  178. // PyTorch tensor
  179. const tensor = obj.__class__.__module__ === 'torch.nn.parameter' && obj.__class__.__name__ === 'Parameter' ? obj.data : obj;
  180. const layout = tensor.layout ? tensor.layout.__str__() : null;
  181. const storage = tensor.storage();
  182. const size = tensor.size() || [];
  183. if (!layout || layout === 'torch.strided') {
  184. this.type = new pickle.TensorType(storage.dtype.__reduce__(), new pickle.TensorShape(size));
  185. this.values = storage.data;
  186. this.encoding = '<';
  187. this.indices = null;
  188. this.stride = tensor.stride();
  189. const stride = this.stride;
  190. const offset = tensor.storage_offset();
  191. let length = 0;
  192. if (!Array.isArray(stride)) {
  193. length = storage.size();
  194. } else if (size.every((v) => v !== 0)) {
  195. length = size.reduce((a, v, i) => a + stride[i] * (v - 1), 1);
  196. }
  197. if (this.values !== undefined) {
  198. if (offset !== 0 || length !== storage.size()) {
  199. const itemsize = storage.dtype.itemsize();
  200. const stream = this.values;
  201. const position = stream.position;
  202. stream.seek(itemsize * offset);
  203. this.values = stream.peek(itemsize * length);
  204. stream.seek(position);
  205. } else if (this.values) {
  206. this.values = this.values.peek();
  207. }
  208. }
  209. } else {
  210. throw new pickle.Error(`Unsupported tensor layout '${layout}'.`);
  211. }
  212. } else {
  213. // NumPy array
  214. const array = obj;
  215. this.type = new pickle.TensorType(array.dtype.__name__, new pickle.TensorShape(array.shape));
  216. this.stride = Array.isArray(array.strides) ? array.strides.map((stride) => stride / array.itemsize) : null;
  217. this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'datetime' ? '|' : array.dtype.byteorder;
  218. this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'datetime' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
  219. }
  220. }
  221. };
  222. pickle.TensorType = class {
  223. constructor(dataType, shape) {
  224. this.dataType = dataType;
  225. this.shape = shape;
  226. }
  227. toString() {
  228. return this.dataType + this.shape.toString();
  229. }
  230. };
  231. pickle.TensorShape = class {
  232. constructor(dimensions) {
  233. this.dimensions = dimensions;
  234. }
  235. toString() {
  236. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  237. }
  238. };
  239. pickle.Utility = class {
  240. static isSubclass(value, name) {
  241. if (value && value.__module__ && value.__name__) {
  242. return name === `${value.__module__}.${value.__name__}`;
  243. } else if (value && value.__bases__) {
  244. return value.__bases__.some((obj) => pickle.Utility.isSubclass(obj, name));
  245. }
  246. return false;
  247. }
  248. static isInstance(value, name) {
  249. return value && value.__class__ ? pickle.Utility.isSubclass(value.__class__, name) : false;
  250. }
  251. static isMetadataObject(obj) {
  252. if (pickle.Utility.isInstance(obj, 'collections.OrderedDict')) {
  253. for (const value of obj.values()) {
  254. if (pickle.Utility.isInstance(value, 'builtins.dict')) {
  255. const entries = Array.from(value);
  256. if (entries.length !== 1 && entries[0] !== 'version' && entries[1] !== 1) {
  257. return false;
  258. }
  259. }
  260. }
  261. return true;
  262. }
  263. return false;
  264. }
  265. static isByteArray(obj) {
  266. return obj && obj.__class__ && obj.__class__.__module__ === 'builtins' && obj.__class__.__name__ === 'bytearray';
  267. }
  268. static isObject(obj) {
  269. if (obj && typeof obj === 'object') {
  270. const proto = Object.getPrototypeOf(obj);
  271. return proto === Object.prototype || proto === null;
  272. }
  273. return false;
  274. }
  275. static isTensor(obj) {
  276. return obj && obj.__class__ && obj.__class__.__name__ &&
  277. ((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
  278. (obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'matrix') ||
  279. (obj.__class__.__module__ === 'jax' && obj.__class__.__name__ === 'Array') ||
  280. (obj.__class__.__module__ === 'torch.nn.parameter' && obj.__class__.__name__ === 'Parameter') ||
  281. (obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')));
  282. }
  283. static weights(obj) {
  284. const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
  285. if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module') {
  286. return null;
  287. }
  288. if (pickle.Utility.isTensor(obj)) {
  289. return null;
  290. }
  291. if (obj instanceof Map === false && obj && !Array.isArray(obj) && Object(obj) === obj) {
  292. const entries = Object.entries(obj);
  293. const named = entries.filter(([name, value]) => (typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)) && pickle.Utility.isTensor(value));
  294. if (named.length > 0 && (named.length / entries.length) >= 0.8) {
  295. obj = new Map(entries);
  296. }
  297. }
  298. if (obj instanceof Map) {
  299. const entries = Array.from(obj).filter(([name]) => name !== '_metadata');
  300. let dot = 0;
  301. let pipe = 0;
  302. let underscore = 0;
  303. let count = 0;
  304. let valid = true;
  305. for (const [name, value] of entries) {
  306. if (typeof name === 'string') {
  307. count++;
  308. dot += name.indexOf('.') !== -1;
  309. pipe += name.indexOf('|') !== -1;
  310. underscore += name.endsWith('_w') || name.endsWith('_b') || name.endsWith('_bn_s');
  311. }
  312. if (pickle.Utility.isInstance(value, 'builtins.dict') && !Array.from(value.values()).every((value) => !pickle.Utility.isTensor(value))) {
  313. valid = false;
  314. }
  315. }
  316. if (valid && count > 1 && (dot >= count || pipe >= count || underscore >= count) && (count / entries.length) >= 0.8) {
  317. let separator = null;
  318. if (dot >= pipe && dot >= underscore) {
  319. separator = '.';
  320. } else if (pipe >= underscore) {
  321. separator = '|';
  322. } else {
  323. separator = '_';
  324. }
  325. const modules = new Map();
  326. for (const [name, value] of entries) {
  327. let c = separator;
  328. if (!c) {
  329. c = name.indexOf('.') === -1 && name.indexOf('|') !== -1 ? '|' : '.';
  330. }
  331. const path = name.split(c);
  332. let property = path.pop();
  333. if (path.length > 1 && path[path.length - 1] === '_packed_params') {
  334. property = `${path.pop()}.${property}`;
  335. }
  336. const key = path.join(separator);
  337. if (!modules.has(key)) {
  338. modules.set(key, {});
  339. }
  340. const module = modules.get(key);
  341. if (pickle.Utility.isTensor(value)) {
  342. value.__name__ = name;
  343. }
  344. module[property] = value;
  345. }
  346. return modules;
  347. }
  348. }
  349. if (obj && !Array.isArray(obj) && Object(obj) === obj) {
  350. const modules = new Map();
  351. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  352. if (entries.length > 0 && entries) {
  353. for (const [key, value] of entries) {
  354. const name = key.toString();
  355. if (!value || Object(value) !== value || pickle.Utility.isTensor(value) || ArrayBuffer.isView(value)) {
  356. return null;
  357. }
  358. if (!modules.has(name)) {
  359. modules.set(name, {});
  360. }
  361. const module = modules.get(name);
  362. let tensor = false;
  363. const entries = value instanceof Map ? value : new Map(Object.entries(value));
  364. for (const [name, value] of entries) {
  365. if (typeof name !== 'string') {
  366. return null;
  367. }
  368. if (name.indexOf('.') !== -1) {
  369. return null;
  370. }
  371. if (name === '_metadata') {
  372. continue;
  373. }
  374. if (typeof value === 'string' || typeof value === 'number') {
  375. module[name] = value;
  376. continue;
  377. }
  378. if (pickle.Utility.isTensor(value)) {
  379. value.__name__ = name;
  380. module[name] = value;
  381. tensor = true;
  382. }
  383. }
  384. if (!tensor) {
  385. return null;
  386. }
  387. }
  388. return modules;
  389. }
  390. }
  391. return null;
  392. }
  393. };
  394. pickle.Error = class extends Error {
  395. constructor(message) {
  396. super(message);
  397. this.name = 'Error loading Pickle model.';
  398. }
  399. };
  400. export const ModelFactory = pickle.ModelFactory;