pickle.js 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. // Experimental
  2. var pickle = pickle || {};
  3. var python = python || require('./python');
  4. var zip = zip || require('./zip');
  5. pickle.ModelFactory = class {
  6. match(context) {
  7. const stream = context.stream;
  8. const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  9. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  10. // Reject PyTorch models with .pkl file extension.
  11. return undefined;
  12. }
  13. const obj = context.open('pkl');
  14. if (obj !== undefined) {
  15. return 'pickle';
  16. }
  17. return undefined;
  18. }
  19. open(context) {
  20. return new Promise((resolve) => {
  21. let format = 'Pickle';
  22. const obj = context.open('pkl');
  23. if (obj === null || obj === undefined) {
  24. context.exception(new pickle.Error("Unsupported Pickle null object in '" + context.identifier + "'."));
  25. }
  26. else if (Array.isArray(obj)) {
  27. if (obj.length > 0 && obj[0] && obj.every((item) => item && item.__class__ && obj[0].__class__ && item.__class__.__module__ === obj[0].__class__.__module__ && item.__class__.__name__ === obj[0].__class__.__name__)) {
  28. const type = obj[0].__class__.__module__ + "." + obj[0].__class__.__name__;
  29. context.exception(new pickle.Error("Unsupported Pickle '" + type + "' array object in '" + context.identifier + "'."));
  30. }
  31. else {
  32. context.exception(new pickle.Error("Unsupported Pickle array object in '" + context.identifier + "'."));
  33. }
  34. }
  35. else if (obj && obj.__class__) {
  36. const formats = new Map([
  37. [ 'cuml.ensemble.randomforestclassifier.RandomForestClassifier', 'cuML' ]
  38. ]);
  39. const type = obj.__class__.__module__ + "." + obj.__class__.__name__;
  40. if (formats.has(type)) {
  41. format = formats.get(type);
  42. }
  43. else {
  44. context.exception(new pickle.Error("Unsupported Pickle type '" + type + "' in '" + context.identifier + "'."));
  45. }
  46. }
  47. else {
  48. context.exception(new pickle.Error("Unsupported Pickle object in '" + context.identifier + "'."));
  49. }
  50. resolve(new pickle.Model(obj, format));
  51. });
  52. }
  53. };
  54. pickle.Model = class {
  55. constructor(value, format) {
  56. this._format = format;
  57. this._graphs = [ new pickle.Graph(value) ];
  58. }
  59. get format() {
  60. return this._format;
  61. }
  62. get graphs() {
  63. return this._graphs;
  64. }
  65. };
  66. pickle.Graph = class {
  67. constructor(obj) {
  68. this._inputs = [];
  69. this._outputs = [];
  70. this._nodes = [];
  71. if (Array.isArray(obj) && obj.every((item) => item.__class__)) {
  72. for (const item of obj) {
  73. this._nodes.push(new pickle.Node(item));
  74. }
  75. }
  76. else if (obj && obj instanceof Map) {
  77. for (const entry of obj) {
  78. this._nodes.push(new pickle.Node(entry[1], entry[0]));
  79. }
  80. }
  81. else if (obj && obj.__class__) {
  82. this._nodes.push(new pickle.Node(obj));
  83. }
  84. else if (obj && Object(obj) === obj) {
  85. this._nodes.push(new pickle.Node(obj));
  86. }
  87. }
  88. get inputs() {
  89. return this._inputs;
  90. }
  91. get outputs() {
  92. return this._outputs;
  93. }
  94. get nodes() {
  95. return this._nodes;
  96. }
  97. };
  98. pickle.Node = class {
  99. constructor(obj, name) {
  100. this._name = name || '';
  101. this._inputs = [];
  102. this._outputs = [];
  103. this._attributes = [];
  104. if (Array.isArray(obj)) {
  105. this._type = { name: 'List' };
  106. this._attributes.push(new pickle.Attribute('value', obj));
  107. }
  108. else {
  109. const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : 'Object';
  110. this._type = { name: type };
  111. for (const entry of Object.entries(obj)) {
  112. const name = entry[0];
  113. const value = entry[1];
  114. this._attributes.push(new pickle.Attribute(name, value));
  115. }
  116. }
  117. }
  118. get type() {
  119. return this._type;
  120. }
  121. get name() {
  122. return this._name;
  123. }
  124. get inputs() {
  125. return this._inputs;
  126. }
  127. get outputs() {
  128. return this._outputs;
  129. }
  130. get attributes() {
  131. return this._attributes;
  132. }
  133. };
  134. pickle.Attribute = class {
  135. constructor(name, value) {
  136. this._name = name;
  137. this._value = value;
  138. if (value && value.__class__) {
  139. this._type = value.__class__.__module__ + '.' + value.__class__.__name__;
  140. }
  141. }
  142. get name() {
  143. return this._name;
  144. }
  145. get value() {
  146. return this._value;
  147. }
  148. get type() {
  149. return this._type;
  150. }
  151. };
  152. pickle.Error = class extends Error {
  153. constructor(message) {
  154. super(message);
  155. this.name = 'Error loading Pickle model.';
  156. }
  157. };
  158. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  159. module.exports.ModelFactory = pickle.ModelFactory;
  160. }