lightgbm.js 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. var lightgbm = lightgbm || {};
  2. var python = python || require('./python');
  3. lightgbm.ModelFactory = class {
  4. match(context) {
  5. try {
  6. const stream = context.stream;
  7. const signature = [ 0x74, 0x72, 0x65, 0x65, 0x0A ];
  8. if (stream && stream.length >= signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  9. return 'lightgbm.text';
  10. }
  11. }
  12. catch (err) {
  13. // continue regardless of error
  14. }
  15. const obj = context.open('pkl');
  16. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('lightgbm.')) {
  17. return 'lightgbm.pickle';
  18. }
  19. return '';
  20. }
  21. open(context, match) {
  22. return new Promise((resolve, reject) => {
  23. try {
  24. let obj;
  25. let format;
  26. switch (match) {
  27. case 'lightgbm.pickle': {
  28. obj = context.open('pkl');
  29. format = 'LightGBM Pickle';
  30. break;
  31. }
  32. case 'lightgbm.text': {
  33. const stream = context.stream;
  34. const buffer = stream.peek();
  35. const decoder = new TextDecoder('utf-8');
  36. const model_str = decoder.decode(buffer);
  37. const execution = new python.Execution(null);
  38. obj = execution.invoke('lightgbm.basic.Booster', []);
  39. obj.LoadModelFromString(model_str);
  40. format = 'LightGBM';
  41. break;
  42. }
  43. default: {
  44. throw new lightgbm.Error("Unsupported LightGBM format '" + match + "'.");
  45. }
  46. }
  47. resolve(new lightgbm.Model(obj, format));
  48. }
  49. catch (err) {
  50. reject(err);
  51. }
  52. });
  53. }
  54. };
  55. lightgbm.Model = class {
  56. constructor(obj, format) {
  57. this._format = format + (obj && obj.version ? ' ' + obj.version : '');
  58. this._graphs = [ new lightgbm.Graph(obj) ];
  59. }
  60. get format() {
  61. return this._format;
  62. }
  63. get graphs() {
  64. return this._graphs;
  65. }
  66. };
  67. lightgbm.Graph = class {
  68. constructor(model) {
  69. this._inputs = [];
  70. this._outputs = [];
  71. this._nodes = [];
  72. const args = [];
  73. const feature_names = model.feature_names || [];
  74. for (let i = 0; i < feature_names.length; i++) {
  75. const name = feature_names[i];
  76. const info = model.feature_infos && i < model.feature_infos.length ? model.feature_infos[i] : null;
  77. const argument = new lightgbm.Argument(name, info);
  78. args.push(argument);
  79. if (feature_names.length < 1000) {
  80. this._inputs.push(new lightgbm.Parameter(name, [ argument ]));
  81. }
  82. }
  83. this._nodes.push(new lightgbm.Node(model, args));
  84. }
  85. get inputs() {
  86. return this._inputs;
  87. }
  88. get outputs() {
  89. return this._outputs;
  90. }
  91. get nodes() {
  92. return this._nodes;
  93. }
  94. };
  95. lightgbm.Parameter = class {
  96. constructor(name, args) {
  97. this._name = name;
  98. this._arguments = args;
  99. }
  100. get name() {
  101. return this._name;
  102. }
  103. get visible() {
  104. return true;
  105. }
  106. get arguments() {
  107. return this._arguments;
  108. }
  109. };
  110. lightgbm.Argument = class {
  111. constructor(name, quantization) {
  112. if (typeof name !== 'string') {
  113. throw new lightgbm.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  114. }
  115. this._name = name;
  116. this._quantization = quantization;
  117. }
  118. get name() {
  119. return this._name;
  120. }
  121. get type() {
  122. return null;
  123. }
  124. get quantization() {
  125. return this._quantization;
  126. }
  127. get initializer() {
  128. return null;
  129. }
  130. };
  131. lightgbm.Node = class {
  132. constructor(model, args) {
  133. const type = model.__class__.__module__ + '.' + model.__class__.__name__;
  134. this._type = { name: type };
  135. this._inputs = [];
  136. this._outputs = [];
  137. this._attributes = [];
  138. this._inputs.push(new lightgbm.Parameter('features', args));
  139. for (const entry of Object.entries(model)) {
  140. const key = entry[0];
  141. const value = entry[1];
  142. if (value === undefined) {
  143. continue;
  144. }
  145. switch (key) {
  146. case 'tree':
  147. case 'version':
  148. case 'feature_names':
  149. case 'feature_infos':
  150. break;
  151. default:
  152. this._attributes.push(new lightgbm.Attribute(key, value));
  153. }
  154. }
  155. }
  156. get type() {
  157. return this._type;
  158. }
  159. get name() {
  160. return '';
  161. }
  162. get inputs() {
  163. return this._inputs;
  164. }
  165. get outputs() {
  166. return this._outputs;
  167. }
  168. get attributes() {
  169. return this._attributes;
  170. }
  171. };
  172. lightgbm.Attribute = class {
  173. constructor(name, value) {
  174. this._name = name;
  175. this._value = value;
  176. }
  177. get name() {
  178. return this._name;
  179. }
  180. get value() {
  181. return this._value;
  182. }
  183. };
  184. lightgbm.Error = class extends Error {
  185. constructor(message) {
  186. super(message);
  187. this.name = 'Error loading LightGBM model.';
  188. }
  189. };
  190. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  191. module.exports.ModelFactory = lightgbm.ModelFactory;
  192. }