lightgbm.js 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import * as python from './python.js';
  2. const lightgbm = {};
  3. lightgbm.ModelFactory = class {
  4. async match(context) {
  5. const stream = context.stream;
  6. const signature = [0x74, 0x72, 0x65, 0x65, 0x0A];
  7. if (stream && stream.length >= signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  8. return context.set('lightgbm.text');
  9. }
  10. const obj = await context.peek('pkl');
  11. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('lightgbm.')) {
  12. return context.set('lightgbm.pickle', obj);
  13. }
  14. return null;
  15. }
  16. async open(context) {
  17. switch (context.type) {
  18. case 'lightgbm.pickle': {
  19. const obj = context.value;
  20. return new lightgbm.Model(obj, 'LightGBM Pickle');
  21. }
  22. case 'lightgbm.text': {
  23. const stream = context.stream;
  24. const buffer = stream.peek();
  25. const decoder = new TextDecoder('utf-8');
  26. const model_str = decoder.decode(buffer);
  27. const execution = new python.Execution();
  28. const obj = execution.invoke('lightgbm.basic.Booster', []);
  29. obj.LoadModelFromString(model_str);
  30. return new lightgbm.Model(obj, 'LightGBM');
  31. }
  32. default: {
  33. throw new lightgbm.Error(`Unsupported LightGBM format '${context.type}'.`);
  34. }
  35. }
  36. }
  37. };
  38. lightgbm.Model = class {
  39. constructor(obj, format) {
  40. this.format = format + (obj && obj.version ? ` ${obj.version}` : '');
  41. this.modules = [new lightgbm.Graph(obj)];
  42. }
  43. };
  44. lightgbm.Graph = class {
  45. constructor(model) {
  46. this.inputs = [];
  47. this.outputs = [];
  48. this.nodes = [];
  49. const values = [];
  50. const feature_names = model.feature_names || [];
  51. for (let i = 0; i < feature_names.length; i++) {
  52. const name = feature_names[i];
  53. // const info = model.feature_infos && i < model.feature_infos.length ? model.feature_infos[i] : null;
  54. const value = new lightgbm.Value(name);
  55. values.push(value);
  56. if (feature_names.length < 1000) {
  57. const argument = new lightgbm.Argument(name, [value]);
  58. this.inputs.push(argument);
  59. }
  60. }
  61. const node = new lightgbm.Node(model, values);
  62. this.nodes.push(node);
  63. }
  64. };
  65. lightgbm.Argument = class {
  66. constructor(name, value, type = null) {
  67. this.name = name;
  68. this.value = value;
  69. this.type = type;
  70. }
  71. };
  72. lightgbm.Value = class {
  73. constructor(name) {
  74. if (typeof name !== 'string') {
  75. throw new lightgbm.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  76. }
  77. this.name = name;
  78. }
  79. };
  80. lightgbm.Node = class {
  81. constructor(obj, values, stack) {
  82. const type = obj && obj.__class__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : 'builtins.object';
  83. this.name = '';
  84. this.type = { name: type };
  85. this.inputs = [];
  86. this.outputs = [];
  87. this.attributes = [];
  88. if (values) {
  89. const argument = new lightgbm.Argument('features', values);
  90. this.inputs.push(argument);
  91. }
  92. const isObject = (obj) => {
  93. if (obj && typeof obj === 'object') {
  94. const proto = Object.getPrototypeOf(obj);
  95. return proto === Object.prototype || proto === null;
  96. }
  97. return false;
  98. };
  99. stack = stack || new Set();
  100. const entries = Object.entries(obj).filter(([key, value]) => value !== undefined && key !== 'feature_names' && key !== 'feature_infos');
  101. for (const [key, value] of entries) {
  102. if (Array.isArray(value) && value.every((obj) => isObject(obj))) {
  103. const values = value.filter((obj) => !stack.has(obj));
  104. const nodes = values.map((obj) => {
  105. stack.add(obj);
  106. const node = new lightgbm.Node(obj, null, stack);
  107. stack.delete(obj);
  108. return node;
  109. });
  110. const attribute = new lightgbm.Argument(key, nodes, 'object[]');
  111. this.attributes.push(attribute);
  112. continue;
  113. } else if (isObject(value) && !stack.has(value)) {
  114. stack.add(obj);
  115. const node = new lightgbm.Node(obj, null, stack);
  116. stack.delete(obj);
  117. const attribute = new lightgbm.Argument(key, node, 'object');
  118. this.attributes.push(attribute);
  119. } else {
  120. const attribute = new lightgbm.Argument(key, value);
  121. this.attributes.push(attribute);
  122. }
  123. }
  124. }
  125. };
  126. lightgbm.Error = class extends Error {
  127. constructor(message) {
  128. super(message);
  129. this.name = 'Error loading LightGBM model.';
  130. }
  131. };
  132. export const ModelFactory = lightgbm.ModelFactory;