|
|
@@ -3,6 +3,7 @@ var mxnet = mxnet || {};
|
|
|
var json = json || require('./json');
|
|
|
var zip = zip || require('./zip');
|
|
|
var ndarray = ndarray || {};
|
|
|
+var base = base || require('./base');
|
|
|
|
|
|
mxnet.ModelFactory = class {
|
|
|
|
|
|
@@ -964,7 +965,28 @@ mxnet.ndarray = class {
|
|
|
static load(buffer) {
|
|
|
// NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
|
|
|
const map = new Map();
|
|
|
- const reader = new mxnet.ndarray.BinaryReader(buffer);
|
|
|
+ const reader = new base.BinaryReader(buffer);
|
|
|
+ reader.uint64 = function() {
|
|
|
+ const value = this.uint32();
|
|
|
+ if (this.uint32() != 0) {
|
|
|
+ throw new mxnet.Error('Large uint64 value.');
|
|
|
+ }
|
|
|
+ return value;
|
|
|
+ };
|
|
|
+ reader.uint32s = function() {
|
|
|
+ const array = new Array(this.uint32());
|
|
|
+ for (let i = 0; i < array.length; i++) {
|
|
|
+ array[i] = this.uint32();
|
|
|
+ }
|
|
|
+ return array;
|
|
|
+ };
|
|
|
+ reader.uint64s = function() {
|
|
|
+ const array = new Array(this.uint32());
|
|
|
+ for (let i = 0; i < array.length; i++) {
|
|
|
+ array[i] = this.uint64();
|
|
|
+ }
|
|
|
+ return array;
|
|
|
+ };
|
|
|
if (reader.uint64() !== 0x112) { // kMXAPINDArrayListMagic
|
|
|
throw new mxnet.Error('Invalid signature.');
|
|
|
}
|
|
|
@@ -990,13 +1012,12 @@ mxnet.ndarray = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-
|
|
|
mxnet.ndarray.NDArray = class {
|
|
|
|
|
|
constructor(reader) {
|
|
|
mxnet.ndarray.NDArray._dataTypeSizeTable = [ 4, 8, 2, 1, 4, 1, 8 ];
|
|
|
switch (reader.uint32()) {
|
|
|
- case 0xF993faca: { // NDARRAY_V3_MAGIC
|
|
|
+ case 0xf993faca: { // NDARRAY_V3_MAGIC
|
|
|
throw new mxnet.Array('mxnet.ndarray.NDArray v3 not supported.');
|
|
|
}
|
|
|
case 0xf993fac9: { // NDARRAY_V2_MAGIC
|
|
|
@@ -1063,62 +1084,6 @@ mxnet.context.Context = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-mxnet.ndarray.BinaryReader = class {
|
|
|
-
|
|
|
- constructor(buffer) {
|
|
|
- this._buffer = buffer;
|
|
|
- this._position = 0;
|
|
|
- this._end = buffer.length;
|
|
|
- }
|
|
|
-
|
|
|
- skip(offset) {
|
|
|
- this._position += offset;
|
|
|
- if (this._position > this._end) {
|
|
|
- throw new mxnet.Error('Data not available.');
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- read(size) {
|
|
|
- const position = this._position;
|
|
|
- this.skip(size);
|
|
|
- return this._buffer.subarray(position, this._position);
|
|
|
- }
|
|
|
-
|
|
|
- uint16() {
|
|
|
- const position = this._position;
|
|
|
- this.skip(2);
|
|
|
- return this._buffer[position] | (this._buffer[position + 1] << 8);
|
|
|
- }
|
|
|
-
|
|
|
- uint32() {
|
|
|
- return this.uint16() | (this.uint16() << 16);
|
|
|
- }
|
|
|
-
|
|
|
- uint64() {
|
|
|
- const value = this.uint32();
|
|
|
- if (this.uint32() != 0) {
|
|
|
- throw new mxnet.Error('Large int64 value.');
|
|
|
- }
|
|
|
- return value;
|
|
|
- }
|
|
|
-
|
|
|
- uint32s() {
|
|
|
- const array = new Array(this.uint32());
|
|
|
- for (let i = 0; i < array.length; i++) {
|
|
|
- array[i] = this.uint32();
|
|
|
- }
|
|
|
- return array;
|
|
|
- }
|
|
|
-
|
|
|
- uint64s() {
|
|
|
- const array = new Array(this.uint32());
|
|
|
- for (let i = 0; i < array.length; i++) {
|
|
|
- array[i] = this.uint64();
|
|
|
- }
|
|
|
- return array;
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
mxnet.Metadata = class {
|
|
|
|
|
|
static open(context) {
|