ソースを参照

Update LightGBM prototype (#669)

Lutz Roeder 4 年 前
コミット
10f6583f56
2 ファイル変更169 行追加139 行削除
  1. 44 138
      source/lightgbm.js
  2. 125 1
      source/python.js

+ 44 - 138
source/lightgbm.js

@@ -1,16 +1,15 @@
 /* jshint esversion: 6 */
 
 var lightgbm = lightgbm || {};
-var base = base || require('./base');
+var python = python || require('./python');
 
 lightgbm.ModelFactory = class {
 
     match(context) {
         try {
             const stream = context.stream;
-            const reader = base.TextReader.open(stream.peek(), 65536);
-            const line = reader.read();
-            if (line === 'tree') {
+            const signature = [ 0x74, 0x72, 0x65, 0x65, 0x0A ];
+            if (stream.length >= signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
                 return 'lightgbm.text';
             }
         }
@@ -27,28 +26,26 @@ lightgbm.ModelFactory = class {
     open(context, match) {
         return new Promise((resolve, reject) => {
             try {
-                let model;
+                let obj;
                 let format;
                 switch (match) {
                     case 'lightgbm.pickle': {
+                        obj = context.open('pkl');
                         format = 'LightGBM Pickle';
-                        const obj = context.open('pkl');
-                        model = obj;
-                        if (model && model.handle && typeof model.handle === 'string') {
-                            const reader = base.TextReader.open(model.handle);
-                            model = new lightgbm.basic.Booster(reader);
-                        }
                         break;
                     }
                     case 'lightgbm.text': {
-                        format = 'LightGBM';
                         const stream = context.stream;
                         const buffer = stream.peek();
-                        const reader = base.TextReader.open(buffer);
-                        model = new lightgbm.basic.Booster(reader);
+                        const decoder = new TextDecoder('utf-8');
+                        const model_str = decoder.decode(buffer);
+                        const execution = new python.Execution(null);
+                        obj = execution.invoke('lightgbm.basic.Booster', []);
+                        obj.LoadModelFromString(model_str);
+                        format = 'LightGBM';
                     }
                 }
-                resolve(new lightgbm.Model(model, format));
+                resolve(new lightgbm.Model(obj, format));
             }
             catch (err) {
                 reject(err);
@@ -59,9 +56,9 @@ lightgbm.ModelFactory = class {
 
 lightgbm.Model = class {
 
-    constructor(model, format) {
-        this._format = format + (model.meta && model.meta.version ? ' ' + model.meta.version : '');
-        this._graphs = [ new lightgbm.Graph(model) ];
+    constructor(obj, format) {
+        this._format = format + (obj && obj.version ? ' ' + obj.version : '');
+        this._graphs = [ new lightgbm.Graph(obj) ];
     }
 
     get format() {
@@ -81,17 +78,16 @@ lightgbm.Graph = class {
         this._nodes = [];
 
         const args = [];
-        if (model.meta && model.meta.feature_names) {
-            const feature_names = model.meta.feature_names.split(' ').map((item) => item.trim());
-            for (const feature_name of feature_names) {
-                const arg = new lightgbm.Argument(feature_name);
-                args.push(arg);
-                if (feature_names.length < 1000) {
-                    this._inputs.push(new lightgbm.Parameter(feature_name, [ arg ]));
-                }
+        const feature_names = model.feature_names || [];
+        for (let i = 0; i < feature_names.length; i++) {
+            const name = feature_names[i];
+            const info = model.feature_infos && i < model.feature_infos.length ? model.feature_infos[i] : null;
+            const argument = new lightgbm.Argument(name, info);
+            args.push(argument);
+            if (feature_names.length < 1000) {
+                this._inputs.push(new lightgbm.Parameter(name, [ argument ]));
             }
         }
-
         this._nodes.push(new lightgbm.Node(model, args));
     }
 
@@ -130,11 +126,12 @@ lightgbm.Parameter = class {
 
 lightgbm.Argument = class {
 
-    constructor(name) {
+    constructor(name, quantization) {
         if (typeof name !== 'string') {
             throw new lightgbm.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
         }
         this._name = name;
+        this._quantization = quantization;
     }
 
     get name() {
@@ -145,6 +142,10 @@ lightgbm.Argument = class {
         return null;
     }
 
+    get quantization() {
+        return this._quantization;
+    }
+
     get initializer() {
         return null;
     }
@@ -159,8 +160,21 @@ lightgbm.Node = class {
         this._outputs = [];
         this._attributes = [];
         this._inputs.push(new lightgbm.Parameter('features', args));
-        for (const key of Object.keys(model.params)) {
-            this._attributes.push(new lightgbm.Attribute(key, model.params[key]));
+        for (const entry of Object.entries(model)) {
+            const key = entry[0];
+            const value = entry[1];
+            if (value === undefined) {
+                continue;
+            }
+            switch (key) {
+                case 'tree':
+                case 'version':
+                case 'feature_names':
+                case 'feature_infos':
+                    break;
+                default:
+                    this._attributes.push(new lightgbm.Attribute(key, value));
+            }
         }
     }
 
@@ -201,114 +215,6 @@ lightgbm.Attribute = class {
     }
 };
 
-lightgbm.basic = {};
-
-lightgbm.basic.Booster = class {
-
-    constructor(reader) {
-
-        this.__class__ = {
-            __module__: 'lightgbm.basic',
-            __name__: 'Booster'
-        };
-
-        this.params = {};
-        this.feature_importances = {};
-        this.meta = {};
-        this.trees = [];
-
-        // GBDT::LoadModelFromString() in https://github.com/microsoft/LightGBM/blob/master/src/boosting/gbdt_model_text.cpp
-        const signature = reader.read();
-        if (!signature || signature.trim() !== 'tree') {
-            throw new lightgbm.Error("Invalid signature '" + signature.trim() + "'.");
-        }
-        let state = '';
-        let tree = null;
-        // let lineNumber = 0;
-        for (;;) {
-            // lineNumber++;
-            const text = reader.read();
-            if (text === undefined) {
-                break;
-            }
-            const line = text.trim();
-            if (line.length === 0) {
-                continue;
-            }
-            if (line.startsWith('Tree=')) {
-                state = 'tree';
-                tree = { index: parseInt(line.split('=').pop(), 10) };
-                this.trees.push(tree);
-                continue;
-            }
-            else if (line === 'parameters:') {
-                state = 'param';
-                continue;
-            }
-            else if (line === 'feature_importances:' || line === 'feature importances:') {
-                state = 'feature_importances';
-                continue;
-            }
-            else if (line === 'end of trees' || line === 'end of parameters') {
-                state = '';
-                continue;
-            }
-            else if (line.startsWith('pandas_categorical:')) {
-                state = 'pandas_categorical';
-                continue;
-            }
-            switch (state) {
-                case '': {
-                    const param = line.split('=');
-                    if (param.length !== 2 && !/^[A-Za-z0-9_]/.exec(param[0].trim())) {
-                        throw new lightgbm.Error("Invalid property '" + line + "'.");
-                    }
-                    const name = param[0].trim();
-                    const value = param.length > 1 ? param[1].trim() : undefined;
-                    this.meta[name] = value;
-                    break;
-                }
-                case 'param': {
-                    if (!line.startsWith('[') || !line.endsWith(']')) {
-                        throw new lightgbm.Error("Invalid parameter '" + line + "'.");
-                    }
-                    const param = line.substring(1, line.length - 2).split(':');
-                    if (param.length !== 2) {
-                        throw new lightgbm.Error("Invalid param '" + line + "'.");
-                    }
-                    const name = param[0].trim();
-                    const value = param[1].trim();
-                    this.params[name] = value;
-                    break;
-                }
-                case 'tree': {
-                    const param = line.split('=');
-                    if (param.length !== 2) {
-                        throw new lightgbm.Error("Invalid property '" + line + "'.");
-                    }
-                    const name = param[0].trim();
-                    const value = param[1].trim();
-                    tree[name] = value;
-                    break;
-                }
-                case 'feature_importances': {
-                    const param = line.split('=');
-                    if (param.length !== 2) {
-                        throw new lightgbm.Error("Invalid feature importance '" + line + "'.");
-                    }
-                    const name = param[0].trim();
-                    const value = param[1].trim();
-                    this.feature_importances[name] = value;
-                    break;
-                }
-                case 'pandas_categorical': {
-                    break;
-                }
-            }
-        }
-    }
-};
-
 lightgbm.Error = class extends Error {
 
     constructor(message) {

+ 125 - 1
source/python.js

@@ -1842,7 +1842,131 @@ python.Execution = class {
         });
         this.registerType('lightgbm.sklearn.LGBMRegressor', class {});
         this.registerType('lightgbm.sklearn.LGBMClassifier', class {});
-        this.registerType('lightgbm.basic.Booster', class {});
+        this.registerType('lightgbm.basic.Booster', class {
+            constructor() {
+                this.average_output = false;
+                this.models = [];
+                this.loaded_parameter = '';
+            }
+            __setstate__(state) {
+                if (typeof state.handle === 'string') {
+                    this.LoadModelFromString(state.handle);
+                    return;
+                }
+                Object.assign(this, state);
+            }
+            LoadModelFromString(model_str) {
+                const lines = model_str.split('\n');
+                const signature = lines.shift() || '?';
+                if (signature.trim() !== 'tree') {
+                    throw new python.Error("Invalid signature '" + signature.trim() + "'.");
+                }
+                // GBDT::LoadModelFromString() in https://github.com/microsoft/LightGBM/blob/master/src/boosting/gbdt_model_text.cpp
+                const key_vals = new Map();
+                while (lines.length > 0 && !lines[0].startsWith('Tree=')) {
+                    const cur_line = lines.shift().trim();
+                    if (cur_line.length > 0) {
+                        const strs = cur_line.split('=');
+                        if (strs.length === 1) {
+                            key_vals.set(strs[0], '');
+                        }
+                        else if (strs.length === 2) {
+                            key_vals.set(strs[0], strs[1]);
+                        }
+                        else if (strs.length > 2) {
+                            if (strs[0] === "feature_names") {
+                                key_vals.set(strs[0], cur_line.substring("feature_names=".length));
+                            }
+                            else if (strs[0] == 'monotone_constraints') {
+                                key_vals.set(strs[0], cur_line.substring('monotone_constraints='.length));
+                            }
+                            else {
+                                throw new python.Error('Wrong line: ' + cur_line.substring(0, Math.min(128, cur_line.length)));
+                            }
+                        }
+                    }
+                }
+                const atoi = (key, value) => {
+                    if (key_vals.has(key)) {
+                        return parseInt(key_vals.get(key), 10);
+                    }
+                    if (value !== undefined) {
+                        return value;
+                    }
+                    throw new python.Error('Model file does not specify ' + key + '.');
+                };
+                const list = (key, size) => {
+                    if (key_vals.has(key)) {
+                        const value = key_vals.get(key).split(' ');
+                        if (value.length !== size) {
+                            throw new python.Error('Wrong size of ' + key + '.');
+                        }
+                        return value;
+                    }
+                    throw new python.Error('Model file does not contain ' + key + '.');
+                };
+                this.version = key_vals.get('version') || '';
+                this.num_class = atoi('num_class');
+                this.num_tree_per_iteration = atoi('num_tree_per_iteration', this.num_class);
+                this.label_index = atoi('label_index');
+                this.max_feature_idx = atoi('max_feature_idx');
+                if (key_vals.has('average_output')) {
+                    this.average_output = true;
+                }
+                this.feature_names = list('feature_names', this.max_feature_idx + 1);
+                this.feature_infos = list('feature_infos', this.max_feature_idx + 1);
+                if (key_vals.has('monotone_constraints')) {
+                    this.monotone_constraints = list('monotone_constraints', this.max_feature_idx + 1, true);
+                }
+                if (key_vals.has('objective')) {
+                    this.objective = key_vals.get('objective');
+                }
+                let tree = null;
+                // let lineNumber = 0;
+                while (lines.length > 0) {
+                    // lineNumber++;
+                    const text = lines.shift();
+                    const line = text.trim();
+                    if (line.length === 0) {
+                        continue;
+                    }
+                    if (line.startsWith('Tree=')) {
+                        tree = { index: parseInt(line.split('=').pop(), 10) };
+                        this.models.push(tree);
+                        continue;
+                    }
+                    if (line === 'end of trees') {
+                        break;
+                    }
+                    const param = line.split('=');
+                    if (param.length !== 2) {
+                        throw new python.Error("Invalid property '" + line + "'.");
+                    }
+                    const name = param[0].trim();
+                    const value = param[1].trim();
+                    tree[name] = value;
+                }
+                const ss = [];
+                let is_inparameter = false;
+                while (lines.length > 0) {
+                    const text = lines.shift();
+                    const line = text.trim();
+                    if (line === 'parameters:') {
+                        is_inparameter = true;
+                        continue;
+                    }
+                    else if (line === 'end of parameters') {
+                        break;
+                    }
+                    else if (is_inparameter) {
+                        ss.push(line);
+                    }
+                }
+                if (ss.length > 0) {
+                    this.loaded_parameter = ss.join('\n');
+                }
+            }
+        });
         this.registerType('nolearn.lasagne.base.BatchIterator', class {});
         this.registerType('nolearn.lasagne.base.Layers', class {});
         this.registerType('nolearn.lasagne.base.NeuralNet', class {});