lightgbm.js 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. /* jshint esversion: 6 */
  2. var lightgbm = lightgbm || {};
  3. var base = base || require('./base');
  4. lightgbm.ModelFactory = class {
  5. match(context) {
  6. try {
  7. const stream = context.stream;
  8. const reader = base.TextReader.open(stream.peek(), 65536);
  9. const line = reader.read();
  10. if (line === 'tree') {
  11. return true;
  12. }
  13. }
  14. catch (err) {
  15. // continue regardless of error
  16. }
  17. const obj = context.open('pkl');
  18. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('lightgbm.')) {
  19. return true;
  20. }
  21. return false;
  22. }
  23. open(context) {
  24. return new Promise((resolve, reject) => {
  25. try {
  26. let model;
  27. let format;
  28. const obj = context.open('pkl');
  29. if (obj) {
  30. format = 'LightGBM Pickle';
  31. model = obj;
  32. if (model && model.handle && typeof model.handle === 'string') {
  33. const reader = base.TextReader.open(model.handle);
  34. model = new lightgbm.basic.Booster(reader);
  35. }
  36. }
  37. else {
  38. format = 'LightGBM';
  39. const stream = context.stream;
  40. const buffer = stream.peek();
  41. const reader = base.TextReader.open(buffer);
  42. model = new lightgbm.basic.Booster(reader);
  43. }
  44. resolve(new lightgbm.Model(model, format));
  45. }
  46. catch (err) {
  47. reject(err);
  48. }
  49. });
  50. }
  51. };
  52. lightgbm.Model = class {
  53. constructor(model, format) {
  54. this._format = format + (model.meta && model.meta.version ? ' ' + model.meta.version : '');
  55. this._graphs = [ new lightgbm.Graph(model) ];
  56. }
  57. get format() {
  58. return this._format;
  59. }
  60. get graphs() {
  61. return this._graphs;
  62. }
  63. };
  64. lightgbm.Graph = class {
  65. constructor(model) {
  66. this._inputs = [];
  67. this._outputs = [];
  68. this._nodes = [];
  69. const args = [];
  70. if (model.meta && model.meta.feature_names) {
  71. const feature_names = model.meta.feature_names.split(' ').map((item) => item.trim());
  72. for (const feature_name of feature_names) {
  73. const arg = new lightgbm.Argument(feature_name);
  74. args.push(arg);
  75. if (feature_names.length < 1000) {
  76. this._inputs.push(new lightgbm.Parameter(feature_name, [ arg ]));
  77. }
  78. }
  79. }
  80. this._nodes.push(new lightgbm.Node(model, args));
  81. }
  82. get inputs() {
  83. return this._inputs;
  84. }
  85. get outputs() {
  86. return this._outputs;
  87. }
  88. get nodes() {
  89. return this._nodes;
  90. }
  91. };
  92. lightgbm.Parameter = class {
  93. constructor(name, args) {
  94. this._name = name;
  95. this._arguments = args;
  96. }
  97. get name() {
  98. return this._name;
  99. }
  100. get visible() {
  101. return true;
  102. }
  103. get arguments() {
  104. return this._arguments;
  105. }
  106. };
  107. lightgbm.Argument = class {
  108. constructor(name) {
  109. if (typeof name !== 'string') {
  110. throw new lightgbm.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  111. }
  112. this._name = name;
  113. }
  114. get name() {
  115. return this._name;
  116. }
  117. get type() {
  118. return null;
  119. }
  120. get initializer() {
  121. return null;
  122. }
  123. };
  124. lightgbm.Node = class {
  125. constructor(model, args) {
  126. this._type = model.__class__.__module__ + '.' + model.__class__.__name__;
  127. this._inputs = [];
  128. this._outputs = [];
  129. this._attributes = [];
  130. this._inputs.push(new lightgbm.Parameter('features', args));
  131. for (const key of Object.keys(model.params)) {
  132. this._attributes.push(new lightgbm.Attribute(key, model.params[key]));
  133. }
  134. }
  135. get type() {
  136. return this._type;
  137. }
  138. get name() {
  139. return '';
  140. }
  141. get inputs() {
  142. return this._inputs;
  143. }
  144. get outputs() {
  145. return this._outputs;
  146. }
  147. get attributes() {
  148. return this._attributes;
  149. }
  150. };
  151. lightgbm.Attribute = class {
  152. constructor(name, value) {
  153. this._name = name;
  154. this._value = value;
  155. }
  156. get name() {
  157. return this._name;
  158. }
  159. get value() {
  160. return this._value;
  161. }
  162. };
  163. lightgbm.basic = {};
  164. lightgbm.basic.Booster = class {
  165. constructor(reader) {
  166. this.__class__ = {
  167. __module__: 'lightgbm.basic',
  168. __name__: 'Booster'
  169. };
  170. this.params = {};
  171. this.feature_importances = {};
  172. this.meta = {};
  173. this.trees = [];
  174. // GBDT::LoadModelFromString() in https://github.com/microsoft/LightGBM/blob/master/src/boosting/gbdt_model_text.cpp
  175. const signature = reader.read();
  176. if (!signature || signature.trim() !== 'tree') {
  177. throw new lightgbm.Error("Invalid signature '" + signature.trim() + "'.");
  178. }
  179. let state = '';
  180. let tree = null;
  181. // let lineNumber = 0;
  182. for (;;) {
  183. // lineNumber++;
  184. const text = reader.read();
  185. if (text === undefined) {
  186. break;
  187. }
  188. const line = text.trim();
  189. if (line.length === 0) {
  190. continue;
  191. }
  192. if (line.startsWith('Tree=')) {
  193. state = 'tree';
  194. tree = { index: parseInt(line.split('=').pop(), 10) };
  195. this.trees.push(tree);
  196. continue;
  197. }
  198. else if (line === 'parameters:') {
  199. state = 'param';
  200. continue;
  201. }
  202. else if (line === 'feature_importances:' || line === 'feature importances:') {
  203. state = 'feature_importances';
  204. continue;
  205. }
  206. else if (line === 'end of trees' || line === 'end of parameters') {
  207. state = '';
  208. continue;
  209. }
  210. else if (line.startsWith('pandas_categorical:')) {
  211. state = 'pandas_categorical';
  212. continue;
  213. }
  214. switch (state) {
  215. case '': {
  216. const param = line.split('=');
  217. if (param.length !== 2 && !/^[A-Za-z0-9_]/.exec(param[0].trim())) {
  218. throw new lightgbm.Error("Invalid property '" + line + "'.");
  219. }
  220. const name = param[0].trim();
  221. const value = param.length > 1 ? param[1].trim() : undefined;
  222. this.meta[name] = value;
  223. break;
  224. }
  225. case 'param': {
  226. if (!line.startsWith('[') || !line.endsWith(']')) {
  227. throw new lightgbm.Error("Invalid parameter '" + line + "'.");
  228. }
  229. const param = line.substring(1, line.length - 2).split(':');
  230. if (param.length !== 2) {
  231. throw new lightgbm.Error("Invalid param '" + line + "'.");
  232. }
  233. const name = param[0].trim();
  234. const value = param[1].trim();
  235. this.params[name] = value;
  236. break;
  237. }
  238. case 'tree': {
  239. const param = line.split('=');
  240. if (param.length !== 2) {
  241. throw new lightgbm.Error("Invalid property '" + line + "'.");
  242. }
  243. const name = param[0].trim();
  244. const value = param[1].trim();
  245. tree[name] = value;
  246. break;
  247. }
  248. case 'feature_importances': {
  249. const param = line.split('=');
  250. if (param.length !== 2) {
  251. throw new lightgbm.Error("Invalid feature importance '" + line + "'.");
  252. }
  253. const name = param[0].trim();
  254. const value = param[1].trim();
  255. this.feature_importances[name] = value;
  256. break;
  257. }
  258. case 'pandas_categorical': {
  259. break;
  260. }
  261. }
  262. }
  263. }
  264. };
  265. lightgbm.Error = class extends Error {
  266. constructor(message) {
  267. super(message);
  268. this.name = 'Error loading LightGBM model.';
  269. }
  270. };
  271. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  272. module.exports.ModelFactory = lightgbm.ModelFactory;
  273. }