Lutz Roeder 4 rokov pred
rodič
commit
bfd367c0ef
1 zmenil súbory, kde vykonal 155 pridanie a 197 odobranie
  1. 155 197
      source/mxnet.js

+ 155 - 197
source/mxnet.js

@@ -163,10 +163,11 @@ mxnet.ModelFactory = class {
                 const parameters = new Map();
                 if (params) {
                     try {
-                        const stream = new ndarray.Stream(params);
-                        for (const key of Object.keys(stream.arrays)) {
+                        for (const entry of mxnet.ndarray.load(params)) {
+                            const key = entry[0];
+                            const array = entry[1];
                             const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
-                            parameters.set(name, stream.arrays[key]);
+                            parameters.set(name, array);
                         }
                     }
                     catch (error) {
@@ -315,7 +316,7 @@ mxnet.Graph = class {
             for (const pair of params) {
                 const key = pair[0];
                 const value = pair[1];
-                tensors.set(key, new mxnet.Tensor('Initializer', key, new mxnet.TensorType(value.dataType, new mxnet.TensorShape(value.shape.dimensions)), value.data));
+                tensors.set(key, new mxnet.Tensor('Initializer', key, new mxnet.TensorType(value.dtype, new mxnet.TensorShape(value.shape)), value.data));
             }
         }
 
@@ -958,196 +959,103 @@ mxnet.TensorShape = class {
     }
 };
 
-mxnet.Metadata = class {
+mxnet.ndarray = class {
 
-    static open(context) {
-        if (mxnet.Metadata._metadata) {
-            return Promise.resolve(mxnet.Metadata._metadata);
+    static load(buffer) {
+        // NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
+        const map = new Map();
+        const reader = new mxnet.ndarray.BinaryReader(buffer);
+        if (reader.uint64() !== 0x112) { // kMXAPINDArrayListMagic
+            throw new mxnet.Error('Invalid signature.');
         }
-        return context.request('mxnet-metadata.json', 'utf-8', null).then((data) => {
-            mxnet.Metadata._metadata = new mxnet.Metadata(data);
-            return mxnet.Metadata._metadata;
-        }).catch(() => {
-            mxnet.Metadata._metadata = new mxnet.Metadata(null);
-            return mxnet.Metadata._metadata;
-        });
-    }
-
-    constructor(data) {
-        this._map = new Map();
-        this._attributeCache = {};
-        if (data) {
-            const metadata = JSON.parse(data);
-            this._map = new Map(metadata.map((item) => [ item.name, item ]));
+        if (reader.uint64() !== 0) {
+            throw new mxnet.Error('Invalid reserved block.');
         }
-    }
-
-    type(name) {
-        return this._map.get(name);
-    }
-
-    attribute(type, name) {
-        let map = this._attributeCache[type];
-        if (!map) {
-            map = {};
-            const schema = this.type(type);
-            if (schema && schema.attributes) {
-                for (const attribute of schema.attributes) {
-                    map[attribute.name] = attribute;
-                }
-            }
-            this._attributeCache[type] = map;
+        const data = new Array(reader.uint64());
+        for (let i = 0; i < data.length; i++) {
+            data[i] = new mxnet.ndarray.NDArray(reader);
         }
-        return map[name] || null;
-    }
-};
-
-mxnet.Error = class extends Error {
-
-    constructor(message) {
-        super(message);
-        this.name = 'Error loading MXNet model.';
-    }
-};
-
-ndarray.Stream = class {
-
-    constructor(buffer) {
-
-        this._arrays = {};
-
-        const reader = new ndarray.Reader(buffer);
-        if (!reader.checkSignature([ 0x12, 1, 0, 0, 0, 0, 0, 0 ])) {
-            throw new ndarray.Error('Invalid signature.');
-        }
-        if (!reader.checkSignature([ 0, 0, 0, 0, 0, 0, 0, 0 ])) {
-            throw new ndarray.Error('Invalid reserved block.');
-        }
-
-        const data = [];
-        for (let dataSize = reader.uint64(); dataSize > 0; dataSize--) {
-            data.push(new ndarray.Array(reader));
-        }
-
         const decoder = new TextDecoder('ascii');
-        const names = [];
-        for (let namesSize = reader.uint64(); namesSize > 0; namesSize--) {
-            const name = decoder.decode(reader.read(reader.uint64()));
-            names.push(name);
+        const names = new Array(reader.uint64());
+        for (let i = 0; i < names.length; i++) {
+            names[i] = decoder.decode(reader.read(reader.uint64()));
         }
-
         if (names.length != data.length) {
-            throw new ndarray.Error('Label count mismatch.');
+            throw new mxnet.Error('Label count mismatch.');
         }
-
         for (let i = 0; i < names.length; i++) {
-            this._arrays[names[i]] = data[i];
+            map.set(names[i], data[i]);
         }
+        return map;
     }
-
-    get arrays() {
-        return this._arrays;
-    }
-
 };
 
-ndarray.Array = class {
-
-    constructor(reader) {
 
-        ndarray.Array._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
+mxnet.ndarray.NDArray = class {
 
-        if (reader.checkSignature([ 0xc9, 0xfa, 0x93, 0xF9 ])) {
-            this._loadV2(reader);
-        }
-        else if (reader.checkSignature([ 0xc8, 0xfa, 0x93, 0xF9 ])) {
-            this._loadV1(reader);
-        }
-        else {
-            this._loadV0(reader);
-        }
-    }
-
-    _loadV2(reader) {
-        const stype = reader.uint32();
-        let num_aux_data = 0;
-        switch (stype) {
-            case 0: num_aux_data = 0; break; // kDefaultStorage
-            case 1: num_aux_data = 1; break; // kRowSparseStorage
-            case 2: num_aux_data = 2; break; // kCSRStorage
-        }
-        this.sshape = null;
-        if (num_aux_data > 0) {
-            this.sshape = new ndarray.Shape(reader, true);
-        }
-        this._shape = new ndarray.Shape(reader, true);
-        if (this._shape.dimensions.length == 0) {
-            return;
-        }
-        this._context = new ndarray.Context(reader);
-        this._dataType = reader.uint32();
-        if (num_aux_data > 0) {
-            throw new ndarray.Error('Not implemented.');
-        }
-        const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
-        const size = dataTypeSize * this._shape.size();
-        this._data = reader.read(size);
-    }
-
-    _loadV1(reader) {
-        this._shape = new ndarray.Shape(reader, true);
-        if (this._shape.dimensions.length == 0) {
-            return;
+    constructor(reader) {
+        mxnet.ndarray.NDArray._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
+        switch (reader.uint32()) {
+            case 0xF993faca: { // NDARRAY_V3_MAGIC
+                throw new mxnet.Array('mxnet.ndarray.NDArray v3 not supported.');
+            }
+            case 0xf993fac9: { // NDARRAY_V2_MAGIC
+                const stype = reader.uint32();
+                let num_aux_data = 0;
+                switch (stype) {
+                    case 0: num_aux_data = 0; break; // kDefaultStorage
+                    case 1: num_aux_data = 1; break; // kRowSparseStorage
+                    case 2: num_aux_data = 2; break; // kCSRStorage
+                }
+                this.sshape = null;
+                if (num_aux_data > 0) {
+                    this.sshape = reader.uint64s();
+                }
+                this.shape = reader.uint64s();
+                if (this.shape.length !== 0) {
+                    this.context = new mxnet.context.Context(reader);
+                    this.dtype = reader.uint32();
+                    if (num_aux_data > 0) {
+                        throw new mxnet.Error('Not implemented.');
+                    }
+                    const dataTypeSize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
+                    const size = dataTypeSize * this.size;
+                    this.data = reader.read(size);
+                }
+                break;
+            }
+            case 0xf993fac8: { // NDARRAY_V1_MAGIC
+                this.shape = reader.uint64s();
+                if (this.shape.length !== 0) {
+                    this.context = new mxnet.context.Context(reader);
+                    this.dtype = reader.uint32();
+                    const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
+                    const size = itemsize * this.size;
+                    this.data = reader.read(size);
+                }
+                break;
+            }
+            default: {
+                reader.skip(-4);
+                this.shape = reader.uint32s();
+                this.context = new mxnet.context.Context(reader);
+                this.dtype = reader.uint32();
+                const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
+                const size = itemsize * this.size;
+                this.data = reader.read(size);
+                break;
+            }
         }
-        this._context = new ndarray.Context(reader);
-        this._dataType = reader.uint32();
-        const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
-        const size = dataTypeSize * this._shape.size();
-        this._data = reader.read(size);
     }
 
-    _loadV0(reader) {
-        this._shape = new ndarray.Shape(reader, false);
-        this._context = new ndarray.Context(reader);
-        this._dataType = reader.uint32();
-        const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
-        const size = dataTypeSize * this._shape.size();
-        this._data = reader.read(size);
-    }
-
-    get dataType() {
-        return this._dataType;
-    }
-
-    get shape() {
-        return this._shape;
-    }
-
-    get data() {
-        return this._data;
+    get size() {
+        return this.shape.reduce((a, b) => a * b, 1);
     }
 };
 
-ndarray.Shape = class {
-
-    constructor(reader, uint64) {
-        const ndim = reader.uint32();
-        this._dimensions = [];
-        for (let i = 0; i < ndim; i++) {
-            this._dimensions.push(uint64 ? reader.uint64() : reader.uint32());
-        }
-    }
-
-    get dimensions() {
-        return this._dimensions;
-    }
-
-    size() {
-        return this._dimensions.reduce((a, b) => a * b, 1);
-    }
-};
+mxnet.context = {};
 
-ndarray.Context = class {
+mxnet.context.Context = class {
 
     constructor(reader) {
         this._deviceType = reader.uint32();
@@ -1155,7 +1063,7 @@ ndarray.Context = class {
     }
 };
 
-ndarray.Reader = class {
+mxnet.ndarray.BinaryReader = class {
 
     constructor(buffer) {
         this._buffer = buffer;
@@ -1163,34 +1071,23 @@ ndarray.Reader = class {
         this._end = buffer.length;
     }
 
-    checkSignature(signature) {
-        if (this._position + signature.length <= this._end) {
-            for (let i = 0; i < signature.length; i++) {
-                if (this._buffer[this._position + i] != signature[i]) {
-                    return false;
-                }
-            }
+    skip(offset) {
+        this._position += offset;
+        if (this._position > this._end) {
+            throw new mxnet.Error('Data not available.');
         }
-        this._position += signature.length;
-        return true;
     }
 
     read(size) {
-        if (this._position + size > this._end) {
-            throw new ndarray.Error('Data not available.');
-        }
-        const data = this._buffer.subarray(this._position, this._position + size);
-        this._position += size;
-        return data;
+        const position = this._position;
+        this.skip(size);
+        return this._buffer.subarray(position, this._position);
     }
 
     uint16() {
-        if (this._position + 2 > this._end) {
-            throw new ndarray.Error('Data not available.');
-        }
-        const value = this._buffer[this._position] | (this._buffer[this._position + 1] << 8);
-        this._position += 2;
-        return value;
+        const position = this._position;
+        this.skip(2);
+        return this._buffer[position] | (this._buffer[position + 1] << 8);
     }
 
     uint32() {
@@ -1200,16 +1097,77 @@ ndarray.Reader = class {
     uint64() {
         const value = this.uint32();
         if (this.uint32() != 0) {
-            throw new ndarray.Error('Large int64 value.');
+            throw new mxnet.Error('Large int64 value.');
         }
         return value;
     }
+
+    uint32s() {
+        const array = new Array(this.uint32());
+        for (let i = 0; i < array.length; i++) {
+            array[i] = this.uint32();
+        }
+        return array;
+    }
+
+    uint64s() {
+        const array = new Array(this.uint32());
+        for (let i = 0; i < array.length; i++) {
+            array[i] = this.uint64();
+        }
+        return array;
+    }
+};
+
+mxnet.Metadata = class {
+
+    static open(context) {
+        if (mxnet.Metadata._metadata) {
+            return Promise.resolve(mxnet.Metadata._metadata);
+        }
+        return context.request('mxnet-metadata.json', 'utf-8', null).then((data) => {
+            mxnet.Metadata._metadata = new mxnet.Metadata(data);
+            return mxnet.Metadata._metadata;
+        }).catch(() => {
+            mxnet.Metadata._metadata = new mxnet.Metadata(null);
+            return mxnet.Metadata._metadata;
+        });
+    }
+
+    constructor(data) {
+        this._map = new Map();
+        this._attributeCache = {};
+        if (data) {
+            const metadata = JSON.parse(data);
+            this._map = new Map(metadata.map((item) => [ item.name, item ]));
+        }
+    }
+
+    type(name) {
+        return this._map.get(name);
+    }
+
+    attribute(type, name) {
+        let map = this._attributeCache[type];
+        if (!map) {
+            map = {};
+            const schema = this.type(type);
+            if (schema && schema.attributes) {
+                for (const attribute of schema.attributes) {
+                    map[attribute.name] = attribute;
+                }
+            }
+            this._attributeCache[type] = map;
+        }
+        return map[name] || null;
+    }
 };
 
-ndarray.Error = class extends Error {
+mxnet.Error = class extends Error {
+
     constructor(message) {
         super(message);
-        this.name = 'NDArray Error';
+        this.name = 'Error loading MXNet model.';
     }
 };