pickle.js 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. // Experimental
  2. const pickle = {};
  3. pickle.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  7. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  8. // Reject PyTorch models with .pkl file extension.
  9. return;
  10. }
  11. const obj = context.peek('pkl');
  12. if (obj !== undefined) {
  13. const name = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : '';
  14. if (!name.startsWith('__torch__.')) {
  15. context.type = 'pickle';
  16. context.target = obj;
  17. return;
  18. }
  19. }
  20. }
  21. async open(context) {
  22. let format = 'Pickle';
  23. const obj = context.target;
  24. if (obj === null || obj === undefined) {
  25. context.exception(new pickle.Error("Unsupported Pickle null object."));
  26. } else if (obj instanceof Error) {
  27. throw obj;
  28. } else if (Array.isArray(obj)) {
  29. if (obj.length > 0 && obj[0] && obj.every((item) => item && item.__class__ && obj[0].__class__ && item.__class__.__module__ === obj[0].__class__.__module__ && item.__class__.__name__ === obj[0].__class__.__name__)) {
  30. const type = `${obj[0].__class__.__module__}.${obj[0].__class__.__name__}`;
  31. context.exception(new pickle.Error(`Unsupported Pickle '${type}' array object.`));
  32. } else if (obj.length > 0) {
  33. context.exception(new pickle.Error("Unsupported Pickle array object."));
  34. }
  35. } else if (obj && obj.__class__) {
  36. const formats = new Map([
  37. [ 'cuml.ensemble.randomforestclassifier.RandomForestClassifier', 'cuML' ]
  38. ]);
  39. const type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  40. if (formats.has(type)) {
  41. format = formats.get(type);
  42. } else {
  43. context.exception(new pickle.Error(`Unsupported Pickle type '${type}'.`));
  44. }
  45. } else {
  46. context.exception(new pickle.Error('Unsupported Pickle object.'));
  47. }
  48. return new pickle.Model(obj, format);
  49. }
  50. };
  51. pickle.Model = class {
  52. constructor(value, format) {
  53. this.format = format;
  54. this.graphs = [ new pickle.Graph(value) ];
  55. }
  56. };
  57. pickle.Graph = class {
  58. constructor(obj) {
  59. this.inputs = [];
  60. this.outputs = [];
  61. this.nodes = [];
  62. if (Array.isArray(obj) && (obj.every((item) => item.__class__) || (obj.every((item) => Array.isArray(item))))) {
  63. for (const item of obj) {
  64. this.nodes.push(new pickle.Node(item));
  65. }
  66. } else if (obj && obj instanceof Map && !Array.from(obj.values()).some((value) => typeof value === 'string' || typeof value === 'number')) {
  67. for (const [name, value] of obj) {
  68. const node = new pickle.Node(value, name);
  69. this.nodes.push(node);
  70. }
  71. } else if (obj && obj.__class__) {
  72. this.nodes.push(new pickle.Node(obj));
  73. } else if (obj && Object(obj) === obj) {
  74. this.nodes.push(new pickle.Node(obj));
  75. }
  76. }
  77. };
  78. pickle.Node = class {
  79. constructor(obj, name, stack) {
  80. const type = obj.__class__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : 'builtins.object';
  81. this.type = { name: type };
  82. this.name = name || '';
  83. this.inputs = [];
  84. this.outputs = [];
  85. this.attributes = [];
  86. const isArray = (obj) => {
  87. return obj && obj.__class__ && obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray';
  88. };
  89. const isObject = (obj) => {
  90. if (obj && typeof obj === 'object') {
  91. const proto = Object.getPrototypeOf(obj);
  92. return proto === Object.prototype || proto === null;
  93. }
  94. return false;
  95. };
  96. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  97. for (const [name, value] of entries) {
  98. if (name === '__class__') {
  99. continue;
  100. } else if (value && isArray(value)) {
  101. const tensor = new pickle.Tensor(value);
  102. const attribute = new pickle.Argument(name, tensor, 'tensor');
  103. this.attributes.push(attribute);
  104. } else if (Array.isArray(value) && value.length > 0 && value.every((obj) => isArray(obj))) {
  105. const tensors = value.map((obj) => new pickle.Tensor(obj));
  106. const attribute = new pickle.Argument(name, tensors, 'tensor[]');
  107. this.attributes.push(attribute);
  108. } else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && (value.__class__.__name__ === 'function' || value.__class__.__name__ === 'type')) {
  109. const obj = {};
  110. obj.__class__ = value;
  111. const node = new pickle.Node(obj, '', stack);
  112. const attribute = new pickle.Argument(name, node, 'object');
  113. this.attributes.push(attribute);
  114. } else {
  115. stack = stack || new Set();
  116. if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'string')) {
  117. const attribute = new pickle.Argument(name, value, 'string[]');
  118. this.attributes.push(attribute);
  119. } else if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'number')) {
  120. const attribute = new pickle.Argument(name, value);
  121. this.attributes.push(attribute);
  122. } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
  123. const values = value.filter((value) => !stack.has(value));
  124. const nodes = values.map((value) => {
  125. stack.add(value);
  126. const node = new pickle.Node(value, '', stack);
  127. stack.delete(value);
  128. return node;
  129. });
  130. const attribute = new pickle.Argument(name, nodes, 'object[]');
  131. this.attributes.push(attribute);
  132. } else if (value && (value.__class__ || isObject(value))) {
  133. if (!stack.has(value)) {
  134. stack.add(value);
  135. const node = new pickle.Node(value, '', stack);
  136. const attribute = new pickle.Argument(name, node, 'object');
  137. this.attributes.push(attribute);
  138. stack.delete(value);
  139. }
  140. } else {
  141. const attribute = new pickle.Argument(name, value);
  142. this.attributes.push(attribute);
  143. }
  144. }
  145. }
  146. }
  147. };
  148. pickle.Argument = class {
  149. constructor(name, value, type, visible) {
  150. this.name = name.toString();
  151. this.value = value;
  152. if (type) {
  153. this.type = type;
  154. }
  155. if (visible === false) {
  156. this.visible = visible;
  157. }
  158. }
  159. };
  160. pickle.Tensor = class {
  161. constructor(array) {
  162. this.type = new pickle.TensorType(array.dtype.__name__, new pickle.TensorShape(array.shape));
  163. this.stride = array.strides.map((stride) => stride / array.itemsize);
  164. this.layout = this.type.dataType == 'string' || this.type.dataType == 'object' ? '|' : array.dtype.byteorder;
  165. this.values = this.type.dataType == 'string' || this.type.dataType == 'object' ? array.tolist() : array.tobytes();
  166. }
  167. };
  168. pickle.TensorType = class {
  169. constructor(dataType, shape) {
  170. this.dataType = dataType;
  171. this.shape = shape;
  172. }
  173. toString() {
  174. return this.dataType + this.shape.toString();
  175. }
  176. };
  177. pickle.TensorShape = class {
  178. constructor(dimensions) {
  179. this.dimensions = dimensions;
  180. }
  181. toString() {
  182. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  183. }
  184. };
  185. pickle.Error = class extends Error {
  186. constructor(message) {
  187. super(message);
  188. this.name = 'Error loading Pickle model.';
  189. }
  190. };
  191. export const ModelFactory = pickle.ModelFactory;