lightgbm.js 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. var lightgbm = {};
  2. var python = require('./python');
  3. lightgbm.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. const signature = [ 0x74, 0x72, 0x65, 0x65, 0x0A ];
  7. if (stream && stream.length >= signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  8. return 'lightgbm.text';
  9. }
  10. const obj = context.open('pkl');
  11. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('lightgbm.')) {
  12. return 'lightgbm.pickle';
  13. }
  14. return null;
  15. }
  16. open(context, match) {
  17. return new Promise((resolve, reject) => {
  18. try {
  19. let obj;
  20. let format;
  21. switch (match) {
  22. case 'lightgbm.pickle': {
  23. obj = context.open('pkl');
  24. format = 'LightGBM Pickle';
  25. break;
  26. }
  27. case 'lightgbm.text': {
  28. const stream = context.stream;
  29. const buffer = stream.peek();
  30. const decoder = new TextDecoder('utf-8');
  31. const model_str = decoder.decode(buffer);
  32. const execution = new python.Execution();
  33. obj = execution.invoke('lightgbm.basic.Booster', []);
  34. obj.LoadModelFromString(model_str);
  35. format = 'LightGBM';
  36. break;
  37. }
  38. default: {
  39. throw new lightgbm.Error("Unsupported LightGBM format '" + match + "'.");
  40. }
  41. }
  42. resolve(new lightgbm.Model(obj, format));
  43. }
  44. catch (err) {
  45. reject(err);
  46. }
  47. });
  48. }
  49. };
  50. lightgbm.Model = class {
  51. constructor(obj, format) {
  52. this._format = format + (obj && obj.version ? ' ' + obj.version : '');
  53. this._graphs = [ new lightgbm.Graph(obj) ];
  54. }
  55. get format() {
  56. return this._format;
  57. }
  58. get graphs() {
  59. return this._graphs;
  60. }
  61. };
  62. lightgbm.Graph = class {
  63. constructor(model) {
  64. this._inputs = [];
  65. this._outputs = [];
  66. this._nodes = [];
  67. const args = [];
  68. const feature_names = model.feature_names || [];
  69. for (let i = 0; i < feature_names.length; i++) {
  70. const name = feature_names[i];
  71. const info = model.feature_infos && i < model.feature_infos.length ? model.feature_infos[i] : null;
  72. const argument = new lightgbm.Argument(name, info);
  73. args.push(argument);
  74. if (feature_names.length < 1000) {
  75. this._inputs.push(new lightgbm.Parameter(name, [ argument ]));
  76. }
  77. }
  78. this._nodes.push(new lightgbm.Node(model, args));
  79. }
  80. get inputs() {
  81. return this._inputs;
  82. }
  83. get outputs() {
  84. return this._outputs;
  85. }
  86. get nodes() {
  87. return this._nodes;
  88. }
  89. };
  90. lightgbm.Parameter = class {
  91. constructor(name, args) {
  92. this._name = name;
  93. this._arguments = args;
  94. }
  95. get name() {
  96. return this._name;
  97. }
  98. get visible() {
  99. return true;
  100. }
  101. get arguments() {
  102. return this._arguments;
  103. }
  104. };
  105. lightgbm.Argument = class {
  106. constructor(name, quantization) {
  107. if (typeof name !== 'string') {
  108. throw new lightgbm.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  109. }
  110. this._name = name;
  111. this._quantization = quantization;
  112. }
  113. get name() {
  114. return this._name;
  115. }
  116. get type() {
  117. return null;
  118. }
  119. get quantization() {
  120. return this._quantization;
  121. }
  122. get initializer() {
  123. return null;
  124. }
  125. };
  126. lightgbm.Node = class {
  127. constructor(model, args) {
  128. const type = model.__class__.__module__ + '.' + model.__class__.__name__;
  129. this._type = { name: type };
  130. this._inputs = [];
  131. this._outputs = [];
  132. this._attributes = [];
  133. this._inputs.push(new lightgbm.Parameter('features', args));
  134. for (const entry of Object.entries(model)) {
  135. const key = entry[0];
  136. const value = entry[1];
  137. if (value === undefined) {
  138. continue;
  139. }
  140. switch (key) {
  141. case 'tree':
  142. case 'version':
  143. case 'feature_names':
  144. case 'feature_infos':
  145. break;
  146. default:
  147. this._attributes.push(new lightgbm.Attribute(key, value));
  148. }
  149. }
  150. }
  151. get type() {
  152. return this._type;
  153. }
  154. get name() {
  155. return '';
  156. }
  157. get inputs() {
  158. return this._inputs;
  159. }
  160. get outputs() {
  161. return this._outputs;
  162. }
  163. get attributes() {
  164. return this._attributes;
  165. }
  166. };
  167. lightgbm.Attribute = class {
  168. constructor(name, value) {
  169. this._name = name;
  170. this._value = value;
  171. }
  172. get name() {
  173. return this._name;
  174. }
  175. get value() {
  176. return this._value;
  177. }
  178. };
  179. lightgbm.Error = class extends Error {
  180. constructor(message) {
  181. super(message);
  182. this.name = 'Error loading LightGBM model.';
  183. }
  184. };
  185. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  186. module.exports.ModelFactory = lightgbm.ModelFactory;
  187. }