|
|
@@ -8,7 +8,8 @@ paddle.ModelFactory = class {
|
|
|
match(context) {
|
|
|
const identifier = context.identifier;
|
|
|
const extension = identifier.split('.').pop().toLowerCase();
|
|
|
- if (identifier === '__model__' || extension === 'paddle' || extension === 'pdmodel') {
|
|
|
+ if (identifier === '__model__' || identifier === 'model' ||
|
|
|
+ extension === 'paddle' || extension === 'pdmodel') {
|
|
|
return true;
|
|
|
}
|
|
|
if (extension === 'pbtxt' || extension === 'txt') {
|
|
|
@@ -20,6 +21,10 @@ paddle.ModelFactory = class {
|
|
|
if (paddle.Container.open(context)) {
|
|
|
return true;
|
|
|
}
|
|
|
+ const stream = context.stream;
|
|
|
+ if (stream.length > 16 && stream.peek(16).every((value) => value === 0x00)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
@@ -27,24 +32,21 @@ paddle.ModelFactory = class {
|
|
|
return paddle.Metadata.open(context).then((metadata) => {
|
|
|
return context.require('./paddle-proto').then(() => {
|
|
|
paddle.proto = protobuf.get('paddle').paddle.framework.proto;
|
|
|
- const container = paddle.Container.open(context);
|
|
|
- if (container) {
|
|
|
- return new paddle.Model(metadata, container.format, null, container.weights);
|
|
|
- }
|
|
|
- else {
|
|
|
- let programDesc = null;
|
|
|
- let format = 'PaddlePaddle';
|
|
|
- const identifier = context.identifier;
|
|
|
- const parts = identifier.split('.');
|
|
|
- const extension = parts.pop().toLowerCase();
|
|
|
- const base = parts.join('.');
|
|
|
+ const stream = context.stream;
|
|
|
+ const identifier = context.identifier;
|
|
|
+ const parts = identifier.split('.');
|
|
|
+ const extension = parts.pop().toLowerCase();
|
|
|
+ const base = parts.join('.');
|
|
|
+ const openProgram = (stream, extension) => {
|
|
|
+ const program = {};
|
|
|
+ program.format = 'PaddlePaddle';
|
|
|
switch (extension) {
|
|
|
case 'pbtxt':
|
|
|
case 'txt': {
|
|
|
try {
|
|
|
- const buffer = context.stream.peek();
|
|
|
+ const buffer = stream.peek();
|
|
|
const reader = protobuf.TextReader.create(buffer);
|
|
|
- programDesc = paddle.proto.ProgramDesc.decodeText(reader);
|
|
|
+ program.desc = paddle.proto.ProgramDesc.decodeText(reader);
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
@@ -54,9 +56,9 @@ paddle.ModelFactory = class {
|
|
|
}
|
|
|
default: {
|
|
|
try {
|
|
|
- const buffer = context.stream.peek();
|
|
|
+ const buffer = stream.peek();
|
|
|
const reader = protobuf.Reader.create(buffer);
|
|
|
- programDesc = paddle.proto.ProgramDesc.decode(reader);
|
|
|
+ program.desc = paddle.proto.ProgramDesc.decode(reader);
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
@@ -65,6 +67,7 @@ paddle.ModelFactory = class {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
+ const programDesc = program.desc;
|
|
|
if (programDesc.version && programDesc.version.version && programDesc.version.version.toNumber) {
|
|
|
const version = programDesc.version.version.toNumber();
|
|
|
if (version > 0) {
|
|
|
@@ -75,7 +78,7 @@ paddle.ModelFactory = class {
|
|
|
list.pop();
|
|
|
}
|
|
|
}
|
|
|
- format += ' v' + list.map((item) => item.toString()).join('.');
|
|
|
+ program.format += ' v' + list.map((item) => item.toString()).join('.');
|
|
|
}
|
|
|
}
|
|
|
const variables = new Set();
|
|
|
@@ -98,30 +101,56 @@ paddle.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- const vars = Array.from(variables).sort();
|
|
|
+ program.vars = Array.from(variables).sort();
|
|
|
+ return program;
|
|
|
+ };
|
|
|
+ const loadParams = (metadata, program, stream) => {
|
|
|
const tensors = new Map();
|
|
|
- const load_entries = (programDesc) => {
|
|
|
- const promises = vars.map((name) => context.request(name, null));
|
|
|
+ while (stream.position < stream.length) {
|
|
|
+ tensors.set(program.vars.shift(), new paddle.Tensor(null, stream));
|
|
|
+ }
|
|
|
+ return new paddle.Model(metadata, program.format, program.desc, tensors);
|
|
|
+ };
|
|
|
+ const container = paddle.Container.open(context);
|
|
|
+ if (container) {
|
|
|
+ return new paddle.Model(metadata, container.format, null, container.weights);
|
|
|
+ }
|
|
|
+ else if (stream.length > 16 && stream.peek(16).every((value) => value === 0x00)) {
|
|
|
+ const file = identifier !== 'params' ? base + '.pdmodel' : 'model';
|
|
|
+ return context.request(file, null).then((stream) => {
|
|
|
+ const program = openProgram(stream, '');
|
|
|
+ return loadParams(metadata, program, context.stream);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ const program = openProgram(context.stream, extension);
|
|
|
+ const loadEntries = (context, program) => {
|
|
|
+ const promises = program.vars.map((name) => context.request(name, null));
|
|
|
+ const tensors = new Map();
|
|
|
return Promise.all(promises).then((streams) => {
|
|
|
- for (let i = 0; i < vars.length; i++) {
|
|
|
- tensors.set(vars[i], new paddle.Tensor(null, streams[i]));
|
|
|
+ for (let i = 0; i < program.vars.length; i++) {
|
|
|
+ tensors.set(program.vars[i], new paddle.Tensor(null, streams[i]));
|
|
|
}
|
|
|
- return new paddle.Model(metadata, format, programDesc, tensors);
|
|
|
+ return new paddle.Model(metadata, program.format, program.desc, tensors);
|
|
|
}).catch((/* err */) => {
|
|
|
- return new paddle.Model(metadata, format, programDesc, tensors);
|
|
|
+ return new paddle.Model(metadata, program.format, program.desc, tensors);
|
|
|
});
|
|
|
};
|
|
|
if (extension === 'pdmodel') {
|
|
|
return context.request(base + '.pdiparams', null).then((stream) => {
|
|
|
- while (stream.position < stream.length) {
|
|
|
- tensors.set(vars.shift(), new paddle.Tensor(null, stream));
|
|
|
- }
|
|
|
- return new paddle.Model(metadata, format, programDesc, tensors);
|
|
|
+ return loadParams(metadata, program, stream);
|
|
|
+ }).catch((/* err */) => {
|
|
|
+ return loadEntries(context, program);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ if (identifier === 'model') {
|
|
|
+ return context.request('params', null).then((stream) => {
|
|
|
+ return loadParams(metadata, program, stream);
|
|
|
}).catch((/* err */) => {
|
|
|
- return load_entries(programDesc, null);
|
|
|
+ return loadEntries(context, program);
|
|
|
});
|
|
|
}
|
|
|
- return load_entries(programDesc, null);
|
|
|
+ return loadEntries(context, program);
|
|
|
}
|
|
|
});
|
|
|
});
|
|
|
@@ -498,7 +527,7 @@ paddle.Tensor = class {
|
|
|
constructor(type, data) {
|
|
|
this._type = type;
|
|
|
if (data && !Array.isArray(data)) {
|
|
|
- if (data.__module__ === 'numpy' && data.__name__ === 'ndarray') {
|
|
|
+ if (data.__class__ && data.__class__.__module__ === 'numpy' && data.__class__.__name__ === 'ndarray') {
|
|
|
this._type = new paddle.TensorType(data.dtype.name, new paddle.TensorShape(data.shape));
|
|
|
this._data = data.data;
|
|
|
this._kind = 'NumPy Array';
|
|
|
@@ -797,7 +826,7 @@ paddle.Container = class {
|
|
|
this._weights = new Map();
|
|
|
for (const key of Object.keys(this._data)) {
|
|
|
const value = this._data[key];
|
|
|
- if (value && !Array.isArray(value) && value.__module__ === 'numpy' && value.__name__ === 'ndarray') {
|
|
|
+ if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
|
|
|
const name = map ? map[key] : key;
|
|
|
this._weights.set(name, new paddle.Tensor(null, value));
|
|
|
}
|
|
|
@@ -826,8 +855,8 @@ paddle.Metadata = class {
|
|
|
}
|
|
|
|
|
|
constructor(data) {
|
|
|
- this._map = {};
|
|
|
- this._attributeCache = {};
|
|
|
+ this._map = new Map();
|
|
|
+ this._attributeCache = new Map();
|
|
|
if (data) {
|
|
|
const metadata = JSON.parse(data);
|
|
|
this._map = new Map(metadata.map((item) => [ item.name, item ]));
|
|
|
@@ -835,22 +864,22 @@ paddle.Metadata = class {
|
|
|
}
|
|
|
|
|
|
type(name) {
|
|
|
- return this._map[name] || null;
|
|
|
+ return this._map.get(name) || null;
|
|
|
}
|
|
|
|
|
|
attribute(type, name) {
|
|
|
- let map = this._attributeCache[type];
|
|
|
+ let map = this._attributeCache.get(type);
|
|
|
if (!map) {
|
|
|
- map = {};
|
|
|
- const schema = this.type(type);
|
|
|
- if (schema && schema.attributes && schema.attributes.length > 0) {
|
|
|
- for (const attribute of schema.attributes) {
|
|
|
- map[attribute.name] = attribute;
|
|
|
+ map = new Map();
|
|
|
+ const metadata = this.type(type);
|
|
|
+ if (metadata && metadata.attributes && metadata.attributes.length > 0) {
|
|
|
+ for (const attribute of metadata.attributes) {
|
|
|
+ map.set(attribute.name, attribute);
|
|
|
}
|
|
|
}
|
|
|
- this._attributeCache[type] = map;
|
|
|
+ this._attributeCache.set(type, map);
|
|
|
}
|
|
|
- return map[name] || null;
|
|
|
+ return map.get(name) || null;
|
|
|
}
|
|
|
};
|
|
|
|