Просмотр исходного кода

scikit-learn weight list support

Lutz Roeder 6 лет назад
Родитель
Сommit
cef9b1a83c
3 измененных файлов с 57 добавлено и 24 удалено
  1. 7 9
      src/pickle.js
  2. 44 15
      src/sklearn.js
  3. 6 0
      test/models.json

+ 7 - 9
src/pickle.js

@@ -15,8 +15,7 @@ pickle.Unpickler = class {
         let stack = [];
         let memo = new Map();
         while (reader.position < reader.length) {
-            let opcode = reader.byte();
-            // console.log(reader.position.toString() + ': ' + opcode.toString());
+            const opcode = reader.byte();
             switch (opcode) {
                 case pickle.OpCode.PROTO: {
                     const version = reader.byte();
@@ -37,7 +36,7 @@ pickle.Unpickler = class {
                     break;
                 }
                 case pickle.OpCode.OBJ: {
-                    let items = stack;
+                    const items = stack;
                     stack = marker.pop();
                     stack.push(function_call(items.pop(), items));
                     break;
@@ -533,20 +532,19 @@ pickle.Reader = class {
     }
 
     string(size, encoding) {
-        let data = this.bytes(size);
-        let text = (encoding == 'utf-8') ?
+        const data = this.bytes(size);
+        return (encoding == 'utf-8') ?
             pickle.Reader._utf8Decoder.decode(data) :
             pickle.Reader._asciiDecoder.decode(data);
-        return text;
     }
 
     line() {
-        let index = this._buffer.indexOf(0x0A, this._position);
+        const index = this._buffer.indexOf(0x0A, this._position);
         if (index == -1) {
             throw new pickle.Error("Could not find end of line.");
         }
-        let size = index - this._position;
-        let text = this.string(size, 'ascii');
+        const size = index - this._position;
+        const text = this.string(size, 'ascii');
         this.seek(1);
         return text;
     }

+ 44 - 15
src/sklearn.js

@@ -336,8 +336,15 @@ sklearn.ModelFactory = class {
                     }
                     return obj;
                 };
-                functionTable['__builtin__.bytearray'] = function(data, encoding) {
-                    return { data: data, encoding: encoding };
+                functionTable['__builtin__.bytearray'] = function(source, encoding /*, errors */) {
+                    if (encoding === 'latin-1') {
+                        let array = new Uint8Array(source.length);
+                        for (let i = 0; i < source.length; i++) {
+                            array[i] = source.charCodeAt(i);
+                        }
+                        return array;
+                    }
+                    throw new sklearn.Error("Unsupported bytearray encoding '" + JSON.stringify(encoding) + "'.");
                 };
                 functionTable['builtins.bytearray'] = function(data) {
                     return { data: data };
@@ -375,40 +382,62 @@ sklearn.ModelFactory = class {
                 };
 
                 obj = unpickler.load(function_call, null);
-                if (obj && Array.isArray(obj)) {
-                    throw new sklearn.Error('Array is not a valid root object.');
-                }
 
-                let find_weight_dict = function(dicts) {
+                const find_weights = function(objs) {
 
-                    for (const dict of dicts) {
+                    for (const dict of objs) {
                         if (dict && !Array.isArray(dict)) {
-                            let list = [];
+                            let weights = [];
                             for (const key in dict) {
                                 const value = dict[key]
                                 if (key != 'weight_order' && key != 'lr') {
                                     if (!key ||
                                         !value.__type__ || !value.__type__ == 'numpy.ndarray') {
-                                        list = null;
+                                        weights = null;
                                         break;
                                     }
-                                    list.push({ key: key, value: value });
+                                    weights.push({ key: key, value: value });
+                                }
+                            }
+                            if (weights) {
+                                return weights;
+                            }
+                        }
+                    }
+
+                    for (const list of objs) {
+                        if (list && Array.isArray(list)) {
+                            let weights = [];
+                            for (let i = 0; i < list.length; i++) {
+                                const value = list[i];
+                                if (!value.__type__ || !value.__type__ == 'numpy.ndarray') {
+                                    weights = null;
+                                    break;
                                 }
+                                weights.push({ key: i.toString(), value: value });
                             }
-                            if (list) {
-                                return list;
+                            if (weights) {
+                                return weights;
                             }
                         }
                     }
                     return null;
                 }
 
-                weights = find_weight_dict([ obj, obj.blobs ]);
+                weights = find_weights([ obj, obj.blobs ]);
                 if (weights) {
                     obj = null;
                 }
-                if (!weights && (!obj || !obj.__type__)) {
-                    throw new sklearn.Error('Root object has no type.');
+                if (!weights) {
+                    if (!obj) {
+                        throw new sklearn.Error('No root object.');
+                    }
+                    if (Array.isArray(obj)) {
+                        throw new sklearn.Error('Root is nullArray is not a valid root object.');
+                    }                    
+                    if (!obj.__type__) {
+                        throw new sklearn.Error('Root object has no type.');
+                    }
                 }
             }
             catch (error) {

+ 6 - 0
test/models.json

@@ -4526,6 +4526,12 @@
     "format": "scikit-learn 0.20.2",
     "link":   "https://github.com/lutzroeder/netron/issues/182"
   },
+  {
+    "type":   "sklearn",
+    "target": "curvrank.localization.dorsal.weights.pkl",
+    "source": "https://lev.cs.rpi.edu/public/models/curvrank.localization.dorsal.weights.pkl",
+    "format": "scikit-learn"
+  },
   {
     "type":   "sklearn",
     "target": "LDA_model.pkl",