|
|
@@ -130,7 +130,7 @@ paddle.ModelFactory = class {
|
|
|
const loadParams = (metadata, program, stream) => {
|
|
|
const weights = new Map();
|
|
|
while (stream.position < stream.length) {
|
|
|
- const tensor = paddle.Utility.openTensor(stream);
|
|
|
+ const tensor = paddle.Utility.openTensorDesc(stream);
|
|
|
weights.set(program.vars.shift(), tensor);
|
|
|
}
|
|
|
return weights;
|
|
|
@@ -159,7 +159,7 @@ paddle.ModelFactory = class {
|
|
|
const promises = program.vars.map((name) => context.request(name, null));
|
|
|
return Promise.all(promises).then((streams) => {
|
|
|
for (let i = 0; i < program.vars.length; i++) {
|
|
|
- const tensor = paddle.Utility.openTensor(streams[i]);
|
|
|
+ const tensor = paddle.Utility.openTensorDesc(streams[i]);
|
|
|
tensors.set(program.vars[i], tensor);
|
|
|
}
|
|
|
return createModel(metadata, program.format, program.desc, tensors);
|
|
|
@@ -167,11 +167,11 @@ paddle.ModelFactory = class {
|
|
|
return createModel(metadata, program.format, program.desc, tensors);
|
|
|
});
|
|
|
};
|
|
|
- const openPickle = (stream, weights) => {
|
|
|
+ const openNumPyArrayPickle = (stream, weights) => {
|
|
|
const execution = new python.Execution(null);
|
|
|
const unpickler = python.Unpickler.open(stream);
|
|
|
const obj = unpickler.load((name, args) => execution.invoke(name, args));
|
|
|
- paddle.Utility.openPickle(obj, weights);
|
|
|
+ paddle.Utility.openNumPyArrayList(obj, weights);
|
|
|
};
|
|
|
const program = openProgram(context.stream, match);
|
|
|
if (extension === 'pdmodel') {
|
|
|
@@ -181,16 +181,16 @@ paddle.ModelFactory = class {
|
|
|
}).catch((/* err */) => {
|
|
|
const weights = new Map();
|
|
|
return context.request(base + '.pdparams', null).then((stream) => {
|
|
|
- openPickle(stream, weights);
|
|
|
+ openNumPyArrayPickle(stream, weights);
|
|
|
return context.request(base + '.pdopt', null).then((stream) => {
|
|
|
- openPickle(stream, weights);
|
|
|
+ openNumPyArrayPickle(stream, weights);
|
|
|
return createModel(metadata, program.format, program.desc, weights);
|
|
|
}).catch((/* err */) => {
|
|
|
return createModel(metadata, program.format, program.desc, weights);
|
|
|
});
|
|
|
}).catch((/* err */) => {
|
|
|
return context.request(base + '.pdopt', null).then((stream) => {
|
|
|
- openPickle(stream, weights);
|
|
|
+ openNumPyArrayPickle(stream, weights);
|
|
|
return createModel(metadata, program.format, program.desc, weights);
|
|
|
}).catch((/* err */) => {
|
|
|
return loadEntries(context, program);
|
|
|
@@ -781,7 +781,7 @@ paddle.Utility = class {
|
|
|
return new paddle.TensorType(dataType, new paddle.TensorShape(shape));
|
|
|
}
|
|
|
|
|
|
- static openTensor(stream) {
|
|
|
+ static openTensorDesc(stream) {
|
|
|
const signature = stream.read(16);
|
|
|
if (!signature.every((value) => value === 0x00)) {
|
|
|
throw new paddle.Error('Invalid paddle.TensorDesc signature.');
|
|
|
@@ -806,12 +806,12 @@ paddle.Utility = class {
|
|
|
return new paddle.Tensor(type, data);
|
|
|
}
|
|
|
|
|
|
- static openPickle(obj, weights) {
|
|
|
+ static openNumPyArrayList(obj, weights) {
|
|
|
const map = null; // this._data['StructuredToParameterName@@'];
|
|
|
for (const entry of Object.entries(obj)) {
|
|
|
const key = entry[0];
|
|
|
const value = entry[1];
|
|
|
- if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
|
|
|
+ if (paddle.Utility.isNumPyArray(value)) {
|
|
|
const name = map ? map[key] : key;
|
|
|
const type = new paddle.TensorType(value.dtype.__name__, new paddle.TensorShape(value.shape));
|
|
|
const data = value.data;
|
|
|
@@ -820,6 +820,10 @@ paddle.Utility = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ static isNumPyArray(value) {
|
|
|
+ return value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray';
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
paddle.Entries = class {
|
|
|
@@ -869,7 +873,7 @@ paddle.Entries = class {
|
|
|
if (entry[0].startsWith(rootFolder)) {
|
|
|
const name = entry[0].substring(rootFolder.length);
|
|
|
const stream = entry[1];
|
|
|
- const tensor = paddle.Utility.openTensor(stream);
|
|
|
+ const tensor = paddle.Utility.openTensorDesc(stream);
|
|
|
this._weights.set(name, tensor);
|
|
|
}
|
|
|
}
|
|
|
@@ -881,7 +885,8 @@ paddle.Pickle = class {
|
|
|
|
|
|
static open(context) {
|
|
|
const obj = context.open('pkl');
|
|
|
- if (obj && !Array.isArray(obj) && Object(obj) === obj) {
|
|
|
+ if (obj && !Array.isArray(obj) && Object(obj) === obj &&
|
|
|
+ Object.entries(obj).filter((entry) => paddle.Utility.isNumPyArray(entry[1])).length > 0) {
|
|
|
return new paddle.Pickle(obj);
|
|
|
}
|
|
|
return null;
|
|
|
@@ -908,7 +913,7 @@ paddle.Pickle = class {
|
|
|
_initialize() {
|
|
|
if (!this._weights) {
|
|
|
this._weights = new Map();
|
|
|
- paddle.Utility.openPickle(this._data, this._weights);
|
|
|
+ paddle.Utility.openNumPyArrayList(this._data, this._weights);
|
|
|
}
|
|
|
}
|
|
|
};
|