pickle.js 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. // Experimental
  2. const pickle = {};
  3. pickle.ModelFactory = class {
  4. async match(context) {
  5. const stream = context.stream;
  6. const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
  7. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  8. // Reject PyTorch models with .pkl file extension.
  9. return null;
  10. }
  11. const obj = await context.peek('pkl');
  12. if (obj !== undefined) {
  13. const name = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : '';
  14. if (!name.startsWith('__torch__.')) {
  15. return context.set('pickle', obj);
  16. }
  17. }
  18. return null;
  19. }
  20. async open(context) {
  21. let format = 'Pickle';
  22. const obj = context.value;
  23. if (obj === null || obj === undefined) {
  24. context.error(new pickle.Error("Unsupported Pickle null object."));
  25. } else if (obj instanceof Error) {
  26. throw obj;
  27. } else if (!Array.isArray(obj) && obj && obj.__class__) {
  28. const formats = new Map([
  29. ['cuml.ensemble.randomforestclassifier.RandomForestClassifier', 'cuML'],
  30. ['shap.explainers._linear.LinearExplainer', 'SHAP'],
  31. ['gensim.models.word2vec.Word2Vec', 'Gensim'],
  32. ['ray.rllib.algorithms.ppo.ppo.PPOConfig', 'Ray RLlib'],
  33. ['builtins.bytearray', 'Pickle'],
  34. ['builtins.dict', 'Pickle'],
  35. ['collections.OrderedDict', 'Pickle'],
  36. ['numpy.ndarray', 'NumPy NDArray'],
  37. ]);
  38. const type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  39. if (formats.has(type)) {
  40. format = formats.get(type);
  41. } else {
  42. context.error(new pickle.Error(`Unsupported Pickle type '${type}'.`));
  43. }
  44. }
  45. return new pickle.Model(obj, format);
  46. }
  47. };
  48. pickle.Model = class {
  49. constructor(value, format) {
  50. this.format = format;
  51. this.modules = [new pickle.Module(null, value)];
  52. }
  53. };
  54. pickle.Module = class {
  55. constructor(type = '', obj = null) {
  56. this.type = type;
  57. this.inputs = [];
  58. this.outputs = [];
  59. this.nodes = [];
  60. const weights = this.type === 'weights' ? obj : pickle.Utility.weights(obj);
  61. if (weights) {
  62. for (const [name, module] of weights) {
  63. const node = new pickle.Node(module, name, 'Weights');
  64. this.nodes.push(node);
  65. }
  66. } else if (pickle.Utility.isTensor(obj)) {
  67. const type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  68. const node = new pickle.Node({ value: obj }, null, type);
  69. this.nodes.push(node);
  70. } else if (Array.isArray(obj) && (obj.every((item) => item.__class__) || (obj.every((item) => Array.isArray(item))))) {
  71. for (const item of obj) {
  72. this.nodes.push(new pickle.Node(item));
  73. }
  74. } else if (obj && obj.__class__) {
  75. this.nodes.push(new pickle.Node(obj));
  76. } else if (obj && Object(obj) === obj) {
  77. this.nodes.push(new pickle.Node(obj));
  78. }
  79. }
  80. };
  81. pickle.Node = class {
  82. constructor(obj, name, type, stack) {
  83. if (typeof type !== 'string') {
  84. type = obj.__class__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : 'builtins.object';
  85. }
  86. this.type = { name: type };
  87. this.name = name || '';
  88. this.inputs = [];
  89. if (type === 'builtins.bytearray') {
  90. const argument = new pickle.Argument('value', Array.from(obj), 'byte[]');
  91. this.inputs.push(argument);
  92. return;
  93. }
  94. const weights = pickle.Utility.weights(obj);
  95. if (weights) {
  96. const type = this.type.name;
  97. this.type = new pickle.Module('weights', weights);
  98. this.type.name = type;
  99. return;
  100. }
  101. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  102. for (const [name, value] of entries) {
  103. if (name === '__class__') {
  104. continue;
  105. } else if (value && pickle.Utility.isTensor(value)) {
  106. const identifier = value.__name__ || '';
  107. const tensor = new pickle.Tensor(value);
  108. const values = [new pickle.Value(identifier, null, tensor)];
  109. const argument = new pickle.Argument(name, values);
  110. this.inputs.push(argument);
  111. } else if (Array.isArray(value) && value.length > 0 && value.every((obj) => pickle.Utility.isTensor(obj))) {
  112. const values = value.map((obj) => new pickle.Value(obj.__name__ || '', null, new pickle.Tensor(obj)));
  113. const argument = new pickle.Argument(name, values);
  114. this.inputs.push(argument);
  115. } else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && (value.__class__.__name__ === 'function' || value.__class__.__name__ === 'type')) {
  116. const obj = {};
  117. obj.__class__ = value;
  118. const node = new pickle.Node(obj, null, null, stack);
  119. const argument = new pickle.Argument(name, node, 'object');
  120. this.inputs.push(argument);
  121. } else if (pickle.Utility.isByteArray(value)) {
  122. const argument = new pickle.Argument(name, Array.from(value), 'byte[]');
  123. this.inputs.push(argument);
  124. } else {
  125. stack = stack || new Set();
  126. if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'string')) {
  127. const argument = new pickle.Argument(name, value, 'string[]');
  128. this.inputs.push(argument);
  129. } else if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'number')) {
  130. const argument = new pickle.Argument(name, value, 'attribute');
  131. this.inputs.push(argument);
  132. } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
  133. const chain = stack;
  134. const values = value.filter((value) => !chain.has(value));
  135. const nodes = values.map((value) => {
  136. chain.add(value);
  137. const node = new pickle.Node(value, null, null, chain);
  138. chain.delete(value);
  139. return node;
  140. });
  141. const argument = new pickle.Argument(name, nodes, 'object[]');
  142. this.inputs.push(argument);
  143. } else if (value && (value.__class__ || pickle.Utility.isObject(value)) && !stack.has(value)) {
  144. stack.add(value);
  145. const node = new pickle.Node(value, null, null, stack);
  146. const visible = name !== '_metadata' || !pickle.Utility.isMetadataObject(value);
  147. const argument = new pickle.Argument(name, node, 'object', visible);
  148. this.inputs.push(argument);
  149. stack.delete(value);
  150. } else {
  151. const argument = new pickle.Argument(name, value, 'attribute');
  152. this.inputs.push(argument);
  153. }
  154. }
  155. }
  156. }
  157. };
  158. pickle.Argument = class {
  159. constructor(name, value, type = null, visible = true) {
  160. this.name = name.toString();
  161. this.value = value;
  162. this.type = type;
  163. this.visible = visible;
  164. }
  165. };
  166. pickle.Value = class {
  167. constructor(name, type, initializer = null) {
  168. if (typeof name !== 'string') {
  169. throw new pickle.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  170. }
  171. this.name = name;
  172. this.type = initializer && initializer.type ? initializer.type : type || null;
  173. this.initializer = initializer;
  174. }
  175. };
  176. pickle.Tensor = class {
  177. constructor(obj) {
  178. if (obj.__class__ && (obj.__class__.__module__ === 'torch' || obj.__class__.__module__ === 'torch.nn.parameter')) {
  179. // PyTorch tensor
  180. const tensor = obj.__class__.__module__ === 'torch.nn.parameter' && obj.__class__.__name__ === 'Parameter' ? obj.data : obj;
  181. const layout = tensor.layout ? tensor.layout.__str__() : null;
  182. const storage = tensor.storage();
  183. const size = tensor.size() || [];
  184. if (!layout || layout === 'torch.strided') {
  185. this.type = new pickle.TensorType(storage.dtype.__reduce__(), new pickle.TensorShape(size));
  186. this.values = storage.data;
  187. this.encoding = '<';
  188. this.indices = null;
  189. this.stride = tensor.stride();
  190. const stride = this.stride;
  191. const offset = tensor.storage_offset();
  192. let length = 0;
  193. if (!Array.isArray(stride)) {
  194. length = storage.size();
  195. } else if (size.every((v) => v !== 0)) {
  196. length = size.reduce((a, v, i) => a + stride[i] * (v - 1), 1);
  197. }
  198. if (this.values !== undefined) {
  199. if (offset !== 0 || length !== storage.size()) {
  200. const itemsize = storage.dtype.itemsize();
  201. const stream = this.values;
  202. const position = stream.position;
  203. stream.seek(itemsize * offset);
  204. this.values = stream.peek(itemsize * length);
  205. stream.seek(position);
  206. } else if (this.values) {
  207. this.values = this.values.peek();
  208. }
  209. }
  210. } else {
  211. throw new pickle.Error(`Unsupported tensor layout '${layout}'.`);
  212. }
  213. } else {
  214. // NumPy array
  215. const array = obj;
  216. this.type = new pickle.TensorType(array.dtype.__name__, new pickle.TensorShape(array.shape));
  217. this.stride = Array.isArray(array.strides) ? array.strides.map((stride) => stride / array.itemsize) : null;
  218. this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'datetime' ? '|' : array.dtype.byteorder;
  219. this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'datetime' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
  220. }
  221. }
  222. };
  223. pickle.TensorType = class {
  224. constructor(dataType, shape) {
  225. this.dataType = dataType;
  226. this.shape = shape;
  227. }
  228. toString() {
  229. return this.dataType + this.shape.toString();
  230. }
  231. };
  232. pickle.TensorShape = class {
  233. constructor(dimensions) {
  234. this.dimensions = dimensions;
  235. }
  236. toString() {
  237. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  238. }
  239. };
  240. pickle.Utility = class {
  241. static isSubclass(value, name) {
  242. if (value && value.__module__ && value.__name__) {
  243. return name === `${value.__module__}.${value.__name__}`;
  244. } else if (value && value.__bases__) {
  245. return value.__bases__.some((obj) => pickle.Utility.isSubclass(obj, name));
  246. }
  247. return false;
  248. }
  249. static isInstance(value, name) {
  250. return value && value.__class__ ? pickle.Utility.isSubclass(value.__class__, name) : false;
  251. }
  252. static isMetadataObject(obj) {
  253. if (pickle.Utility.isInstance(obj, 'collections.OrderedDict')) {
  254. for (const value of obj.values()) {
  255. if (pickle.Utility.isInstance(value, 'builtins.dict')) {
  256. const entries = Array.from(value);
  257. if (entries.length !== 1 && entries[0] !== 'version' && entries[1] !== 1) {
  258. return false;
  259. }
  260. }
  261. }
  262. return true;
  263. }
  264. return false;
  265. }
  266. static isByteArray(obj) {
  267. return obj && obj.__class__ && obj.__class__.__module__ === 'builtins' && obj.__class__.__name__ === 'bytearray';
  268. }
  269. static isObject(obj) {
  270. if (obj && typeof obj === 'object') {
  271. const proto = Object.getPrototypeOf(obj);
  272. return proto === Object.prototype || proto === null;
  273. }
  274. return false;
  275. }
  276. static isTensor(obj) {
  277. return obj && obj.__class__ && obj.__class__.__name__ &&
  278. ((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
  279. (obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'matrix') ||
  280. (obj.__class__.__module__ === 'jax' && obj.__class__.__name__ === 'Array') ||
  281. (obj.__class__.__module__ === 'torch.nn.parameter' && obj.__class__.__name__ === 'Parameter') ||
  282. (obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')));
  283. }
  284. static weights(obj) {
  285. const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
  286. if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module') {
  287. return null;
  288. }
  289. if (pickle.Utility.isTensor(obj)) {
  290. return null;
  291. }
  292. if (obj instanceof Map === false && obj && !Array.isArray(obj) && Object(obj) === obj) {
  293. const entries = Object.entries(obj);
  294. const named = entries.filter(([name, value]) => (typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)) && pickle.Utility.isTensor(value));
  295. if (named.length > 0 && (named.length / entries.length) >= 0.8) {
  296. obj = new Map(entries);
  297. }
  298. }
  299. if (obj instanceof Map) {
  300. const entries = Array.from(obj).filter(([name]) => name !== '_metadata');
  301. let dot = 0;
  302. let pipe = 0;
  303. let underscore = 0;
  304. let count = 0;
  305. let valid = true;
  306. for (const [name, value] of entries) {
  307. if (typeof name === 'string') {
  308. count++;
  309. dot += name.indexOf('.') !== -1;
  310. pipe += name.indexOf('|') !== -1;
  311. underscore += name.endsWith('_w') || name.endsWith('_b') || name.endsWith('_bn_s');
  312. }
  313. if (pickle.Utility.isInstance(value, 'builtins.dict') && !Array.from(value.values()).every((value) => !pickle.Utility.isTensor(value))) {
  314. valid = false;
  315. }
  316. }
  317. if (valid && count > 1 && (dot >= count || pipe >= count || underscore >= count) && (count / entries.length) >= 0.8) {
  318. let separator = null;
  319. if (dot >= pipe && dot >= underscore) {
  320. separator = '.';
  321. } else if (pipe >= underscore) {
  322. separator = '|';
  323. } else {
  324. separator = '_';
  325. }
  326. const modules = new Map();
  327. for (const [name, value] of entries) {
  328. let c = separator;
  329. if (!c) {
  330. c = name.indexOf('.') === -1 && name.indexOf('|') !== -1 ? '|' : '.';
  331. }
  332. const path = name.split(c);
  333. let property = path.pop();
  334. if (path.length > 1 && path[path.length - 1] === '_packed_params') {
  335. property = `${path.pop()}.${property}`;
  336. }
  337. const key = path.join(separator);
  338. if (!modules.has(key)) {
  339. modules.set(key, {});
  340. }
  341. const module = modules.get(key);
  342. if (pickle.Utility.isTensor(value)) {
  343. value.__name__ = name;
  344. }
  345. module[property] = value;
  346. }
  347. return modules;
  348. }
  349. }
  350. if (obj && !Array.isArray(obj) && Object(obj) === obj) {
  351. const modules = new Map();
  352. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  353. if (entries.length > 0 && entries) {
  354. for (const [key, value] of entries) {
  355. const name = key.toString();
  356. if (!value || Object(value) !== value || pickle.Utility.isTensor(value) || ArrayBuffer.isView(value)) {
  357. return null;
  358. }
  359. if (!modules.has(name)) {
  360. modules.set(name, {});
  361. }
  362. const module = modules.get(name);
  363. let tensor = false;
  364. const entries = value instanceof Map ? value : new Map(Object.entries(value));
  365. for (const [name, value] of entries) {
  366. if (typeof name !== 'string') {
  367. return null;
  368. }
  369. if (name.indexOf('.') !== -1) {
  370. return null;
  371. }
  372. if (name === '_metadata') {
  373. continue;
  374. }
  375. if (typeof value === 'string' || typeof value === 'number') {
  376. module[name] = value;
  377. continue;
  378. }
  379. if (pickle.Utility.isTensor(value)) {
  380. value.__name__ = name;
  381. module[name] = value;
  382. tensor = true;
  383. }
  384. }
  385. if (!tensor) {
  386. return null;
  387. }
  388. }
  389. return modules;
  390. }
  391. }
  392. return null;
  393. }
  394. };
  395. pickle.Error = class extends Error {
  396. constructor(message) {
  397. super(message);
  398. this.name = 'Error loading Pickle model.';
  399. }
  400. };
  401. export const ModelFactory = pickle.ModelFactory;