|
|
@@ -127,11 +127,19 @@ paddle.ModelFactory = class {
|
|
|
const createModel = (metadata, format, desc, tensors) => {
|
|
|
return new paddle.Model(metadata, format, desc, tensors);
|
|
|
};
|
|
|
- const loadParams = (metadata, program, stream) => {
|
|
|
- const weights = new Map();
|
|
|
+ const loadParams = (stream) => {
|
|
|
+ const params = [];
|
|
|
while (stream.position < stream.length) {
|
|
|
const tensor = paddle.Utility.openTensorDesc(stream);
|
|
|
- weights.set(program.vars.shift(), tensor);
|
|
|
+ params.push(tensor);
|
|
|
+ }
|
|
|
+ return params;
|
|
|
+ };
|
|
|
+ const mapParams = (params, program) => {
|
|
|
+ const weights = new Map();
|
|
|
+ const vars = program.vars.slice();
|
|
|
+ for (const param of params) {
|
|
|
+ weights.set(vars.shift(), param);
|
|
|
}
|
|
|
return weights;
|
|
|
};
|
|
|
@@ -146,25 +154,24 @@ paddle.ModelFactory = class {
|
|
|
}
|
|
|
case 'paddle.params': {
|
|
|
const file = identifier !== 'params' ? base + '.pdmodel' : 'model';
|
|
|
+ const params = loadParams(context.stream);
|
|
|
return context.request(file, null).then((stream) => {
|
|
|
const program = openProgram(stream, 'paddle.pb');
|
|
|
- const tensors = loadParams(metadata, program, context.stream);
|
|
|
- return createModel(metadata, program.format, program.desc, tensors);
|
|
|
+ const weights = mapParams(params, program);
|
|
|
+ return createModel(metadata, program.format, program.desc, weights);
|
|
|
+ }).catch(() => {
|
|
|
+ const weights = new Map(params.map((param, index) => [ index.toString(), param ]));
|
|
|
+ return createModel(metadata, 'PaddlePaddle Inference Weights', null, weights);
|
|
|
});
|
|
|
}
|
|
|
case 'paddle.pb':
|
|
|
case 'paddle.pbtxt': {
|
|
|
const loadEntries = (context, program) => {
|
|
|
- const tensors = new Map();
|
|
|
- const promises = program.vars.map((name) => context.request(name, null));
|
|
|
+ const promises = program.vars.map((name) => context.request(name, null).then((stream) => stream).catch(() => null));
|
|
|
return Promise.all(promises).then((streams) => {
|
|
|
- for (let i = 0; i < program.vars.length; i++) {
|
|
|
- const tensor = paddle.Utility.openTensorDesc(streams[i]);
|
|
|
- tensors.set(program.vars[i], tensor);
|
|
|
- }
|
|
|
- return createModel(metadata, program.format, program.desc, tensors);
|
|
|
- }).catch((/* err */) => {
|
|
|
- return createModel(metadata, program.format, program.desc, tensors);
|
|
|
+ const params = streams.map((stream) => stream ? paddle.Utility.openTensorDesc(stream) : null);
|
|
|
+ const weights = mapParams(params, program);
|
|
|
+ return createModel(metadata, program.format, program.desc, weights);
|
|
|
});
|
|
|
};
|
|
|
const openNumPyArrayPickle = (stream, weights) => {
|
|
|
@@ -176,7 +183,8 @@ paddle.ModelFactory = class {
|
|
|
const program = openProgram(context.stream, match);
|
|
|
if (extension === 'pdmodel') {
|
|
|
return context.request(base + '.pdiparams', null).then((stream) => {
|
|
|
- const weights = loadParams(metadata, program, stream);
|
|
|
+ const params = loadParams(stream);
|
|
|
+ const weights = mapParams(params, program);
|
|
|
return createModel(metadata, program.format, program.desc, weights);
|
|
|
}).catch((/* err */) => {
|
|
|
const weights = new Map();
|
|
|
@@ -200,7 +208,8 @@ paddle.ModelFactory = class {
|
|
|
}
|
|
|
if (identifier === 'model') {
|
|
|
return context.request('params', null).then((stream) => {
|
|
|
- const weights = loadParams(metadata, program, stream);
|
|
|
+ const params = loadParams(stream);
|
|
|
+ const weights = mapParams(params, program);
|
|
|
return createModel(metadata, program.format, program.desc, weights);
|
|
|
}).catch((/* err */) => {
|
|
|
return loadEntries(context, program);
|