|
|
@@ -11,16 +11,10 @@ caffe.ModelFactory = class {
|
|
|
if (extension == 'caffemodel') {
|
|
|
return true;
|
|
|
}
|
|
|
- if (extension == 'pbtxt' || extension == 'prototxt') {
|
|
|
- if (identifier == 'saved_model.pbtxt' || identifier == 'saved_model.prototxt' ||
|
|
|
- identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
|
|
|
- identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- const tags = context.tags('pbtxt');
|
|
|
- if (tags.has('layer') || tags.has('layers') || tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
|
|
|
- return true;
|
|
|
- }
|
|
|
+ if (identifier == 'saved_model.pbtxt' || identifier == 'saved_model.prototxt' ||
|
|
|
+ identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
|
|
|
+ identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
|
|
|
+ return false;
|
|
|
}
|
|
|
if (extension == 'pt') {
|
|
|
// Reject PyTorch models
|
|
|
@@ -33,10 +27,10 @@ caffe.ModelFactory = class {
|
|
|
if (buffer && buffer.length > 2 && buffer[0] == 0x50 && buffer[1] == 0x4B) {
|
|
|
return false;
|
|
|
}
|
|
|
- const tags = context.tags('pbtxt');
|
|
|
- if (tags.has('layer') || tags.has('layers') || tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
|
|
|
- return true;
|
|
|
- }
|
|
|
+ }
|
|
|
+ const tags = context.tags('pbtxt');
|
|
|
+ if (tags.has('layer') || tags.has('layers') || tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
|
|
|
+ return true;
|
|
|
}
|
|
|
return false;
|
|
|
}
|
|
|
@@ -46,39 +40,39 @@ caffe.ModelFactory = class {
|
|
|
caffe.proto = protobuf.get('caffe').caffe;
|
|
|
return caffe.Metadata.open(host).then((metadata) => {
|
|
|
const extension = context.identifier.split('.').pop();
|
|
|
- if (extension == 'pbtxt' || extension == 'prototxt' || extension == 'pt') {
|
|
|
- const tags = context.tags('pbtxt');
|
|
|
- if (tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
|
|
|
- try {
|
|
|
- const reader = protobuf.TextReader.create(context.buffer);
|
|
|
- reader.field = function(tag, message) {
|
|
|
- if (message instanceof caffe.proto.SolverParameter) {
|
|
|
- message[tag] = this.read();
|
|
|
- return;
|
|
|
- }
|
|
|
- throw new Error("Unknown field '" + tag + "'" + this.location());
|
|
|
- };
|
|
|
- const solver = caffe.proto.SolverParameter.decodeText(reader);
|
|
|
- if (solver.net_param) {
|
|
|
- return new caffe.Model(metadata, solver.net_param);
|
|
|
- }
|
|
|
- else if (solver.net || solver.train_net) {
|
|
|
- let file = solver.net || solver.train_net;
|
|
|
- file = file.split('/').pop();
|
|
|
- return context.request(file, null).then((buffer) => {
|
|
|
- return this._openNetParameterText(metadata, context.identifier, buffer, host);
|
|
|
- }).catch((error) => {
|
|
|
- if (error) {
|
|
|
- const message = error && error.message ? error.message : error.toString();
|
|
|
- throw new caffe.Error("Failed to load '" + file + "' (" + message.replace(/\.$/, '') + ').');
|
|
|
- }
|
|
|
- });
|
|
|
+ const tags = context.tags('pbtxt');
|
|
|
+ if (tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
|
|
|
+ try {
|
|
|
+ const reader = protobuf.TextReader.create(context.buffer);
|
|
|
+ reader.field = function(tag, message) {
|
|
|
+ if (message instanceof caffe.proto.SolverParameter) {
|
|
|
+ message[tag] = this.read();
|
|
|
+ return;
|
|
|
}
|
|
|
+ throw new Error("Unknown field '" + tag + "'" + this.location());
|
|
|
+ };
|
|
|
+ const solver = caffe.proto.SolverParameter.decodeText(reader);
|
|
|
+ if (solver.net_param) {
|
|
|
+ return new caffe.Model(metadata, solver.net_param);
|
|
|
}
|
|
|
- catch (error) {
|
|
|
- // continue regardless of error
|
|
|
+ else if (solver.net || solver.train_net) {
|
|
|
+ let file = solver.net || solver.train_net;
|
|
|
+ file = file.split('/').pop();
|
|
|
+ return context.request(file, null).then((buffer) => {
|
|
|
+ return this._openNetParameterText(metadata, context.identifier, buffer, host);
|
|
|
+ }).catch((error) => {
|
|
|
+ if (error) {
|
|
|
+ const message = error && error.message ? error.message : error.toString();
|
|
|
+ throw new caffe.Error("Failed to load '" + file + "' (" + message.replace(/\.$/, '') + ').');
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
}
|
|
|
+ catch (error) {
|
|
|
+ // continue regardless of error
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (tags.has('layer') || tags.has('layers')) {
|
|
|
return this._openNetParameterText(metadata, context.identifier, context.buffer, host);
|
|
|
}
|
|
|
else {
|