Przeglądaj źródła

Add CatBoost test files (#1263)

Lutz Roeder 1 rok temu
rodzic
commit
0ff2c38e8d
4 zmienionych plików z 79 dodań i 2 usunięć
  1. 36 0
      source/catboost.js
  2. 9 0
      source/python.js
  3. 2 2
      source/view.js
  4. 32 0
      test/models.json

+ 36 - 0
source/catboost.js

@@ -0,0 +1,36 @@
+
+import * as python from './python.js';
+
+const catboost = {};
+
+catboost.ModelFactory = class {
+
+    match(context) {
+        const stream = context.stream;
+        if (stream && stream.length > 4) {
+            const buffer = stream.peek(4);
+            const signature = Array.from(buffer).map((c) => String.fromCharCode(c)).join('');
+            if (signature === 'CBM1') {
+                context.type = 'catboost';
+            }
+        }
+    }
+
+    async open(context) {
+        const stream = context.stream;
+        const execution = new python.Execution();
+        const model = execution.invoke('catboost.CatBoostClassifier', []);
+        model.load_model(stream);
+    }
+};
+
+catboost.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading CatBoost model.';
+    }
+};
+
+export const ModelFactory = catboost.ModelFactory;
+

+ 9 - 0
source/python.js

@@ -1656,6 +1656,7 @@ python.Execution = class {
         const functools = this.register('functools');
         this.registerType('functools.partial', class {});
         const keras = this.register('keras');
+        const catboost = this.register('catboost');
         this.register('lightgbm');
         this.register('nolearn');
         const fastcore = this.register('fastcore');
@@ -1792,6 +1793,14 @@ python.Execution = class {
                 this.args = args;
             }
         });
+        this.registerType('catboost.core._CatBoostBase', class {});
+        this.registerType('catboost.core.CatBoost', class extends catboost.core._CatBoostBase {});
+        this.registerType('catboost.core.CatBoostClassifier', class extends catboost.core.CatBoost {
+            load_model() {
+                throw new python.Error("'catboost.core.CatBoostClassifier.load_model' not implemented.");
+            }
+        });
+        catboost.CatBoostClassifier = catboost.core.CatBoostClassifier;
         this.registerType('collections.deque', class extends Array {
             constructor(iterable) {
                 super();

+ 2 - 2
source/view.js

@@ -5453,6 +5453,7 @@ view.ModelFactoryService = class {
         this.register('./nnc', ['.nnc']);
         this.register('./safetensors', ['.safetensors', '.safetensors.index.json']);
         this.register('./modular', ['.maxviz']);
+        this.register('./catboost', ['.cbm']);
         this.register('./cambricon', ['.cambricon']);
         this.register('./weka', ['.model']);
     }
@@ -5897,8 +5898,7 @@ view.ModelFactoryService = class {
                 { name: 'V8 natives blob', value: /^./, identifier: 'natives_blob.bin' },
                 { name: 'ViSQOL model', value: /^svm_type\s/ },
                 { name: 'SenseTime model', value: /^STEF/ },
-                { name: 'AES Crypt data', value: /^AES[\x01|\x02]\x00/ },
-                { name: 'CatBoost model', value: /^CBM1/ }
+                { name: 'AES Crypt data', value: /^AES[\x01|\x02]\x00/ }
             ];
             /* eslint-enable no-control-regex */
             const buffer = stream.peek(Math.min(4096, stream.length));

+ 32 - 0
test/models.json

@@ -913,6 +913,38 @@
     "error":    "File contains undocumented Cambricon data.",
     "link":     "https://github.com/lutzroeder/netron/issues/917"
   },
+  {
+    "type":     "catboost",
+    "target":   "iris_model.cbm",
+    "source":   "https://github.com/lutzroeder/netron/files/15046507/iris_model.cbm.zip[iris_model.cbm]",
+    "format":   "CatBoost",
+    "error":    "'catboost.core.CatBoostClassifier.load_model' not implemented.",
+    "link":     "https://github.com/lutzroeder/netron/issues/1263"
+  },
+  {
+    "type":     "catboost",
+    "target":   "model.cbm",
+    "source":   "https://github.com/lutzroeder/netron/files/15046506/model.cbm.zip[model.cbm]",
+    "format":   "CatBoost",
+    "error":    "'catboost.core.CatBoostClassifier.load_model' not implemented.",
+    "link":     "https://github.com/lutzroeder/netron/issues/1263"
+  },
+  {
+    "type":     "catboost",
+    "target":   "model.pkl",
+    "source":   "https://github.com/lutzroeder/netron/files/15046504/model.pkl.zip[model.pkl]",
+    "format":   "Pickle",
+    "error":    "Unsupported Pickle type 'autogluon.tabular.models.catboost.catboost_model.CatBoostModel'.",
+    "link":     "https://github.com/lutzroeder/netron/issues/1263"
+  },
+  {
+    "type":     "catboost",
+    "target":   "numeric_only_model.cbm",
+    "source":   "https://github.com/lutzroeder/netron/files/15046509/numeric_only_model.cbm.zip[numeric_only_model.cbm]",
+    "format":   "CatBoost",
+    "error":    "'catboost.core.CatBoostClassifier.load_model' not implemented.",
+    "link":     "https://github.com/lutzroeder/netron/issues/1263"
+  },
   {
     "type":     "circle",
     "target":   "ArgMax_001.circle",