|
@@ -86,10 +86,10 @@ caffe2.ModelFactory = class {
|
|
|
throw new caffe2.Error("File text format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'.");
|
|
throw new caffe2.Error("File text format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'.");
|
|
|
}
|
|
}
|
|
|
try {
|
|
try {
|
|
|
- if (init) {
|
|
|
|
|
- caffe2.proto = protobuf.roots.caffe2.caffe2;
|
|
|
|
|
- init_net = caffe2.proto.NetDef.decodeText(prototxt.TextReader.create(init));
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ caffe2.proto = protobuf.roots.caffe2.caffe2;
|
|
|
|
|
+ init_net = (typeof init === 'string') ?
|
|
|
|
|
+ caffe2.proto.NetDef.decodeText(prototxt.TextReader.create(init)) :
|
|
|
|
|
+ caffe2.proto.NetDef.decode(init);
|
|
|
}
|
|
}
|
|
|
catch (error) {
|
|
catch (error) {
|
|
|
// continue regardless of error
|
|
// continue regardless of error
|
|
@@ -111,14 +111,22 @@ caffe2.ModelFactory = class {
|
|
|
});
|
|
});
|
|
|
}
|
|
}
|
|
|
else if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
|
|
else if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
|
|
|
- return context.request(identifier.replace('predict_net', 'init_net'), 'utf-8').then((text) => {
|
|
|
|
|
- return open_text(context.text, text);
|
|
|
|
|
|
|
+ return context.request(identifier.replace('predict_net', 'init_net').replace(/\.pbtxt/, '.pb'), null).then((buffer) => {
|
|
|
|
|
+ return open_text(context.text, buffer);
|
|
|
}).catch(() => {
|
|
}).catch(() => {
|
|
|
- return open_text(context.text, null);
|
|
|
|
|
|
|
+ return context.request(identifier.replace('predict_net', 'init_net'), 'utf-8').then((text) => {
|
|
|
|
|
+ return open_text(context.text, text);
|
|
|
|
|
+ }).catch(() => {
|
|
|
|
|
+ return open_text(context.text, null);
|
|
|
|
|
+ });
|
|
|
});
|
|
});
|
|
|
}
|
|
}
|
|
|
else {
|
|
else {
|
|
|
- return open_text(context.text, null);
|
|
|
|
|
|
|
+ return context.request(base + '_init.pb', null).then((buffer) => {
|
|
|
|
|
+ return open_text(context.text, buffer);
|
|
|
|
|
+ }).catch(() => {
|
|
|
|
|
+ return open_text(context.text, null);
|
|
|
|
|
+ });
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
else {
|
|
else {
|
|
@@ -151,14 +159,28 @@ caffe2.ModelFactory = class {
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
if (base.toLowerCase().endsWith('init_net')) {
|
|
if (base.toLowerCase().endsWith('init_net')) {
|
|
|
- return context.request(base.substring(0, base.length - 8) + 'predict_net.' + extension, null).then((buffer) => {
|
|
|
|
|
|
|
+ return context.request(base.replace(/init_net$/, '') + 'predict_net.' + extension, null).then((buffer) => {
|
|
|
return open_binary(buffer, context.buffer);
|
|
return open_binary(buffer, context.buffer);
|
|
|
}).catch(() => {
|
|
}).catch(() => {
|
|
|
return open_binary(context.buffer, null);
|
|
return open_binary(context.buffer, null);
|
|
|
});
|
|
});
|
|
|
}
|
|
}
|
|
|
|
|
+ else if (base.toLowerCase().endsWith('_init')) {
|
|
|
|
|
+ return context.request(base.replace(/_init$/, '') + '.' + extension, null).then((buffer) => {
|
|
|
|
|
+ return open_binary(buffer, context.buffer);
|
|
|
|
|
+ }).catch(() => {
|
|
|
|
|
+ return open_binary(context.buffer, null);
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
|
|
|
|
|
+ return context.request(identifier.replace('predict_net', 'init_net'), null).then((buffer) => {
|
|
|
|
|
+ return open_binary(context.buffer, buffer);
|
|
|
|
|
+ }).catch(() => {
|
|
|
|
|
+ return open_binary(context.buffer, null);
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
else {
|
|
else {
|
|
|
- return context.request(base.substring(0, base.length - 11) + 'init_net.' + extension, null).then((buffer) => {
|
|
|
|
|
|
|
+ return context.request(base + '_init.' + extension, null).then((buffer) => {
|
|
|
return open_binary(context.buffer, buffer);
|
|
return open_binary(context.buffer, buffer);
|
|
|
}).catch(() => {
|
|
}).catch(() => {
|
|
|
return open_binary(context.buffer, null);
|
|
return open_binary(context.buffer, null);
|