ソースを参照

Update NumPy dtype support

Lutz Roeder 3 年 前
コミット
cf045bf586
7 ファイル変更116 行追加42 行削除
  1. 1 1
      source/flax.js
  2. 1 1
      source/keras.js
  3. 1 1
      source/lasagne.js
  4. 1 1
      source/numpy.js
  5. 1 1
      source/paddle.js
  6. 89 29
      source/python.js
  7. 22 8
      source/sklearn.js

+ 1 - 1
source/flax.js

@@ -207,7 +207,7 @@ flax.TensorShape = class {
 flax.Tensor = class {
 
     constructor(array) {
-        this._type = new flax.TensorType(array.dtype.name, new flax.TensorShape(array.shape));
+        this._type = new flax.TensorType(array.dtype.__name__, new flax.TensorShape(array.shape));
         this._data = array.tobytes();
         this._byteorder = array.dtype.byteorder;
         this._itemsize = array.dtype.itemsize;

+ 1 - 1
source/keras.js

@@ -339,7 +339,7 @@ keras.ModelFactory = class {
                                     const buffer = layer_weights[weight_name];
                                     const unpickler = python.Unpickler.open(buffer);
                                     const variable = unpickler.load((name, args) => execution.invoke(name, args));
-                                    const tensor = new keras.Tensor(weight_name, variable.shape, variable.dtype.name, null, true, variable.data);
+                                    const tensor = new keras.Tensor(weight_name, variable.shape, variable.dtype.__name__, null, true, variable.data);
                                     weights.add(layer_name, tensor);
                                 }
                             }

+ 1 - 1
source/lasagne.js

@@ -320,7 +320,7 @@ lasagne.TensorShape = class {
 lasagne.Tensor = class {
 
     constructor(storage) {
-        this._type = new lasagne.TensorType(storage.dtype.name, new lasagne.TensorShape(storage.shape));
+        this._type = new lasagne.TensorType(storage.dtype.__name__, new lasagne.TensorShape(storage.shape));
     }
 
     get type() {

+ 1 - 1
source/numpy.js

@@ -273,7 +273,7 @@ numpy.Node = class {
 numpy.Tensor = class  {
 
     constructor(array) {
-        this._type = new numpy.TensorType(array.dtype.name, new numpy.TensorShape(array.shape));
+        this._type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
         this._data = array.tobytes();
         this._byteorder = array.dtype.byteorder;
         this._itemsize = array.dtype.itemsize;

+ 1 - 1
source/paddle.js

@@ -808,7 +808,7 @@ paddle.Utility = class {
             const value = entry[1];
             if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
                 const name = map ? map[key] : key;
-                const type = new paddle.TensorType(value.dtype.name, new paddle.TensorShape(value.shape));
+                const type = new paddle.TensorType(value.dtype.__name__, new paddle.TensorShape(value.shape));
                 const data = value.data;
                 const tensor = new paddle.Tensor(type, data, 'NumPy Array');
                 weights.set(name, tensor);

+ 89 - 29
source/python.js

@@ -1729,45 +1729,37 @@ python.Execution = class {
         this.registerType('numpy.dtype', class {
             constructor(obj, align, copy) {
                 switch (obj) {
-                    case 'b1': case 'bool': this.name = 'bool'; this.itemsize = 1; this.kind = 'b'; break;
-                    case 'i1': case 'int8': this.name = 'int8'; this.itemsize = 1; this.kind = 'i'; break;
-                    case 'i2': case 'int16': this.name = 'int16'; this.itemsize = 2; this.kind = 'i'; break;
-                    case 'i4': case 'int32': this.name = 'int32'; this.itemsize = 4; this.kind = 'i'; break;
-                    case 'i8': case 'int64': case 'int': this.name = 'int64'; this.itemsize = 8; this.kind = 'i'; break;
-                    case 'u1': case 'uint8': this.name = 'uint8'; this.itemsize = 1; this.kind = 'u'; break;
-                    case 'u2': case 'uint16': this.name = 'uint16'; this.itemsize = 2; this.kind = 'u'; break;
-                    case 'u4': case 'uint32': this.name = 'uint32'; this.itemsize = 4; this.kind = 'u'; break;
-                    case 'u8': case 'uint64': case 'uint': this.name = 'uint64'; this.itemsize = 8; this.kind = 'u'; break;
-                    case 'f2': case 'float16': this.name = 'float16'; this.itemsize = 2; this.kind = 'f'; break;
-                    case 'f4': case 'float32': this.name = 'float32'; this.itemsize = 4; this.kind = 'f'; break;
-                    case 'f8': case 'float64': case 'float': this.name = 'float64'; this.itemsize = 8; this.kind = 'f'; break;
-                    case 'c8': case 'complex64': this.name = 'complex64'; this.itemsize = 8; this.kind = 'c'; break;
-                    case 'c16': case 'complex128': case 'complex': this.name = 'complex128'; this.itemsize = 16; this.kind = 'c'; break;
+                    case 'b1': case 'bool': this.itemsize = 1; this.kind = 'b'; break;
+                    case 'i1': case 'int8': this.itemsize = 1; this.kind = 'i'; break;
+                    case 'i2': case 'int16': this.itemsize = 2; this.kind = 'i'; break;
+                    case 'i4': case 'int32': this.itemsize = 4; this.kind = 'i'; break;
+                    case 'i8': case 'int64': case 'int': this.itemsize = 8; this.kind = 'i'; break;
+                    case 'u1': case 'uint8': this.itemsize = 1; this.kind = 'u'; break;
+                    case 'u2': case 'uint16': this.itemsize = 2; this.kind = 'u'; break;
+                    case 'u4': case 'uint32': this.itemsize = 4; this.kind = 'u'; break;
+                    case 'u8': case 'uint64': case 'uint': this.itemsize = 8; this.kind = 'u'; break;
+                    case 'f2': case 'float16': this.itemsize = 2; this.kind = 'f'; break;
+                    case 'f4': case 'float32': this.itemsize = 4; this.kind = 'f'; break;
+                    case 'f8': case 'float64': case 'float': this.itemsize = 8; this.kind = 'f'; break;
+                    case 'c8': case 'complex64': this.itemsize = 8; this.kind = 'c'; break;
+                    case 'c16': case 'complex128': case 'complex': this.itemsize = 16; this.kind = 'c'; break;
+                    case 'M': this.itemsize = 8; this.kind = 'M'; break;
                     default:
                         if (obj.startsWith('V')) {
                             this.itemsize = parseInt(obj.substring(1), 10);
                             this.kind = 'V';
-                            this.name = 'void' + (this.itemsize * 8).toString();
                         }
                         else if (obj.startsWith('O')) {
                             this.itemsize = parseInt(obj.substring(1), 10);
                             this.kind = 'O';
-                            this.name = 'object';
                         }
                         else if (obj.startsWith('S')) {
                             this.itemsize = parseInt(obj.substring(1), 10);
                             this.kind = 'S';
-                            this.name = 'string';
                         }
                         else if (obj.startsWith('U')) { // Unicode string
-                            this.itemsize = 4 * parseInt(obj.substring(1), 10);
                             this.kind = 'U';
-                            this.name = 'string'; // 'str' + (8 * this.itemsize)
-                        }
-                        else if (obj.startsWith('M')) { // datetime
-                            this.itemsize = parseInt(obj.substring(1), 10);
-                            this.kind = 'M';
-                            this.name = 'datetime';
+                            this.itemsize = 4 * parseInt(obj.substring(1), 10);
                         }
                         else {
                             throw new python.Error("Unsupported dtype '" + obj.toString() + "'.");
@@ -1785,6 +1777,15 @@ python.Execution = class {
             get str() {
                 return (this.byteorder === '=' ? '<' : this.byteorder) + this.kind + this.itemsize.toString();
             }
+            get name() {
+                switch (this.kind) {
+                    case 'V': return 'void' + (this.itemsize === 0 ? '' : (this.itemsize * 8).toString());
+                    case 'S': return 'bytes' + (this.itemsize === 0 ? '' : (this.itemsize * 8).toString());
+                    case 'U': return 'str' + (this.itemsize === 0 ? '' : (this.itemsize * 8).toString());
+                    case 'M': return 'datetime64';
+                    default: return this.name;
+                }
+            }
             __setstate__(state) {
                 switch (state.length) {
                     case 8:
@@ -1812,6 +1813,55 @@ python.Execution = class {
                         throw new python.Error("Unsupported numpy.dtype setstate length '" + state.length.toString() + "'.");
                 }
             }
+            get __name__() {
+                switch (this.kind) {
+                    case 'b':
+                        switch (this.itemsize) {
+                            case 1: return 'boolean';
+                            default: throw new python.Error("Unsupported boolean itemsize '" + this.itemsize + "'.");
+                        }
+                    case 'i':
+                        switch (this.itemsize) {
+                            case 1: return 'int8';
+                            case 2: return 'int16';
+                            case 4: return 'int32';
+                            case 8: return 'int64';
+                            default: throw new python.Error("Unsupported int itemsize '" + this.itemsize + "'.");
+                        }
+                    case 'u':
+                        switch (this.itemsize) {
+                            case 1: return 'uint8';
+                            case 2: return 'uint16';
+                            case 4: return 'uint32';
+                            case 8: return 'uint64';
+                            default: throw new python.Error("Unsupported uint itemsize '" + this.itemsize + "'.");
+                        }
+                    case 'f':
+                        switch (this.itemsize) {
+                            case 2: return 'float16';
+                            case 4: return 'float32';
+                            case 8: return 'float64';
+                            default: throw new python.Error("Unsupported float itemsize '" + this.itemsize + "'.");
+                        }
+                    case 'c':
+                        switch (this.itemsize) {
+                            case 8: return 'complex64';
+                            case 16: return 'complex128';
+                            default: throw new python.Error("Unsupported complex itemsize '" + this.itemsize + "'.");
+                        }
+                    case 'S':
+                    case 'U':
+                        return 'string';
+                    case 'M':
+                        return 'datetime';
+                    case 'O':
+                        return 'object';
+                    case 'V':
+                        return 'void';
+                    default:
+                        throw new python.Error("Unsupported dtype kind '" + this.kind + "'.");
+                }
+            }
         });
         this.registerType('gensim.models.doc2vec.Doctag', class {});
         this.registerType('gensim.models.doc2vec.Doc2Vec', class {});
@@ -1843,7 +1893,7 @@ python.Execution = class {
                 this.allow_mmap = state.allow_mmap;
             }
             __read__(unpickler) {
-                if (this.dtype.name == 'object') {
+                if (this.dtype.__name__ == 'object') {
                     return unpickler.load((name, args) => self.invoke(name, args), null);
                 }
                 else {
@@ -2044,12 +2094,14 @@ python.Execution = class {
         this.registerType('sklearn.calibration.CalibratedClassifierCV', class {});
         this.registerType('sklearn.compose._column_transformer.ColumnTransformer', class {});
         this.registerType('sklearn.compose._target.TransformedTargetRegressor', class {});
+        this.registerType('sklearn.cluster._agglomerative.FeatureAgglomeration', class {});
         this.registerType('sklearn.cluster._dbscan.DBSCAN', class {});
         this.registerType('sklearn.cluster._kmeans.KMeans', class {});
+        this.registerType('sklearn.decomposition._fastica.FastICA', class {});
         this.registerType('sklearn.decomposition._pca.PCA', class {});
+        this.registerType('sklearn.decomposition._truncated_svd.TruncatedSVD', class {});
         this.registerType('sklearn.decomposition.PCA', class {});
         this.registerType('sklearn.decomposition.pca.PCA', class {});
-        this.registerType('sklearn.decomposition._truncated_svd.TruncatedSVD', class {});
         this.registerType('sklearn.decomposition.truncated_svd.TruncatedSVD', class {});
         this.registerType('sklearn.discriminant_analysis.LinearDiscriminantAnalysis', class {});
         this.registerType('sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis', class {});
@@ -2066,7 +2118,7 @@ python.Execution = class {
                 this.allow_mmap = state.allow_mmap;
             }
             __read__(unpickler) {
-                if (this.dtype.name == 'object') {
+                if (this.dtype.__name__ == 'object') {
                     return unpickler.load((name, args) => self.invoke(name, args), null);
                 }
                 else {
@@ -2097,9 +2149,14 @@ python.Execution = class {
         this.registerType('sklearn.ensemble._gb_losses.MultinomialDeviance', class {});
         this.registerType('sklearn.ensemble._gb.GradientBoostingClassifier', class {});
         this.registerType('sklearn.ensemble._gb.GradientBoostingRegressor', class {});
+        this.registerType('sklearn.ensemble._hist_gradient_boosting.binning._BinMapper', class {});
+        this.registerType('sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingRegressor', class {});
+        this.registerType('sklearn.ensemble._hist_gradient_boosting.loss.LeastSquares', class {});
+        this.registerType('sklearn.ensemble._hist_gradient_boosting.predictor.TreePredictor', class {});
         this.registerType('sklearn.ensemble._iforest.IsolationForest', class {});
         this.registerType('sklearn.ensemble._stacking.StackingClassifier', class {});
         this.registerType('sklearn.ensemble._voting.VotingClassifier', class {});
+        this.registerType('sklearn.ensemble._weight_boosting.AdaBoostClassifier', class {});
         this.registerType('sklearn.ensemble.forest.RandomForestClassifier', class {});
         this.registerType('sklearn.ensemble.forest.RandomForestRegressor', class {});
         this.registerType('sklearn.ensemble.forest.ExtraTreesClassifier', class {});
@@ -2115,6 +2172,7 @@ python.Execution = class {
         this.registerType('sklearn.feature_extraction.text.TfidfTransformer', class {});
         this.registerType('sklearn.feature_extraction.text.TfidfVectorizer', class {});
         this.registerType('sklearn.feature_selection._from_model.SelectFromModel', class {});
+        this.registerType('sklearn.feature_selection._univariate_selection.GenericUnivariateSelect', class {});
         this.registerType('sklearn.feature_selection._univariate_selection.SelectKBest', class {});
         this.registerType('sklearn.feature_selection._univariate_selection.SelectPercentile', class {});
         this.registerType('sklearn.feature_selection._variance_threshold.VarianceThreshold', class {});
@@ -2135,6 +2193,7 @@ python.Execution = class {
         this.registerType('sklearn.linear_model._coordinate_descent.ElasticNet', class {});
         this.registerType('sklearn.linear_model._logistic.LogisticRegression', class {});
         this.registerType('sklearn.linear_model._ridge.Ridge', class {});
+        this.registerType('sklearn.linear_model._ridge.RidgeClassifier', class {});
         this.registerType('sklearn.linear_model._sgd_fast.Hinge', class {});
         this.registerType('sklearn.linear_model._sgd_fast.Log', class {});
         this.registerType('sklearn.linear_model._sgd_fast.ModifiedHuber', class {});
@@ -2189,6 +2248,7 @@ python.Execution = class {
         this.registerType('sklearn.preprocessing._data.MaxAbsScaler', class {});
         this.registerType('sklearn.preprocessing._data.Normalizer', class {});
         this.registerType('sklearn.preprocessing._data.PolynomialFeatures', class {});
+        this.registerType('sklearn.preprocessing._data.PowerTransformer', class {});
         this.registerType('sklearn.preprocessing._data.QuantileTransformer', class {});
         this.registerType('sklearn.preprocessing._data.RobustScaler', class {});
         this.registerType('sklearn.preprocessing._data.StandardScaler', class {});
@@ -2545,7 +2605,7 @@ python.Execution = class {
                 }
             }
             const dataView = new DataView(data.buffer, data.byteOffset, data.byteLength);
-            switch (dtype.name) {
+            switch (dtype.__name__) {
                 case 'uint8':
                     return dataView.getUint8(0);
                 case 'float32':
@@ -2561,7 +2621,7 @@ python.Execution = class {
                 case 'int64':
                     return dataView.getInt64(0, true);
                 default:
-                    throw new python.Error("Unsupported scalar type '" + dtype.name + "'.");
+                    throw new python.Error("Unsupported scalar type '" + dtype.__name__ + "'.");
             }
         });
         this.registerFunction('numpy.load', function(file) {

+ 22 - 8
source/sklearn.js

@@ -336,9 +336,9 @@ sklearn.Tensor = class {
             const type = value.__class__.__module__ + '.' + value.__class__.__name__;
             throw new sklearn.Error("Unsupported tensor type '" + type + "'.");
         }
-        this._type = new sklearn.TensorType(value.dtype.name, new sklearn.TensorShape(value.shape));
+        this._type = new sklearn.TensorType(value.dtype.__name__, new sklearn.TensorShape(value.shape));
         this._data = value.data;
-        if (value.dtype.name === 'string') {
+        if (this._type.dataType === 'string') {
             this._itemsize = value.dtype.itemsize;
         }
     }
@@ -402,8 +402,10 @@ sklearn.Tensor = class {
         switch (context.dataType) {
             case 'float32':
             case 'float64':
-            case 'int32':
             case 'uint32':
+            case 'int8':
+            case 'int16':
+            case 'int32':
             case 'int64':
             case 'uint64':
                 context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
@@ -446,14 +448,20 @@ sklearn.Tensor = class {
                         context.count++;
                         break;
                     }
-                    case 'int32': {
-                        results.push(context.view.getInt32(context.index, true));
-                        context.index += 4;
+                    case 'int8': {
+                        results.push(context.view.getInt8(context.index, true));
+                        context.index += 1;
                         context.count++;
                         break;
                     }
-                    case 'uint32': {
-                        results.push(context.view.getUint32(context.index, true));
+                    case 'int16': {
+                        results.push(context.view.getInt16(context.index, true));
+                        context.index += 2;
+                        context.count++;
+                        break;
+                    }
+                    case 'int32': {
+                        results.push(context.view.getInt32(context.index, true));
                         context.index += 4;
                         context.count++;
                         break;
@@ -464,6 +472,12 @@ sklearn.Tensor = class {
                         context.count++;
                         break;
                     }
+                    case 'uint32': {
+                        results.push(context.view.getUint32(context.index, true));
+                        context.index += 4;
+                        context.count++;
+                        break;
+                    }
                     case 'uint64': {
                         results.push(context.view.getUint64(context.index, true));
                         context.index += 8;