xgboost.js 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. // Experimental
  2. import * as python from './python.js';
  3. const xgboost = {};
  4. xgboost.ModelFactory = class {
  5. async match(context) {
  6. const obj = await context.peek('json');
  7. if (obj && obj.learner && obj.version && Object.keys(obj).length < 256) {
  8. return context.set('xgboost.json', obj);
  9. }
  10. const stream = context.stream;
  11. if (stream && stream.length > 4) {
  12. const buffer = stream.peek(4);
  13. if (buffer[0] === 0x7B && buffer[1] === 0x4C && buffer[2] === 0x00 && buffer[3] === 0x00) {
  14. return context.set('xgboost.ubj', stream);
  15. }
  16. const signature = String.fromCharCode.apply(null, buffer);
  17. if (signature.startsWith('binf')) {
  18. return context.set('xgboost.binf', stream);
  19. }
  20. if (signature.startsWith('bs64')) {
  21. return context.set('xgboost.bs64', stream);
  22. }
  23. const reader = await context.read('text', 0x100);
  24. const line = reader.read('\n');
  25. if (line !== undefined && line.trim() === 'booster[0]:') {
  26. return context.set('xgboost.text', stream);
  27. }
  28. }
  29. return null;
  30. }
  31. async open(context) {
  32. if (context.type === 'xgboost.json') {
  33. const execution = new python.Execution();
  34. const model = execution.invoke('xgboost.core.Booster', []);
  35. model.load_model(context.value);
  36. throw new xgboost.Error('File contains unsupported XGBoost JSON data.');
  37. }
  38. if (context.type === 'xgboost.text') {
  39. throw new xgboost.Error('File contains unsupported XGBoost text data.');
  40. }
  41. throw new xgboost.Error('File contains unsupported XGBoost data.');
  42. }
  43. };
  44. xgboost.Error = class extends Error {
  45. constructor(message) {
  46. super(message);
  47. this.name = 'Error loading XGBoost model.';
  48. }
  49. };
  50. export const ModelFactory = xgboost.ModelFactory;