Browse Source

Add LightGBM Pickle test file (#669)

Lutz Roeder 5 years ago
parent
commit
a880135b56
3 changed files with 35 additions and 11 deletions
  1. 27 10
      source/lightgbm.js
  2. 1 1
      source/view.js
  3. 7 0
      test/models.json

+ 27 - 10
source/lightgbm.js

@@ -17,14 +17,33 @@ lightgbm.ModelFactory = class {
         catch (err) {
             // continue regardless of error
         }
+        const tags = context.tags('pkl');
+        if (tags.size === 1 && tags.keys().next().value.startsWith('lightgbm.')) {
+            return true;
+        }
         return false;
     }
 
     open(context) {
         return new Promise((resolve, reject) => {
             try {
-                const booster = new lightgbm.basic.Booster(context.stream);
-                resolve(new lightgbm.Model(booster));
+                const tags = context.tags('pkl');
+                let model;
+                let format;
+                if (tags.size === 1) {
+                    format = 'LightGBM Pickle';
+                    model = tags.values().next().value;
+                    if (model && model.handle && typeof model.handle === 'string') {
+                        const reader = base.TextReader.create(model.handle);
+                        model = new lightgbm.basic.Booster(reader);
+                    }
+                }
+                else {
+                    format = 'LightGBM';
+                    const reader = base.TextReader.create(context.stream.peek());
+                    model = new lightgbm.basic.Booster(reader);
+                }
+                resolve(new lightgbm.Model(model, format));
             }
             catch (err) {
                 reject(err);
@@ -35,13 +54,13 @@ lightgbm.ModelFactory = class {
 
 lightgbm.Model = class {
 
-    constructor(model) {
-        this._version = model.meta.version;
+    constructor(model, format) {
+        this._format = format + (model.meta && model.meta.version ? ' ' + model.meta.version : '');
         this._graphs = [ new lightgbm.Graph(model) ];
     }
 
     get format() {
-        return 'LightGBM' + (this._version ? ' ' + this._version : '');
+        return this._format;
     }
 
     get graphs() {
@@ -57,7 +76,7 @@ lightgbm.Graph = class {
         this._nodes = [];
 
         const args = [];
-        if (model.meta.feature_names) {
+        if (model.meta && model.meta.feature_names) {
             const feature_names = model.meta.feature_names.split(' ').map((item) => item.trim());
             for (const feature_name of feature_names) {
                 const arg = new lightgbm.Argument(feature_name);
@@ -180,7 +199,7 @@ lightgbm.basic = {};
 
 lightgbm.basic.Booster = class {
 
-    constructor(stream) {
+    constructor(reader) {
 
         this.__module__ = 'lightgbm.basic';
         this.__name__ = 'Booster';
@@ -191,8 +210,6 @@ lightgbm.basic.Booster = class {
         this.trees = [];
 
         // GBDT::LoadModelFromString() in https://github.com/microsoft/LightGBM/blob/master/src/boosting/gbdt_model_text.cpp
-        const reader = base.TextReader.create(stream.peek());
-
         const signature = reader.read();
         if (!signature || signature.trim() !== 'tree') {
             throw new lightgbm.Error("Invalid signature '" + signature.trim() + "'.");
@@ -220,7 +237,7 @@ lightgbm.basic.Booster = class {
                 state = 'param';
                 continue;
             }
-            else if (line === 'feature_importances:') {
+            else if (line === 'feature_importances:' || line === 'feature importances:') {
                 state = 'feature_importances';
                 continue;
             }

+ 1 - 1
source/view.js

@@ -1328,6 +1328,7 @@ view.ModelFactoryService = class {
         this.register('./uff', [ '.uff', '.pb', '.pbtxt', '.uff.txt', '.trt', '.engine' ]);
         this.register('./npz', [ '.npz', '.pkl' ]);
         this.register('./lasagne', [ '.pkl', '.pickle', '.joblib', '.model', '.pkl.z', '.joblib.z' ]);
+        this.register('./lightgbm', [ '.txt', '.pkl' ]);
         this.register('./sklearn', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z' ]);
         this.register('./pickle', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z' ]);
         this.register('./cntk', [ '.model', '.cntk', '.cmf', '.dnn' ]);
@@ -1351,7 +1352,6 @@ view.ModelFactoryService = class {
         this.register('./dl4j', [ '.zip' ]);
         this.register('./mlnet', [ '.zip' ]);
         this.register('./acuity', [ '.json' ]);
-        this.register('./lightgbm', [ '.txt' ]);
     }
 
     register(id, extensions) {

+ 7 - 0
test/models.json

@@ -2330,6 +2330,13 @@
     "format": "Lasagne",
     "link":   "https://github.com/Aabglov/LasaganeTest"
   },
+  {
+    "type":   "lightgbm",
+    "target": "simple_example.pkl",
+    "source": "https://github.com/lutzroeder/netron/files/5978325/simple_example.pkl.zip[simple_example.pkl]",
+    "format": "LightGBM Pickle v2",
+    "link":   "https://github.com/lutzroeder/netron/issues/669"
+  },
   {
     "type":   "lightgbm",
     "target": "simple_example.txt",