pickle.js 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var pickle = pickle || {};
  4. var python = python || require('./python');
  5. var zip = zip || require('./zip');
  6. pickle.ModelFactory = class {
  7. match(context) {
  8. const stream = context.stream;
  9. const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  10. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  11. // Reject PyTorch models with .pkl file extension.
  12. return false;
  13. }
  14. const obj = context.open('pkl');
  15. if (obj !== undefined) {
  16. return true;
  17. }
  18. return false;
  19. }
  20. open(context) {
  21. return new Promise((resolve) => {
  22. let format = 'Pickle';
  23. const obj = context.open('pkl');
  24. if (obj === null || obj === undefined) {
  25. context.exception(new pickle.Error('Unknown Pickle null object.'));
  26. }
  27. else if (Array.isArray(obj)) {
  28. context.exception(new pickle.Error('Unknown Pickle array object.'));
  29. }
  30. else if (obj && obj.__class__) {
  31. const formats = new Map([
  32. [ 'cuml.ensemble.randomforestclassifier.RandomForestClassifier', 'cuML' ]
  33. ]);
  34. const type = obj.__class__.__module__ + "." + obj.__class__.__name__;
  35. if (formats.has(type)) {
  36. format = formats.get(type);
  37. }
  38. else {
  39. context.exception(new pickle.Error("Unknown Pickle type '" + type + "'."));
  40. }
  41. }
  42. else {
  43. context.exception(new pickle.Error('Unknown Pickle object.'));
  44. }
  45. resolve(new pickle.Model(obj, format));
  46. });
  47. }
  48. };
  49. pickle.Model = class {
  50. constructor(value, format) {
  51. this._format = format;
  52. this._graphs = [ new pickle.Graph(value) ];
  53. }
  54. get format() {
  55. return this._format;
  56. }
  57. get graphs() {
  58. return this._graphs;
  59. }
  60. };
  61. pickle.Graph = class {
  62. constructor(obj) {
  63. this._inputs = [];
  64. this._outputs = [];
  65. this._nodes = [];
  66. if (Array.isArray(obj) && obj.every((item) => item.__class__)) {
  67. for (const item of obj) {
  68. this._nodes.push(new pickle.Node(item));
  69. }
  70. }
  71. else if (obj && obj.__class__) {
  72. this._nodes.push(new pickle.Node(obj));
  73. }
  74. else if (obj && Object(obj) === obj) {
  75. this._nodes.push(new pickle.Node(obj));
  76. }
  77. }
  78. get inputs() {
  79. return this._inputs;
  80. }
  81. get outputs() {
  82. return this._outputs;
  83. }
  84. get nodes() {
  85. return this._nodes;
  86. }
  87. };
  88. pickle.Node = class {
  89. constructor(obj) {
  90. this._inputs = [];
  91. this._outputs = [];
  92. this._attributes = [];
  93. if (Array.isArray(obj)) {
  94. this._type = { name: 'List' };
  95. this._attributes.push(new pickle.Attribute('value', obj));
  96. }
  97. else {
  98. const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : 'Object';
  99. this._type = { name: type };
  100. for (const key of Object.keys(obj)) {
  101. const value = obj[key];
  102. this._attributes.push(new pickle.Attribute(key, value));
  103. }
  104. }
  105. }
  106. get type() {
  107. return this._type;
  108. }
  109. get name() {
  110. return '';
  111. }
  112. get inputs() {
  113. return this._inputs;
  114. }
  115. get outputs() {
  116. return this._outputs;
  117. }
  118. get attributes() {
  119. return this._attributes;
  120. }
  121. };
  122. pickle.Attribute = class {
  123. constructor(name, value) {
  124. this._name = name;
  125. this._value = value;
  126. if (value && value.__class__) {
  127. this._type = value.__class__.__module__ + '.' + value.__class__.__name__;
  128. }
  129. }
  130. get name() {
  131. return this._name;
  132. }
  133. get value() {
  134. return this._value;
  135. }
  136. get type() {
  137. return this._type;
  138. }
  139. };
  140. pickle.Error = class extends Error {
  141. constructor(message) {
  142. super(message);
  143. this.name = 'Error loading Pickle model.';
  144. }
  145. };
  146. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  147. module.exports.ModelFactory = pickle.ModelFactory;
  148. }