Przeglądaj źródła

Update mxnet.js (#889)

Lutz Roeder 4 lat temu
rodzic
commit
a52301ca18
1 zmienionych plików z 24 dodań i 59 usunięć
  1. 24 59
      source/mxnet.js

+ 24 - 59
source/mxnet.js

@@ -3,6 +3,7 @@ var mxnet = mxnet || {};
 var json = json || require('./json');
 var zip = zip || require('./zip');
 var ndarray = ndarray || {};
+var base = base || require('./base');
 
 mxnet.ModelFactory = class {
 
@@ -964,7 +965,28 @@ mxnet.ndarray = class {
     static load(buffer) {
         // NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
         const map = new Map();
-        const reader = new mxnet.ndarray.BinaryReader(buffer);
+        const reader = new base.BinaryReader(buffer);
+        reader.uint64 = function() {
+            const value = this.uint32();
+            if (this.uint32() != 0) {
+                throw new mxnet.Error('Large uint64 value.');
+            }
+            return value;
+        };
+        reader.uint32s = function() {
+            const array = new Array(this.uint32());
+            for (let i = 0; i < array.length; i++) {
+                array[i] = this.uint32();
+            }
+            return array;
+        };
+        reader.uint64s = function() {
+            const array = new Array(this.uint32());
+            for (let i = 0; i < array.length; i++) {
+                array[i] = this.uint64();
+            }
+            return array;
+        };
         if (reader.uint64() !== 0x112) { // kMXAPINDArrayListMagic
             throw new mxnet.Error('Invalid signature.');
         }
@@ -990,13 +1012,12 @@ mxnet.ndarray = class {
     }
 };
 
-
 mxnet.ndarray.NDArray = class {
 
     constructor(reader) {
         mxnet.ndarray.NDArray._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
         switch (reader.uint32()) {
-            case 0xF993faca: { // NDARRAY_V3_MAGIC
+            case 0xf993faca: { // NDARRAY_V3_MAGIC
                 throw new mxnet.Array('mxnet.ndarray.NDArray v3 not supported.');
             }
             case 0xf993fac9: { // NDARRAY_V2_MAGIC
@@ -1063,62 +1084,6 @@ mxnet.context.Context = class {
     }
 };
 
-mxnet.ndarray.BinaryReader = class {
-
-    constructor(buffer) {
-        this._buffer = buffer;
-        this._position = 0;
-        this._end = buffer.length;
-    }
-
-    skip(offset) {
-        this._position += offset;
-        if (this._position > this._end) {
-            throw new mxnet.Error('Data not available.');
-        }
-    }
-
-    read(size) {
-        const position = this._position;
-        this.skip(size);
-        return this._buffer.subarray(position, this._position);
-    }
-
-    uint16() {
-        const position = this._position;
-        this.skip(2);
-        return this._buffer[position] | (this._buffer[position + 1] << 8);
-    }
-
-    uint32() {
-        return this.uint16() | (this.uint16() << 16);
-    }
-
-    uint64() {
-        const value = this.uint32();
-        if (this.uint32() != 0) {
-            throw new mxnet.Error('Large int64 value.');
-        }
-        return value;
-    }
-
-    uint32s() {
-        const array = new Array(this.uint32());
-        for (let i = 0; i < array.length; i++) {
-            array[i] = this.uint32();
-        }
-        return array;
-    }
-
-    uint64s() {
-        const array = new Array(this.uint32());
-        for (let i = 0; i < array.length; i++) {
-            array[i] = this.uint64();
-        }
-        return array;
-    }
-};
-
 mxnet.Metadata = class {
 
     static open(context) {