sklearn.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. // Experimental
  2. const sklearn = {};
  3. sklearn.ModelFactory = class {
  4. async match(context) {
  5. const obj = await context.peek('pkl');
  6. const validate = (obj, name) => {
  7. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
  8. const key = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  9. return key.startsWith(name);
  10. }
  11. return false;
  12. };
  13. const formats = [
  14. { name: 'sklearn.', format: 'sklearn' },
  15. { name: 'xgboost.sklearn.', format: 'sklearn' },
  16. { name: 'lightgbm.sklearn.', format: 'sklearn' },
  17. { name: 'scipy.', format: 'scipy' },
  18. { name: 'hmmlearn.', format: 'hmmlearn' }
  19. ];
  20. for (const format of formats) {
  21. if (validate(obj, format.name)) {
  22. return context.set(format.format, obj);
  23. }
  24. if (Array.isArray(obj) && obj.length > 0 && obj.every((item) => validate(item, format.name))) {
  25. return context.set(`${format.format}.list`, obj);
  26. }
  27. if (Object(obj) === obj || obj instanceof Map) {
  28. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  29. if (entries.length > 0 && entries.every(([, value]) => validate(value, format.name))) {
  30. return context.set(`${format.format}.map`, obj);
  31. }
  32. }
  33. }
  34. return null;
  35. }
  36. async open(context) {
  37. const metadata = await context.metadata('sklearn-metadata.json');
  38. return new sklearn.Model(metadata, context.type, context.value);
  39. }
  40. };
  41. sklearn.Model = class {
  42. constructor(metadata, type, obj) {
  43. const formats = new Map([
  44. ['sklearn', 'scikit-learn'],
  45. ['scipy', 'SciPy'],
  46. ['hmmlearn', 'hmmlearn']
  47. ]);
  48. this.format = formats.get(type.split('.').shift());
  49. this.modules = [];
  50. const version = [];
  51. switch (type) {
  52. case 'sklearn':
  53. case 'scipy':
  54. case 'hmmlearn': {
  55. if (obj._sklearn_version) {
  56. version.push(` v${obj._sklearn_version}`);
  57. }
  58. this.modules.push(new sklearn.Module(metadata, '', obj));
  59. break;
  60. }
  61. case 'sklearn.list':
  62. case 'scipy.list': {
  63. const list = obj;
  64. for (let i = 0; i < list.length; i++) {
  65. const obj = list[i];
  66. this.modules.push(new sklearn.Module(metadata, i.toString(), obj));
  67. if (obj._sklearn_version) {
  68. version.push(` v${obj._sklearn_version}`);
  69. }
  70. }
  71. break;
  72. }
  73. case 'sklearn.map':
  74. case 'scipy.map': {
  75. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  76. for (const [name, value] of entries) {
  77. this.modules.push(new sklearn.Module(metadata, name, value));
  78. if (value._sklearn_version) {
  79. version.push(` v${value._sklearn_version}`);
  80. }
  81. }
  82. break;
  83. }
  84. default: {
  85. throw new sklearn.Error(`Unsupported scikit-learn format '${type}'.`);
  86. }
  87. }
  88. if (version.length > 0 && version.every((value) => value === version[0])) {
  89. this.format += version[0];
  90. }
  91. }
  92. };
  93. sklearn.Module = class {
  94. constructor(metadata, name = '', obj = null) {
  95. this.name = name;
  96. this.nodes = [];
  97. this.inputs = [];
  98. this.outputs = [];
  99. const node = new sklearn.Node(metadata, '', obj);
  100. this.nodes.push(node);
  101. }
  102. };
  103. sklearn.Argument = class {
  104. constructor(name, value, type = null, visible = true) {
  105. this.name = name;
  106. this.value = value;
  107. this.type = type;
  108. this.visible = visible;
  109. }
  110. };
  111. sklearn.Value = class {
  112. constructor(name, type, initializer = null) {
  113. if (typeof name !== 'string') {
  114. throw new sklearn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  115. }
  116. this.name = name;
  117. this.type = initializer ? initializer.type : type;
  118. this.initializer = initializer;
  119. }
  120. };
  121. sklearn.Node = class {
  122. constructor(metadata, name, obj, stack) {
  123. this.name = name || '';
  124. const type = obj.__class__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : 'builtins.dict';
  125. this.type = metadata.type(type) || { name: type };
  126. this.inputs = [];
  127. this.outputs = [];
  128. const isObject = (obj) => {
  129. if (obj && typeof obj === 'object') {
  130. const proto = Object.getPrototypeOf(obj);
  131. return proto === Object.prototype || proto === null;
  132. }
  133. return false;
  134. };
  135. if (type === 'builtins.bytearray') {
  136. const attribute = new sklearn.Argument('value', Array.from(obj), 'byte[]');
  137. this.inputs.push(attribute);
  138. } else {
  139. const entries = Object.entries(obj);
  140. for (const [name, value] of entries) {
  141. if (name === '__class__') {
  142. continue;
  143. } else if (value && sklearn.Utility.isTensor(value)) {
  144. const tensor = new sklearn.Tensor(value);
  145. const argument = new sklearn.Argument(name, tensor, 'tensor');
  146. this.inputs.push(argument);
  147. } else if (Array.isArray(value) && value.length > 0 && value.every((obj) => sklearn.Utility.isTensor(obj))) {
  148. const tensors = value.map((obj) => new sklearn.Tensor(obj));
  149. const argument = new sklearn.Argument(name, tensors, 'tensor[]');
  150. this.inputs.push(argument);
  151. } else if (sklearn.Utility.isType(value, 'builtins.bytearray')) {
  152. const argument = new sklearn.Argument(name, Array.from(value), 'byte[]');
  153. this.inputs.push(argument);
  154. } else {
  155. stack = stack || new Set();
  156. if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'string')) {
  157. const argument = new sklearn.Argument(name, value, 'string[]');
  158. this.inputs.push(argument);
  159. } else if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'number')) {
  160. const argument = new sklearn.Argument(name, value, 'attribute');
  161. this.inputs.push(argument);
  162. } else if (sklearn.Utility.isType(value, 'builtins.function') || sklearn.Utility.isType(value, 'builtins.type')) {
  163. const node = new sklearn.Node(metadata, '', { __class__: value }, stack);
  164. const argument = new sklearn.Argument(name, node, 'object');
  165. this.inputs.push(argument);
  166. } else if (sklearn.Utility.isType(value, 'builtins.list') && value.every((value) => Array.isArray(value) && value.length === 2 && typeof value[0] === 'string')) {
  167. const chain = stack;
  168. const nodes = value.map(([name, value]) => {
  169. chain.add(value);
  170. const node = new sklearn.Node(metadata, name, value, chain);
  171. chain.delete(value);
  172. return node;
  173. });
  174. const argument = new sklearn.Argument(name, nodes, 'object[]');
  175. this.inputs.push(argument);
  176. } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
  177. const chain = stack;
  178. const values = value.filter((value) => !chain.has(value));
  179. const nodes = values.map((value) => {
  180. chain.add(value);
  181. const node = new sklearn.Node(metadata, '', value, null, chain);
  182. chain.delete(value);
  183. return node;
  184. });
  185. const argument = new sklearn.Argument(name, nodes, 'object[]');
  186. this.inputs.push(argument);
  187. } else if (value && (value.__class__ || isObject(value)) && !stack.has(value)) {
  188. stack.add(value);
  189. const node = new sklearn.Node(metadata, '', value, null, stack);
  190. const argument = new sklearn.Argument(name, node, 'object');
  191. this.inputs.push(argument);
  192. stack.delete(value);
  193. } else {
  194. let type = 'attribute';
  195. let visible = true;
  196. const schema = metadata.attribute(type, name);
  197. if (schema) {
  198. if (schema.type) {
  199. type = schema.type;
  200. }
  201. if (schema.visible === false || (schema.optional && value === null)) {
  202. visible = false;
  203. } else if (schema.default !== undefined) {
  204. if (Array.isArray(value)) {
  205. if (Array.isArray(schema.default)) {
  206. visible = value.length !== schema.default || !value.every((item, index) => item === metadata.default[index]);
  207. } else {
  208. visible = !value.every((item) => item === schema.default);
  209. }
  210. } else {
  211. visible = value !== schema.default;
  212. }
  213. }
  214. }
  215. const argument = new sklearn.Argument(name, value, type, visible);
  216. this.inputs.push(argument);
  217. }
  218. }
  219. }
  220. }
  221. }
  222. };
  223. sklearn.Tensor = class {
  224. constructor(array) {
  225. this.type = new sklearn.TensorType(array.dtype.__name__, new sklearn.TensorShape(array.shape));
  226. this.stride = Array.isArray(array.strides) ? array.strides.map((stride) => stride / array.itemsize) : null;
  227. this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
  228. this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
  229. }
  230. };
  231. sklearn.TensorType = class {
  232. constructor(dataType, shape) {
  233. this.dataType = dataType;
  234. this.shape = shape;
  235. }
  236. toString() {
  237. return this.dataType + this.shape.toString();
  238. }
  239. };
  240. sklearn.TensorShape = class {
  241. constructor(dimensions) {
  242. this.dimensions = dimensions;
  243. }
  244. toString() {
  245. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  246. }
  247. };
  248. sklearn.Utility = class {
  249. static isType(obj, name) {
  250. return obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ && `${obj.__class__.__module__}.${obj.__class__.__name__}` === name;
  251. }
  252. static isTensor = (obj) => {
  253. return sklearn.Utility.isType(obj, 'numpy.ndarray') || sklearn.Utility.isType(obj, 'numpy.matrix');
  254. };
  255. };
  256. sklearn.Error = class extends Error {
  257. constructor(message) {
  258. super(message);
  259. this.name = 'Error loading scikit-learn model.';
  260. }
  261. };
  262. export const ModelFactory = sklearn.ModelFactory;