|
|
@@ -1,16 +1,15 @@
|
|
|
/* jshint esversion: 6 */
|
|
|
|
|
|
var lightgbm = lightgbm || {};
|
|
|
-var base = base || require('./base');
|
|
|
+var python = python || require('./python');
|
|
|
|
|
|
lightgbm.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
try {
|
|
|
const stream = context.stream;
|
|
|
- const reader = base.TextReader.open(stream.peek(), 65536);
|
|
|
- const line = reader.read();
|
|
|
- if (line === 'tree') {
|
|
|
+ const signature = [ 0x74, 0x72, 0x65, 0x65, 0x0A ];
|
|
|
+ if (stream.length >= signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
return 'lightgbm.text';
|
|
|
}
|
|
|
}
|
|
|
@@ -27,28 +26,26 @@ lightgbm.ModelFactory = class {
|
|
|
open(context, match) {
|
|
|
return new Promise((resolve, reject) => {
|
|
|
try {
|
|
|
- let model;
|
|
|
+ let obj;
|
|
|
let format;
|
|
|
switch (match) {
|
|
|
case 'lightgbm.pickle': {
|
|
|
+ obj = context.open('pkl');
|
|
|
format = 'LightGBM Pickle';
|
|
|
- const obj = context.open('pkl');
|
|
|
- model = obj;
|
|
|
- if (model && model.handle && typeof model.handle === 'string') {
|
|
|
- const reader = base.TextReader.open(model.handle);
|
|
|
- model = new lightgbm.basic.Booster(reader);
|
|
|
- }
|
|
|
break;
|
|
|
}
|
|
|
case 'lightgbm.text': {
|
|
|
- format = 'LightGBM';
|
|
|
const stream = context.stream;
|
|
|
const buffer = stream.peek();
|
|
|
- const reader = base.TextReader.open(buffer);
|
|
|
- model = new lightgbm.basic.Booster(reader);
|
|
|
+ const decoder = new TextDecoder('utf-8');
|
|
|
+ const model_str = decoder.decode(buffer);
|
|
|
+ const execution = new python.Execution(null);
|
|
|
+ obj = execution.invoke('lightgbm.basic.Booster', []);
|
|
|
+ obj.LoadModelFromString(model_str);
|
|
|
+ format = 'LightGBM';
|
|
|
}
|
|
|
}
|
|
|
- resolve(new lightgbm.Model(model, format));
|
|
|
+ resolve(new lightgbm.Model(obj, format));
|
|
|
}
|
|
|
catch (err) {
|
|
|
reject(err);
|
|
|
@@ -59,9 +56,9 @@ lightgbm.ModelFactory = class {
|
|
|
|
|
|
lightgbm.Model = class {
|
|
|
|
|
|
- constructor(model, format) {
|
|
|
- this._format = format + (model.meta && model.meta.version ? ' ' + model.meta.version : '');
|
|
|
- this._graphs = [ new lightgbm.Graph(model) ];
|
|
|
+ constructor(obj, format) {
|
|
|
+ this._format = format + (obj && obj.version ? ' ' + obj.version : '');
|
|
|
+ this._graphs = [ new lightgbm.Graph(obj) ];
|
|
|
}
|
|
|
|
|
|
get format() {
|
|
|
@@ -81,17 +78,16 @@ lightgbm.Graph = class {
|
|
|
this._nodes = [];
|
|
|
|
|
|
const args = [];
|
|
|
- if (model.meta && model.meta.feature_names) {
|
|
|
- const feature_names = model.meta.feature_names.split(' ').map((item) => item.trim());
|
|
|
- for (const feature_name of feature_names) {
|
|
|
- const arg = new lightgbm.Argument(feature_name);
|
|
|
- args.push(arg);
|
|
|
- if (feature_names.length < 1000) {
|
|
|
- this._inputs.push(new lightgbm.Parameter(feature_name, [ arg ]));
|
|
|
- }
|
|
|
+ const feature_names = model.feature_names || [];
|
|
|
+ for (let i = 0; i < feature_names.length; i++) {
|
|
|
+ const name = feature_names[i];
|
|
|
+ const info = model.feature_infos && i < model.feature_infos.length ? model.feature_infos[i] : null;
|
|
|
+ const argument = new lightgbm.Argument(name, info);
|
|
|
+ args.push(argument);
|
|
|
+ if (feature_names.length < 1000) {
|
|
|
+ this._inputs.push(new lightgbm.Parameter(name, [ argument ]));
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
this._nodes.push(new lightgbm.Node(model, args));
|
|
|
}
|
|
|
|
|
|
@@ -130,11 +126,12 @@ lightgbm.Parameter = class {
|
|
|
|
|
|
lightgbm.Argument = class {
|
|
|
|
|
|
- constructor(name) {
|
|
|
+ constructor(name, quantization) {
|
|
|
if (typeof name !== 'string') {
|
|
|
throw new lightgbm.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
|
|
|
}
|
|
|
this._name = name;
|
|
|
+ this._quantization = quantization;
|
|
|
}
|
|
|
|
|
|
get name() {
|
|
|
@@ -145,6 +142,10 @@ lightgbm.Argument = class {
|
|
|
return null;
|
|
|
}
|
|
|
|
|
|
+ get quantization() {
|
|
|
+ return this._quantization;
|
|
|
+ }
|
|
|
+
|
|
|
get initializer() {
|
|
|
return null;
|
|
|
}
|
|
|
@@ -159,8 +160,21 @@ lightgbm.Node = class {
|
|
|
this._outputs = [];
|
|
|
this._attributes = [];
|
|
|
this._inputs.push(new lightgbm.Parameter('features', args));
|
|
|
- for (const key of Object.keys(model.params)) {
|
|
|
- this._attributes.push(new lightgbm.Attribute(key, model.params[key]));
|
|
|
+ for (const entry of Object.entries(model)) {
|
|
|
+ const key = entry[0];
|
|
|
+ const value = entry[1];
|
|
|
+ if (value === undefined) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ switch (key) {
|
|
|
+ case 'tree':
|
|
|
+ case 'version':
|
|
|
+ case 'feature_names':
|
|
|
+ case 'feature_infos':
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ this._attributes.push(new lightgbm.Attribute(key, value));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -201,114 +215,6 @@ lightgbm.Attribute = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-lightgbm.basic = {};
|
|
|
-
|
|
|
-lightgbm.basic.Booster = class {
|
|
|
-
|
|
|
- constructor(reader) {
|
|
|
-
|
|
|
- this.__class__ = {
|
|
|
- __module__: 'lightgbm.basic',
|
|
|
- __name__: 'Booster'
|
|
|
- };
|
|
|
-
|
|
|
- this.params = {};
|
|
|
- this.feature_importances = {};
|
|
|
- this.meta = {};
|
|
|
- this.trees = [];
|
|
|
-
|
|
|
- // GBDT::LoadModelFromString() in https://github.com/microsoft/LightGBM/blob/master/src/boosting/gbdt_model_text.cpp
|
|
|
- const signature = reader.read();
|
|
|
- if (!signature || signature.trim() !== 'tree') {
|
|
|
- throw new lightgbm.Error("Invalid signature '" + signature.trim() + "'.");
|
|
|
- }
|
|
|
- let state = '';
|
|
|
- let tree = null;
|
|
|
- // let lineNumber = 0;
|
|
|
- for (;;) {
|
|
|
- // lineNumber++;
|
|
|
- const text = reader.read();
|
|
|
- if (text === undefined) {
|
|
|
- break;
|
|
|
- }
|
|
|
- const line = text.trim();
|
|
|
- if (line.length === 0) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (line.startsWith('Tree=')) {
|
|
|
- state = 'tree';
|
|
|
- tree = { index: parseInt(line.split('=').pop(), 10) };
|
|
|
- this.trees.push(tree);
|
|
|
- continue;
|
|
|
- }
|
|
|
- else if (line === 'parameters:') {
|
|
|
- state = 'param';
|
|
|
- continue;
|
|
|
- }
|
|
|
- else if (line === 'feature_importances:' || line === 'feature importances:') {
|
|
|
- state = 'feature_importances';
|
|
|
- continue;
|
|
|
- }
|
|
|
- else if (line === 'end of trees' || line === 'end of parameters') {
|
|
|
- state = '';
|
|
|
- continue;
|
|
|
- }
|
|
|
- else if (line.startsWith('pandas_categorical:')) {
|
|
|
- state = 'pandas_categorical';
|
|
|
- continue;
|
|
|
- }
|
|
|
- switch (state) {
|
|
|
- case '': {
|
|
|
- const param = line.split('=');
|
|
|
- if (param.length !== 2 && !/^[A-Za-z0-9_]/.exec(param[0].trim())) {
|
|
|
- throw new lightgbm.Error("Invalid property '" + line + "'.");
|
|
|
- }
|
|
|
- const name = param[0].trim();
|
|
|
- const value = param.length > 1 ? param[1].trim() : undefined;
|
|
|
- this.meta[name] = value;
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'param': {
|
|
|
- if (!line.startsWith('[') || !line.endsWith(']')) {
|
|
|
- throw new lightgbm.Error("Invalid parameter '" + line + "'.");
|
|
|
- }
|
|
|
- const param = line.substring(1, line.length - 2).split(':');
|
|
|
- if (param.length !== 2) {
|
|
|
- throw new lightgbm.Error("Invalid param '" + line + "'.");
|
|
|
- }
|
|
|
- const name = param[0].trim();
|
|
|
- const value = param[1].trim();
|
|
|
- this.params[name] = value;
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'tree': {
|
|
|
- const param = line.split('=');
|
|
|
- if (param.length !== 2) {
|
|
|
- throw new lightgbm.Error("Invalid property '" + line + "'.");
|
|
|
- }
|
|
|
- const name = param[0].trim();
|
|
|
- const value = param[1].trim();
|
|
|
- tree[name] = value;
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'feature_importances': {
|
|
|
- const param = line.split('=');
|
|
|
- if (param.length !== 2) {
|
|
|
- throw new lightgbm.Error("Invalid feature importance '" + line + "'.");
|
|
|
- }
|
|
|
- const name = param[0].trim();
|
|
|
- const value = param[1].trim();
|
|
|
- this.feature_importances[name] = value;
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'pandas_categorical': {
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
lightgbm.Error = class extends Error {
|
|
|
|
|
|
constructor(message) {
|