Browse Source

Update NumPy support (#711)

Lutz Roeder 4 năm trước cách đây
mục cha
commit
a3285b991a
4 tập tin đã thay đổi với 106 bổ sung73 xóa
  1. 84 9
      source/python.js
  2. 1 1
      source/view-sidebar.js
  3. 6 63
      source/view.js
  4. 15 0
      test/models.js

+ 84 - 9
source/python.js

@@ -2572,19 +2572,14 @@ python.Execution = class {
             if (descr[0] !== '<' && descr[0] !== '>') {
                 throw new numpy.Error("Unknown byte order '" + descr + "'.");
             }
-            if (descr.length !== 3 || (descr[1] !== 'f' && descr[1] !== 'i' && descr[1] !== 'u')) {
+            if (descr.length !== 3 || (descr[1] !== 'f' && descr[1] !== 'i' && descr[1] !== 'u' && descr.substring(1) !== 'b1')) {
                 throw new numpy.Error("Unsupported data type '" + descr + "'.");
             }
             let shape = '';
             switch (arr.shape.length) {
-                case 0:
-                    throw new numpy.Error('Invalid shape.');
-                case 1:
-                    shape = '(' + arr.shape[0].toString() + ',)';
-                    break;
-                default:
-                    shape = '(' + arr.shape.map((dimension) => dimension.toString()).join(', ') + ')';
-                    break;
+                case 0: shape = '()'; break;
+                case 1: shape = '(' + arr.shape[0].toString() + ',)'; break;
+                default: shape = '(' + arr.shape.map((dimension) => dimension.toString()).join(', ') + ')'; break;
             }
             const properties = [
                 "'descr': '" + descr + "'",
@@ -2599,6 +2594,86 @@ python.Execution = class {
             file.write(encoder.encode(header));
             file.write(arr.tobytes());
         });
+        this.registerFunction('numpy.asarray', function(a, dtype) {
+            const encode = (context, data, dim) => {
+                const size = context.shape[dim];
+                const littleendian = context.littleendian;
+                if (dim == context.shape.length - 1) {
+                    for (let i = 0; i < size; i++) {
+                        switch (context.dtype) {
+                            case 'f2':
+                                context.view.setFloat16(context.position, data[i], littleendian);
+                                break;
+                            case 'f4':
+                                context.view.setFloat32(context.position, data[i], littleendian);
+                                break;
+                            case 'f8':
+                                context.view.setFloat64(context.position, data[i], littleendian);
+                                break;
+                            case 'i1':
+                                context.view.setInt8(context.position, data[i], littleendian);
+                                break;
+                            case 'i2':
+                                context.view.setInt16(context.position, data[i], littleendian);
+                                break;
+                            case 'i4':
+                                context.view.setInt32(context.position, data[i], littleendian);
+                                break;
+                            case 'i8':
+                                context.view.setInt64(context.position, data[i], littleendian);
+                                break;
+                            case 'u1':
+                                context.view.setUint8(context.position, data[i], littleendian);
+                                break;
+                            case 'u2':
+                                context.view.setUint16(context.position, data[i], littleendian);
+                                break;
+                            case 'u4':
+                                context.view.setUint32(context.position, data[i], littleendian);
+                                break;
+                            case 'u8':
+                                context.view.setUint64(context.position, data[i], littleendian);
+                                break;
+                        }
+                        context.position += context.itemsize;
+                    }
+                }
+                else {
+                    for (let j = 0; j < size; j++) {
+                        encode(context, data[j], dim + 1);
+                    }
+                }
+            };
+            const array_size = (value) => {
+                if (value.every((item) => Array.isArray(item))) {
+                    const dims = value.map((item) => array_size(item));
+                    const dim = dims[0];
+                    for (let i = 1; i < dims.length; i++) {
+                        if (dim.length === dims[i].length) {
+                            if (!dims[i].every((value, i) => value ===dim[i])) {
+                                throw new python.Error('Invalid array shape.');
+                            }
+                        }
+                    }
+                    return [ value.length ].concat(dim);
+                }
+                return [ value.length ];
+            };
+            const shape = Array.isArray(a) ? array_size(a) : [];
+            const size = dtype.itemsize * shape.reduce((a, b) => a * b, 1);
+            const context = {
+                position: 0,
+                itemsize: dtype.itemsize,
+                dtype: dtype.str.substring(1),
+                littleendian: dtype.str[0],
+                shape: shape,
+                data: new Uint8Array(size)
+            };
+            context.view = new DataView(context.data.buffer, context.data.byteOffset, size);
+            encode(context, a, 0);
+            return self.invoke('numpy.ndarray', [ shape, dtype, context.data ]);
+
+        });
         this.registerFunction('numpy.ma.core._mareconstruct', function(subtype, baseclass, baseshape, basetype) {
             const data = self.invoke(baseclass, [ baseshape, basetype ]);
             // = ndarray.__new__(ndarray, baseshape, make_mask_descr(basetype))

+ 1 - 1
source/view-sidebar.js

@@ -765,7 +765,7 @@ sidebar.ArgumentView = class {
                         const state = initializer.state;
                         if (state === null && this._host.save &&
                             initializer.type.dataType && initializer.type.dataType != '?' &&
-                            initializer.type.shape && initializer.type.shape.dimensions && initializer.type.shape.dimensions.length > 0) {
+                            initializer.type.shape && initializer.type.shape.dimensions /*&& initializer.type.shape.dimensions.length > 0*/) {
                             this._saveButton = this._host.document.createElement('div');
                             this._saveButton.className = 'sidebar-view-item-value-expander';
                             this._saveButton.innerHTML = '&#x1F4BE;';

+ 6 - 63
source/view.js

@@ -867,71 +867,14 @@ view.View = class {
                     const defaultPath = tensor.name ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_') : 'tensor';
                     this._host.save('NumPy Array', 'npy', defaultPath, (file) => {
                         try {
-                            const encode = (context, data, dim) => {
-                                const size = context.shape[dim];
-                                const littleendian = context.littleendian;
-                                if (dim == context.shape.length - 1) {
-                                    for (let i = 0; i < size; i++) {
-                                        switch (context.dtype) {
-                                            case 'f2':
-                                                context.view.setFloat16(context.position, data[i], littleendian);
-                                                break;
-                                            case 'f4':
-                                                context.view.setFloat32(context.position, data[i], littleendian);
-                                                break;
-                                            case 'f8':
-                                                context.view.setFloat64(context.position, data[i], littleendian);
-                                                break;
-                                            case 'i1':
-                                                context.view.setInt8(context.position, data[i], littleendian);
-                                                break;
-                                            case 'i2':
-                                                context.view.setInt16(context.position, data[i], littleendian);
-                                                break;
-                                            case 'i4':
-                                                context.view.setInt32(context.position, data[i], littleendian);
-                                                break;
-                                            case 'i8':
-                                                context.view.setInt64(context.position, data[i], littleendian);
-                                                break;
-                                            case 'u1':
-                                                context.view.setUint8(context.position, data[i], littleendian);
-                                                break;
-                                            case 'u2':
-                                                context.view.setUint16(context.position, data[i], littleendian);
-                                                break;
-                                            case 'u4':
-                                                context.view.setUint32(context.position, data[i], littleendian);
-                                                break;
-                                            case 'u8':
-                                                context.view.setUint64(context.position, data[i], littleendian);
-                                                break;
-                                        }
-                                        context.position += context.itemsize;
-                                    }
-                                }
-                                else {
-                                    for (let j = 0; j < size; j++) {
-                                        encode(context, data[j], dim + 1);
-                                    }
-                                }
-                            };
+                            let data_type = tensor.type.dataType;
+                            switch (data_type) {
+                                case 'boolean': data_type = 'bool'; break;
+                            }
                             const execution = new python.Execution(null);
                             const bytes = execution.invoke('io.BytesIO', []);
-                            const dtype = execution.invoke('numpy.dtype', [ tensor.type.dataType ]);
-                            const shape = tensor.type.shape.dimensions.map((dim) => dim instanceof base.Int64 || dim instanceof base.Uint64 ? dim.toNumber() : dim);
-                            const size = dtype.itemsize * shape.reduce((a, b) => a * b, 1);
-                            const context = {
-                                position: 0,
-                                itemsize: dtype.itemsize,
-                                dtype: dtype.str.substring(1),
-                                littleendian: dtype.str[0],
-                                shape: shape,
-                                data: new Uint8Array(size)
-                            };
-                            context.view = new DataView(context.data.buffer, context.data.byteOffset, size);
-                            encode(context, tensor.value, 0);
-                            const array = execution.invoke('numpy.ndarray', [ tensor.type.shape.dimensions, dtype, context.data ]);
+                            const dtype = execution.invoke('numpy.dtype', [ data_type ]);
+                            const array = execution.invoke('numpy.asarray', [ tensor.value, dtype ]);
                             execution.invoke('numpy.save', [ bytes, array ]);
                             bytes.seek(0);
                             const blob = new Blob([ bytes.read() ], { type: 'application/octet-stream' });

+ 15 - 0
test/models.js

@@ -666,6 +666,21 @@ const loadModel = (target, item) => {
                         if (argument.initializer) {
                             argument.initializer.toString();
                             argument.initializer.type.toString();
+                            /*
+                            const python = require('../source/python');
+                            const tensor = argument.initializer;
+                            if (tensor.type && tensor.type.dataType !== '?') {
+                                let data_type = tensor.type.dataType;
+                                switch (data_type) {
+                                    case 'boolean': data_type = 'bool'; break;
+                                }
+                                const execution = new python.Execution(null);
+                                const bytes = execution.invoke('io.BytesIO', []);
+                                const dtype = execution.invoke('numpy.dtype', [ data_type ]);
+                                const array = execution.invoke('numpy.asarray', [ tensor.value, dtype ]);
+                                execution.invoke('numpy.save', [ bytes, array ]);
+                            }
+                            */
                         }
                     }
                 }