|
|
@@ -163,10 +163,11 @@ mxnet.ModelFactory = class {
|
|
|
const parameters = new Map();
|
|
|
if (params) {
|
|
|
try {
|
|
|
- const stream = new ndarray.Stream(params);
|
|
|
- for (const key of Object.keys(stream.arrays)) {
|
|
|
+ for (const entry of mxnet.ndarray.load(params)) {
|
|
|
+ const key = entry[0];
|
|
|
+ const array = entry[1];
|
|
|
const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
|
|
|
- parameters.set(name, stream.arrays[key]);
|
|
|
+ parameters.set(name, array);
|
|
|
}
|
|
|
}
|
|
|
catch (error) {
|
|
|
@@ -315,7 +316,7 @@ mxnet.Graph = class {
|
|
|
for (const pair of params) {
|
|
|
const key = pair[0];
|
|
|
const value = pair[1];
|
|
|
- tensors.set(key, new mxnet.Tensor('Initializer', key, new mxnet.TensorType(value.dataType, new mxnet.TensorShape(value.shape.dimensions)), value.data));
|
|
|
+ tensors.set(key, new mxnet.Tensor('Initializer', key, new mxnet.TensorType(value.dtype, new mxnet.TensorShape(value.shape)), value.data));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -958,196 +959,103 @@ mxnet.TensorShape = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-mxnet.Metadata = class {
|
|
|
+mxnet.ndarray = class {
|
|
|
|
|
|
- static open(context) {
|
|
|
- if (mxnet.Metadata._metadata) {
|
|
|
- return Promise.resolve(mxnet.Metadata._metadata);
|
|
|
+ static load(buffer) {
|
|
|
+ // NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
|
|
|
+ const map = new Map();
|
|
|
+ const reader = new mxnet.ndarray.BinaryReader(buffer);
|
|
|
+ if (reader.uint64() !== 0x112) { // kMXAPINDArrayListMagic
|
|
|
+ throw new mxnet.Error('Invalid signature.');
|
|
|
}
|
|
|
- return context.request('mxnet-metadata.json', 'utf-8', null).then((data) => {
|
|
|
- mxnet.Metadata._metadata = new mxnet.Metadata(data);
|
|
|
- return mxnet.Metadata._metadata;
|
|
|
- }).catch(() => {
|
|
|
- mxnet.Metadata._metadata = new mxnet.Metadata(null);
|
|
|
- return mxnet.Metadata._metadata;
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- constructor(data) {
|
|
|
- this._map = new Map();
|
|
|
- this._attributeCache = {};
|
|
|
- if (data) {
|
|
|
- const metadata = JSON.parse(data);
|
|
|
- this._map = new Map(metadata.map((item) => [ item.name, item ]));
|
|
|
+ if (reader.uint64() !== 0) {
|
|
|
+ throw new mxnet.Error('Invalid reserved block.');
|
|
|
}
|
|
|
- }
|
|
|
-
|
|
|
- type(name) {
|
|
|
- return this._map.get(name);
|
|
|
- }
|
|
|
-
|
|
|
- attribute(type, name) {
|
|
|
- let map = this._attributeCache[type];
|
|
|
- if (!map) {
|
|
|
- map = {};
|
|
|
- const schema = this.type(type);
|
|
|
- if (schema && schema.attributes) {
|
|
|
- for (const attribute of schema.attributes) {
|
|
|
- map[attribute.name] = attribute;
|
|
|
- }
|
|
|
- }
|
|
|
- this._attributeCache[type] = map;
|
|
|
+ const data = new Array(reader.uint64());
|
|
|
+ for (let i = 0; i < data.length; i++) {
|
|
|
+ data[i] = new mxnet.ndarray.NDArray(reader);
|
|
|
}
|
|
|
- return map[name] || null;
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
-mxnet.Error = class extends Error {
|
|
|
-
|
|
|
- constructor(message) {
|
|
|
- super(message);
|
|
|
- this.name = 'Error loading MXNet model.';
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
-ndarray.Stream = class {
|
|
|
-
|
|
|
- constructor(buffer) {
|
|
|
-
|
|
|
- this._arrays = {};
|
|
|
-
|
|
|
- const reader = new ndarray.Reader(buffer);
|
|
|
- if (!reader.checkSignature([ 0x12, 1, 0, 0, 0, 0, 0, 0 ])) {
|
|
|
- throw new ndarray.Error('Invalid signature.');
|
|
|
- }
|
|
|
- if (!reader.checkSignature([ 0, 0, 0, 0, 0, 0, 0, 0 ])) {
|
|
|
- throw new ndarray.Error('Invalid reserved block.');
|
|
|
- }
|
|
|
-
|
|
|
- const data = [];
|
|
|
- for (let dataSize = reader.uint64(); dataSize > 0; dataSize--) {
|
|
|
- data.push(new ndarray.Array(reader));
|
|
|
- }
|
|
|
-
|
|
|
const decoder = new TextDecoder('ascii');
|
|
|
- const names = [];
|
|
|
- for (let namesSize = reader.uint64(); namesSize > 0; namesSize--) {
|
|
|
- const name = decoder.decode(reader.read(reader.uint64()));
|
|
|
- names.push(name);
|
|
|
+ const names = new Array(reader.uint64());
|
|
|
+ for (let i = 0; i < names.length; i++) {
|
|
|
+ names[i] = decoder.decode(reader.read(reader.uint64()));
|
|
|
}
|
|
|
-
|
|
|
if (names.length != data.length) {
|
|
|
- throw new ndarray.Error('Label count mismatch.');
|
|
|
+ throw new mxnet.Error('Label count mismatch.');
|
|
|
}
|
|
|
-
|
|
|
for (let i = 0; i < names.length; i++) {
|
|
|
- this._arrays[names[i]] = data[i];
|
|
|
+ map.set(names[i], data[i]);
|
|
|
}
|
|
|
+ return map;
|
|
|
}
|
|
|
-
|
|
|
- get arrays() {
|
|
|
- return this._arrays;
|
|
|
- }
|
|
|
-
|
|
|
};
|
|
|
|
|
|
-ndarray.Array = class {
|
|
|
-
|
|
|
- constructor(reader) {
|
|
|
|
|
|
- ndarray.Array._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
|
|
|
+mxnet.ndarray.NDArray = class {
|
|
|
|
|
|
- if (reader.checkSignature([ 0xc9, 0xfa, 0x93, 0xF9 ])) {
|
|
|
- this._loadV2(reader);
|
|
|
- }
|
|
|
- else if (reader.checkSignature([ 0xc8, 0xfa, 0x93, 0xF9 ])) {
|
|
|
- this._loadV1(reader);
|
|
|
- }
|
|
|
- else {
|
|
|
- this._loadV0(reader);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- _loadV2(reader) {
|
|
|
- const stype = reader.uint32();
|
|
|
- let num_aux_data = 0;
|
|
|
- switch (stype) {
|
|
|
- case 0: num_aux_data = 0; break; // kDefaultStorage
|
|
|
- case 1: num_aux_data = 1; break; // kRowSparseStorage
|
|
|
- case 2: num_aux_data = 2; break; // kCSRStorage
|
|
|
- }
|
|
|
- this.sshape = null;
|
|
|
- if (num_aux_data > 0) {
|
|
|
- this.sshape = new ndarray.Shape(reader, true);
|
|
|
- }
|
|
|
- this._shape = new ndarray.Shape(reader, true);
|
|
|
- if (this._shape.dimensions.length == 0) {
|
|
|
- return;
|
|
|
- }
|
|
|
- this._context = new ndarray.Context(reader);
|
|
|
- this._dataType = reader.uint32();
|
|
|
- if (num_aux_data > 0) {
|
|
|
- throw new ndarray.Error('Not implemented.');
|
|
|
- }
|
|
|
- const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
|
|
|
- const size = dataTypeSize * this._shape.size();
|
|
|
- this._data = reader.read(size);
|
|
|
- }
|
|
|
-
|
|
|
- _loadV1(reader) {
|
|
|
- this._shape = new ndarray.Shape(reader, true);
|
|
|
- if (this._shape.dimensions.length == 0) {
|
|
|
- return;
|
|
|
+ constructor(reader) {
|
|
|
+ mxnet.ndarray.NDArray._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
|
|
|
+ switch (reader.uint32()) {
|
|
|
+ case 0xF993faca: { // NDARRAY_V3_MAGIC
|
|
|
+ throw new mxnet.Array('mxnet.ndarray.NDArray v3 not supported.');
|
|
|
+ }
|
|
|
+ case 0xf993fac9: { // NDARRAY_V2_MAGIC
|
|
|
+ const stype = reader.uint32();
|
|
|
+ let num_aux_data = 0;
|
|
|
+ switch (stype) {
|
|
|
+ case 0: num_aux_data = 0; break; // kDefaultStorage
|
|
|
+ case 1: num_aux_data = 1; break; // kRowSparseStorage
|
|
|
+ case 2: num_aux_data = 2; break; // kCSRStorage
|
|
|
+ }
|
|
|
+ this.sshape = null;
|
|
|
+ if (num_aux_data > 0) {
|
|
|
+ this.sshape = reader.uint64s();
|
|
|
+ }
|
|
|
+ this.shape = reader.uint64s();
|
|
|
+ if (this.shape.length !== 0) {
|
|
|
+ this.context = new mxnet.context.Context(reader);
|
|
|
+ this.dtype = reader.uint32();
|
|
|
+ if (num_aux_data > 0) {
|
|
|
+ throw new mxnet.Error('Not implemented.');
|
|
|
+ }
|
|
|
+ const dataTypeSize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
|
|
|
+ const size = dataTypeSize * this.size;
|
|
|
+ this.data = reader.read(size);
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 0xf993fac8: { // NDARRAY_V1_MAGIC
|
|
|
+ this.shape = reader.uint64s();
|
|
|
+ if (this.shape.length !== 0) {
|
|
|
+ this.context = new mxnet.context.Context(reader);
|
|
|
+ this.dtype = reader.uint32();
|
|
|
+ const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
|
|
|
+ const size = itemsize * this.size;
|
|
|
+ this.data = reader.read(size);
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ default: {
|
|
|
+ reader.skip(-4);
|
|
|
+ this.shape = reader.uint32s();
|
|
|
+ this.context = new mxnet.context.Context(reader);
|
|
|
+ this.dtype = reader.uint32();
|
|
|
+ const itemsize = (this.dtype < mxnet.ndarray.NDArray._dataTypeSizeTable.length) ? mxnet.ndarray.NDArray._dataTypeSizeTable[this.dtype] : 0;
|
|
|
+ const size = itemsize * this.size;
|
|
|
+ this.data = reader.read(size);
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
- this._context = new ndarray.Context(reader);
|
|
|
- this._dataType = reader.uint32();
|
|
|
- const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
|
|
|
- const size = dataTypeSize * this._shape.size();
|
|
|
- this._data = reader.read(size);
|
|
|
}
|
|
|
|
|
|
- _loadV0(reader) {
|
|
|
- this._shape = new ndarray.Shape(reader, false);
|
|
|
- this._context = new ndarray.Context(reader);
|
|
|
- this._dataType = reader.uint32();
|
|
|
- const dataTypeSize = (this._dataType < ndarray.Array._dataTypeSizeTable.length) ? ndarray.Array._dataTypeSizeTable[this._dataType] : 0;
|
|
|
- const size = dataTypeSize * this._shape.size();
|
|
|
- this._data = reader.read(size);
|
|
|
- }
|
|
|
-
|
|
|
- get dataType() {
|
|
|
- return this._dataType;
|
|
|
- }
|
|
|
-
|
|
|
- get shape() {
|
|
|
- return this._shape;
|
|
|
- }
|
|
|
-
|
|
|
- get data() {
|
|
|
- return this._data;
|
|
|
+ get size() {
|
|
|
+ return this.shape.reduce((a, b) => a * b, 1);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-ndarray.Shape = class {
|
|
|
-
|
|
|
- constructor(reader, uint64) {
|
|
|
- const ndim = reader.uint32();
|
|
|
- this._dimensions = [];
|
|
|
- for (let i = 0; i < ndim; i++) {
|
|
|
- this._dimensions.push(uint64 ? reader.uint64() : reader.uint32());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- get dimensions() {
|
|
|
- return this._dimensions;
|
|
|
- }
|
|
|
-
|
|
|
- size() {
|
|
|
- return this._dimensions.reduce((a, b) => a * b, 1);
|
|
|
- }
|
|
|
-};
|
|
|
+mxnet.context = {};
|
|
|
|
|
|
-ndarray.Context = class {
|
|
|
+mxnet.context.Context = class {
|
|
|
|
|
|
constructor(reader) {
|
|
|
this._deviceType = reader.uint32();
|
|
|
@@ -1155,7 +1063,7 @@ ndarray.Context = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-ndarray.Reader = class {
|
|
|
+mxnet.ndarray.BinaryReader = class {
|
|
|
|
|
|
constructor(buffer) {
|
|
|
this._buffer = buffer;
|
|
|
@@ -1163,34 +1071,23 @@ ndarray.Reader = class {
|
|
|
this._end = buffer.length;
|
|
|
}
|
|
|
|
|
|
- checkSignature(signature) {
|
|
|
- if (this._position + signature.length <= this._end) {
|
|
|
- for (let i = 0; i < signature.length; i++) {
|
|
|
- if (this._buffer[this._position + i] != signature[i]) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- }
|
|
|
+ skip(offset) {
|
|
|
+ this._position += offset;
|
|
|
+ if (this._position > this._end) {
|
|
|
+ throw new mxnet.Error('Data not available.');
|
|
|
}
|
|
|
- this._position += signature.length;
|
|
|
- return true;
|
|
|
}
|
|
|
|
|
|
read(size) {
|
|
|
- if (this._position + size > this._end) {
|
|
|
- throw new ndarray.Error('Data not available.');
|
|
|
- }
|
|
|
- const data = this._buffer.subarray(this._position, this._position + size);
|
|
|
- this._position += size;
|
|
|
- return data;
|
|
|
+ const position = this._position;
|
|
|
+ this.skip(size);
|
|
|
+ return this._buffer.subarray(position, this._position);
|
|
|
}
|
|
|
|
|
|
uint16() {
|
|
|
- if (this._position + 2 > this._end) {
|
|
|
- throw new ndarray.Error('Data not available.');
|
|
|
- }
|
|
|
- const value = this._buffer[this._position] | (this._buffer[this._position + 1] << 8);
|
|
|
- this._position += 2;
|
|
|
- return value;
|
|
|
+ const position = this._position;
|
|
|
+ this.skip(2);
|
|
|
+ return this._buffer[position] | (this._buffer[position + 1] << 8);
|
|
|
}
|
|
|
|
|
|
uint32() {
|
|
|
@@ -1200,16 +1097,77 @@ ndarray.Reader = class {
|
|
|
uint64() {
|
|
|
const value = this.uint32();
|
|
|
if (this.uint32() != 0) {
|
|
|
- throw new ndarray.Error('Large int64 value.');
|
|
|
+ throw new mxnet.Error('Large int64 value.');
|
|
|
}
|
|
|
return value;
|
|
|
}
|
|
|
+
|
|
|
+ uint32s() {
|
|
|
+ const array = new Array(this.uint32());
|
|
|
+ for (let i = 0; i < array.length; i++) {
|
|
|
+ array[i] = this.uint32();
|
|
|
+ }
|
|
|
+ return array;
|
|
|
+ }
|
|
|
+
|
|
|
+ uint64s() {
|
|
|
+ const array = new Array(this.uint32());
|
|
|
+ for (let i = 0; i < array.length; i++) {
|
|
|
+ array[i] = this.uint64();
|
|
|
+ }
|
|
|
+ return array;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+mxnet.Metadata = class {
|
|
|
+
|
|
|
+ static open(context) {
|
|
|
+ if (mxnet.Metadata._metadata) {
|
|
|
+ return Promise.resolve(mxnet.Metadata._metadata);
|
|
|
+ }
|
|
|
+ return context.request('mxnet-metadata.json', 'utf-8', null).then((data) => {
|
|
|
+ mxnet.Metadata._metadata = new mxnet.Metadata(data);
|
|
|
+ return mxnet.Metadata._metadata;
|
|
|
+ }).catch(() => {
|
|
|
+ mxnet.Metadata._metadata = new mxnet.Metadata(null);
|
|
|
+ return mxnet.Metadata._metadata;
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ constructor(data) {
|
|
|
+ this._map = new Map();
|
|
|
+ this._attributeCache = {};
|
|
|
+ if (data) {
|
|
|
+ const metadata = JSON.parse(data);
|
|
|
+ this._map = new Map(metadata.map((item) => [ item.name, item ]));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ type(name) {
|
|
|
+ return this._map.get(name);
|
|
|
+ }
|
|
|
+
|
|
|
+ attribute(type, name) {
|
|
|
+ let map = this._attributeCache[type];
|
|
|
+ if (!map) {
|
|
|
+ map = {};
|
|
|
+ const schema = this.type(type);
|
|
|
+ if (schema && schema.attributes) {
|
|
|
+ for (const attribute of schema.attributes) {
|
|
|
+ map[attribute.name] = attribute;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ this._attributeCache[type] = map;
|
|
|
+ }
|
|
|
+ return map[name] || null;
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
-ndarray.Error = class extends Error {
|
|
|
+mxnet.Error = class extends Error {
|
|
|
+
|
|
|
constructor(message) {
|
|
|
super(message);
|
|
|
- this.name = 'NDArray Error';
|
|
|
+ this.name = 'Error loading MXNet model.';
|
|
|
}
|
|
|
};
|
|
|
|