lightgbm.js 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import * as python from './python.js';
  2. const lightgbm = {};
  3. lightgbm.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. const signature = [ 0x74, 0x72, 0x65, 0x65, 0x0A ];
  7. if (stream && stream.length >= signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  8. context.type = 'lightgbm.text';
  9. return;
  10. }
  11. const obj = context.peek('pkl');
  12. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('lightgbm.')) {
  13. context.type = 'lightgbm.pickle';
  14. context.target = obj;
  15. return;
  16. }
  17. }
  18. async open(context) {
  19. switch (context.type) {
  20. case 'lightgbm.pickle': {
  21. const obj = context.target;
  22. return new lightgbm.Model(obj, 'LightGBM Pickle');
  23. }
  24. case 'lightgbm.text': {
  25. const stream = context.stream;
  26. const buffer = stream.peek();
  27. const decoder = new TextDecoder('utf-8');
  28. const model_str = decoder.decode(buffer);
  29. const execution = new python.Execution();
  30. const obj = execution.invoke('lightgbm.basic.Booster', []);
  31. obj.LoadModelFromString(model_str);
  32. return new lightgbm.Model(obj, 'LightGBM');
  33. }
  34. default: {
  35. throw new lightgbm.Error(`Unsupported LightGBM format '${context.type}'.`);
  36. }
  37. }
  38. }
  39. };
  40. lightgbm.Model = class {
  41. constructor(obj, format) {
  42. this.format = format + (obj && obj.version ? ` ${obj.version}` : '');
  43. this.graphs = [ new lightgbm.Graph(obj) ];
  44. }
  45. };
  46. lightgbm.Graph = class {
  47. constructor(model) {
  48. this.inputs = [];
  49. this.outputs = [];
  50. this.nodes = [];
  51. const values = [];
  52. const feature_names = model.feature_names || [];
  53. for (let i = 0; i < feature_names.length; i++) {
  54. const name = feature_names[i];
  55. // const info = model.feature_infos && i < model.feature_infos.length ? model.feature_infos[i] : null;
  56. const value = new lightgbm.Value(name);
  57. values.push(value);
  58. if (feature_names.length < 1000) {
  59. const argument = new lightgbm.Argument(name, [ value ]);
  60. this.inputs.push(argument);
  61. }
  62. }
  63. const node = new lightgbm.Node(model, values);
  64. this.nodes.push(node);
  65. }
  66. };
  67. lightgbm.Argument = class {
  68. constructor(name, value) {
  69. this.name = name;
  70. this.value = value;
  71. }
  72. };
  73. lightgbm.Value = class {
  74. constructor(name) {
  75. if (typeof name !== 'string') {
  76. throw new lightgbm.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  77. }
  78. this.name = name;
  79. }
  80. };
  81. lightgbm.Node = class {
  82. constructor(obj, values, stack) {
  83. const type = obj && obj.__class__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : 'builtins.object';
  84. this.name = '';
  85. this.type = { name: type };
  86. this.inputs = [];
  87. this.outputs = [];
  88. this.attributes = [];
  89. if (values) {
  90. const argument = new lightgbm.Argument('features', values);
  91. this.inputs.push(argument);
  92. }
  93. const isObject = (obj) => {
  94. if (obj && typeof obj === 'object') {
  95. const proto = Object.getPrototypeOf(obj);
  96. return proto === Object.prototype || proto === null;
  97. }
  98. return false;
  99. };
  100. stack = stack || new Set();
  101. const entries = Object.entries(obj).filter(([key, value]) => value !== undefined && key !== 'feature_names' && key !== 'feature_infos');
  102. for (const [key, value] of entries) {
  103. if (Array.isArray(value) && value.every((obj) => isObject(obj))) {
  104. const values = value.filter((obj) => !stack.has(obj));
  105. const nodes = values.map((obj) => {
  106. stack.add(obj);
  107. const node = new lightgbm.Node(obj, null, stack);
  108. stack.delete(obj);
  109. return node;
  110. });
  111. const attribute = new lightgbm.Attribute('object[]', key, nodes);
  112. this.attributes.push(attribute);
  113. continue;
  114. } else if (isObject(value) && !stack.has(value)) {
  115. stack.add(obj);
  116. const node = new lightgbm.Node(obj, null, stack);
  117. stack.delete(obj);
  118. const attribute = new lightgbm.Attribute('object', key, node);
  119. this.attributes.push(attribute);
  120. } else {
  121. const attribute = new lightgbm.Attribute(null, key, value);
  122. this.attributes.push(attribute);
  123. }
  124. }
  125. }
  126. };
  127. lightgbm.Attribute = class {
  128. constructor(type, name, value) {
  129. this.type = type;
  130. this.name = name;
  131. this.value = value;
  132. }
  133. };
  134. lightgbm.Error = class extends Error {
  135. constructor(message) {
  136. super(message);
  137. this.name = 'Error loading LightGBM model.';
  138. }
  139. };
  140. export const ModelFactory = lightgbm.ModelFactory;