|
@@ -17,14 +17,33 @@ lightgbm.ModelFactory = class {
|
|
|
catch (err) {
|
|
catch (err) {
|
|
|
// continue regardless of error
|
|
// continue regardless of error
|
|
|
}
|
|
}
|
|
|
|
|
+ const tags = context.tags('pkl');
|
|
|
|
|
+ if (tags.size === 1 && tags.keys().next().value.startsWith('lightgbm.')) {
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
open(context) {
|
|
open(context) {
|
|
|
return new Promise((resolve, reject) => {
|
|
return new Promise((resolve, reject) => {
|
|
|
try {
|
|
try {
|
|
|
- const booster = new lightgbm.basic.Booster(context.stream);
|
|
|
|
|
- resolve(new lightgbm.Model(booster));
|
|
|
|
|
|
|
+ const tags = context.tags('pkl');
|
|
|
|
|
+ let model;
|
|
|
|
|
+ let format;
|
|
|
|
|
+ if (tags.size === 1) {
|
|
|
|
|
+ format = 'LightGBM Pickle';
|
|
|
|
|
+ model = tags.values().next().value;
|
|
|
|
|
+ if (model && model.handle && typeof model.handle === 'string') {
|
|
|
|
|
+ const reader = base.TextReader.create(model.handle);
|
|
|
|
|
+ model = new lightgbm.basic.Booster(reader);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ else {
|
|
|
|
|
+ format = 'LightGBM';
|
|
|
|
|
+ const reader = base.TextReader.create(context.stream.peek());
|
|
|
|
|
+ model = new lightgbm.basic.Booster(reader);
|
|
|
|
|
+ }
|
|
|
|
|
+ resolve(new lightgbm.Model(model, format));
|
|
|
}
|
|
}
|
|
|
catch (err) {
|
|
catch (err) {
|
|
|
reject(err);
|
|
reject(err);
|
|
@@ -35,13 +54,13 @@ lightgbm.ModelFactory = class {
|
|
|
|
|
|
|
|
lightgbm.Model = class {
|
|
lightgbm.Model = class {
|
|
|
|
|
|
|
|
- constructor(model) {
|
|
|
|
|
- this._version = model.meta.version;
|
|
|
|
|
|
|
+ constructor(model, format) {
|
|
|
|
|
+ this._format = format + (model.meta && model.meta.version ? ' ' + model.meta.version : '');
|
|
|
this._graphs = [ new lightgbm.Graph(model) ];
|
|
this._graphs = [ new lightgbm.Graph(model) ];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get format() {
|
|
get format() {
|
|
|
- return 'LightGBM' + (this._version ? ' ' + this._version : '');
|
|
|
|
|
|
|
+ return this._format;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get graphs() {
|
|
get graphs() {
|
|
@@ -57,7 +76,7 @@ lightgbm.Graph = class {
|
|
|
this._nodes = [];
|
|
this._nodes = [];
|
|
|
|
|
|
|
|
const args = [];
|
|
const args = [];
|
|
|
- if (model.meta.feature_names) {
|
|
|
|
|
|
|
+ if (model.meta && model.meta.feature_names) {
|
|
|
const feature_names = model.meta.feature_names.split(' ').map((item) => item.trim());
|
|
const feature_names = model.meta.feature_names.split(' ').map((item) => item.trim());
|
|
|
for (const feature_name of feature_names) {
|
|
for (const feature_name of feature_names) {
|
|
|
const arg = new lightgbm.Argument(feature_name);
|
|
const arg = new lightgbm.Argument(feature_name);
|
|
@@ -180,7 +199,7 @@ lightgbm.basic = {};
|
|
|
|
|
|
|
|
lightgbm.basic.Booster = class {
|
|
lightgbm.basic.Booster = class {
|
|
|
|
|
|
|
|
- constructor(stream) {
|
|
|
|
|
|
|
+ constructor(reader) {
|
|
|
|
|
|
|
|
this.__module__ = 'lightgbm.basic';
|
|
this.__module__ = 'lightgbm.basic';
|
|
|
this.__name__ = 'Booster';
|
|
this.__name__ = 'Booster';
|
|
@@ -191,8 +210,6 @@ lightgbm.basic.Booster = class {
|
|
|
this.trees = [];
|
|
this.trees = [];
|
|
|
|
|
|
|
|
// GBDT::LoadModelFromString() in https://github.com/microsoft/LightGBM/blob/master/src/boosting/gbdt_model_text.cpp
|
|
// GBDT::LoadModelFromString() in https://github.com/microsoft/LightGBM/blob/master/src/boosting/gbdt_model_text.cpp
|
|
|
- const reader = base.TextReader.create(stream.peek());
|
|
|
|
|
-
|
|
|
|
|
const signature = reader.read();
|
|
const signature = reader.read();
|
|
|
if (!signature || signature.trim() !== 'tree') {
|
|
if (!signature || signature.trim() !== 'tree') {
|
|
|
throw new lightgbm.Error("Invalid signature '" + signature.trim() + "'.");
|
|
throw new lightgbm.Error("Invalid signature '" + signature.trim() + "'.");
|
|
@@ -220,7 +237,7 @@ lightgbm.basic.Booster = class {
|
|
|
state = 'param';
|
|
state = 'param';
|
|
|
continue;
|
|
continue;
|
|
|
}
|
|
}
|
|
|
- else if (line === 'feature_importances:') {
|
|
|
|
|
|
|
+ else if (line === 'feature_importances:' || line === 'feature importances:') {
|
|
|
state = 'feature_importances';
|
|
state = 'feature_importances';
|
|
|
continue;
|
|
continue;
|
|
|
}
|
|
}
|