|
|
@@ -11,7 +11,8 @@ caffe2.ModelFactory = class {
|
|
|
const identifier = context.identifier.toLowerCase();
|
|
|
const extension = identifier.split('.').pop().toLowerCase();
|
|
|
if (extension == 'pb') {
|
|
|
- if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb')) {
|
|
|
+ if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb') ||
|
|
|
+ identifier.startsWith('predict_net') || identifier.startsWith('init_net')) {
|
|
|
return true;
|
|
|
}
|
|
|
const tags = context.tags('pb');
|
|
|
@@ -44,7 +45,7 @@ caffe2.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
if (extension == 'pbtxt' || extension == 'prototxt') {
|
|
|
- if (identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt')) {
|
|
|
+ if (identifier.endsWith('predict_net')) {
|
|
|
return true;
|
|
|
}
|
|
|
const tags = context.tags('pbtxt');
|
|
|
@@ -61,8 +62,10 @@ caffe2.ModelFactory = class {
|
|
|
open(context, host) {
|
|
|
return host.require('./caffe2-proto').then(() => {
|
|
|
return caffe2.Metadata.open(host).then((metadata) => {
|
|
|
- const identifier = context.identifier;
|
|
|
- const extension = identifier.split('.').pop().toLowerCase();
|
|
|
+ const identifier = context.identifier;
|
|
|
+ const parts = identifier.split('.');
|
|
|
+ const extension = parts.pop().toLowerCase();
|
|
|
+ const base = parts.join('.');
|
|
|
if (extension == 'pbtxt' || extension == 'prototxt') {
|
|
|
const open_text = (predict, init) => {
|
|
|
let predict_net = null;
|
|
|
@@ -83,8 +86,10 @@ caffe2.ModelFactory = class {
|
|
|
throw new caffe2.Error("File text format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'.");
|
|
|
}
|
|
|
try {
|
|
|
- caffe2.proto = protobuf.roots.caffe2.caffe2;
|
|
|
- init_net = caffe2.proto.NetDef.decodeText(prototxt.TextReader.create(init));
|
|
|
+ if (init) {
|
|
|
+ caffe2.proto = protobuf.roots.caffe2.caffe2;
|
|
|
+ init_net = caffe2.proto.NetDef.decodeText(prototxt.TextReader.create(init));
|
|
|
+ }
|
|
|
}
|
|
|
catch (error) {
|
|
|
// continue regardless of error
|
|
|
@@ -99,20 +104,23 @@ caffe2.ModelFactory = class {
|
|
|
throw new caffe2.Error(message + " in '" + identifier + "'.");
|
|
|
}
|
|
|
};
|
|
|
- if (identifier.toLowerCase().startsWith('init_net.')) {
|
|
|
- return context.request('predict_net.' + extension, 'utf-8').then((text) => {
|
|
|
+ if (base.toLowerCase().endsWith('init_net') || base.toLowerCase().startsWith('init_net')) {
|
|
|
+ return context.request(identifier.replace('init_net', 'predict_net'), 'utf-8').then((text) => {
|
|
|
return open_text(text, context.text);
|
|
|
}).catch(() => {
|
|
|
return open_text(context.text, null);
|
|
|
});
|
|
|
}
|
|
|
- else {
|
|
|
- return context.request('init_net.' + extension, 'utf-8').then((text) => {
|
|
|
+ else if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
|
|
|
+ return context.request(identifier.replace('predict_net', 'init_net'), 'utf-8').then((text) => {
|
|
|
return open_text(context.text, text);
|
|
|
}).catch(() => {
|
|
|
return open_text(context.text, null);
|
|
|
});
|
|
|
}
|
|
|
+ else {
|
|
|
+ return open_text(context.text, null);
|
|
|
+ }
|
|
|
}
|
|
|
else {
|
|
|
const open_binary = (predict, init) => {
|
|
|
@@ -126,8 +134,10 @@ caffe2.ModelFactory = class {
|
|
|
throw new caffe2.Error("File format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'.");
|
|
|
}
|
|
|
try {
|
|
|
- caffe2.proto = protobuf.roots.caffe2.caffe2;
|
|
|
- init_net = caffe2.proto.NetDef.decode(init);
|
|
|
+ if (init) {
|
|
|
+ caffe2.proto = protobuf.roots.caffe2.caffe2;
|
|
|
+ init_net = caffe2.proto.NetDef.decode(init);
|
|
|
+ }
|
|
|
}
|
|
|
catch (error) {
|
|
|
// continue regardless of error
|
|
|
@@ -142,15 +152,15 @@ caffe2.ModelFactory = class {
|
|
|
throw new caffe2.Error(message + " in '" + identifier + "'.");
|
|
|
}
|
|
|
};
|
|
|
- if (identifier.toLowerCase().startsWith('init_net.')) {
|
|
|
- return context.request('predict_net.' + extension, null).then((buffer) => {
|
|
|
+ if (base.toLowerCase().endsWith('init_net')) {
|
|
|
+ return context.request(base.substring(0, base.length - 8) + 'predict_net.' + extension, null).then((buffer) => {
|
|
|
return open_binary(buffer, context.buffer);
|
|
|
}).catch(() => {
|
|
|
return open_binary(context.buffer, null);
|
|
|
});
|
|
|
}
|
|
|
else {
|
|
|
- return context.request('init_net.' + extension, null).then((buffer) => {
|
|
|
+ return context.request(base.substring(0, base.length - 11) + 'init_net.' + extension, null).then((buffer) => {
|
|
|
return open_binary(context.buffer, buffer);
|
|
|
}).catch(() => {
|
|
|
return open_binary(context.buffer, null);
|
|
|
@@ -190,47 +200,59 @@ caffe2.Graph = class {
|
|
|
this._type = netDef.type || '';
|
|
|
this._nodes = [];
|
|
|
|
|
|
- let initializers = {};
|
|
|
- for (let external_input of netDef.external_input) {
|
|
|
- initializers[external_input] = {};
|
|
|
+ let inputs = new Map();
|
|
|
+ for (const input of netDef.external_input) {
|
|
|
+ inputs.set(input, {});
|
|
|
}
|
|
|
if (init) {
|
|
|
- for (let op of init.op) {
|
|
|
+ for (const op of init.op) {
|
|
|
if (op.output && op.output.length == 1) {
|
|
|
const name = op.output[0];
|
|
|
- let dataType = null;
|
|
|
+ if (!inputs.has(name)) {
|
|
|
+ inputs.set(name, {});
|
|
|
+ }
|
|
|
+ let initializer = inputs.get(name);
|
|
|
+ for (const arg of op.arg) {
|
|
|
+ initializer[arg.name] = arg;
|
|
|
+ }
|
|
|
switch (op.type) {
|
|
|
case 'GivenTensorFill':
|
|
|
- dataType = 'float32';
|
|
|
+ initializer.dataType = 'float32';
|
|
|
+ break;
|
|
|
+ case 'GivenTensorDoubleFill':
|
|
|
+ initializer.dataType = 'float64';
|
|
|
break;
|
|
|
case 'GivenTensorBoolFill':
|
|
|
- dataType = 'boolean';
|
|
|
+ initializer.dataType = 'boolean';
|
|
|
break;
|
|
|
case 'GivenTensorByteStringToUInt8Fill':
|
|
|
- dataType = 'uint8';
|
|
|
+ initializer.dataType = 'uint8';
|
|
|
break;
|
|
|
case 'GivenTensorIntFill':
|
|
|
- dataType = 'int32';
|
|
|
+ initializer.dataType = 'int32';
|
|
|
break;
|
|
|
case 'GivenTensorInt64Fill':
|
|
|
- dataType = 'int64';
|
|
|
+ initializer.dataType = 'int64';
|
|
|
break;
|
|
|
case 'GivenTensorStringFill':
|
|
|
- dataType = 'string';
|
|
|
+ initializer.dataType = 'string';
|
|
|
break;
|
|
|
case 'Int8GivenIntTensorFill':
|
|
|
- dataType = 'int32';
|
|
|
+ initializer.dataType = 'int32';
|
|
|
break;
|
|
|
case 'Int8GivenTensorFill':
|
|
|
- dataType = 'int8';
|
|
|
+ initializer.dataType = 'int8';
|
|
|
break;
|
|
|
- default:
|
|
|
+ case 'XavierFill':
|
|
|
break;
|
|
|
+ case 'ConstantFill':
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ throw new caffe2.Error("Unknown init op '" + op.type + "'.");
|
|
|
+ }
|
|
|
+ if (initializer.values && (initializer.values.floats.length !== 1 || initializer.values.floats[0] !== 0)) {
|
|
|
+ initializer.input = false;
|
|
|
}
|
|
|
- if (dataType) {
|
|
|
- op.dataType = dataType;
|
|
|
- initializers[name] = op;
|
|
|
- }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -241,7 +263,7 @@ caffe2.Graph = class {
|
|
|
op.input = op.input.map((input) => scope[input] ? scope[input] : input);
|
|
|
op.output = op.output.map((output) => {
|
|
|
if (scope[output]) {
|
|
|
- let next = output + '\n' + index.toString(); // custom argument id
|
|
|
+ const next = output + '\n' + index.toString(); // custom argument id
|
|
|
scope[output] = next;
|
|
|
return next;
|
|
|
}
|
|
|
@@ -254,7 +276,7 @@ caffe2.Graph = class {
|
|
|
let lastNode = null;
|
|
|
let lastOutput = null;
|
|
|
for (let op of netDef.op) {
|
|
|
- let node = new caffe2.Node(metadata, op, initializers);
|
|
|
+ let node = new caffe2.Node(metadata, op, inputs);
|
|
|
if (op.input.length == 1 &&
|
|
|
op.output.length >= 1 &&
|
|
|
op.input[0].split('\n').shift() == op.output[0].split('\n').shift() &&
|
|
|
@@ -274,11 +296,14 @@ caffe2.Graph = class {
|
|
|
}
|
|
|
|
|
|
this._inputs = [];
|
|
|
- let inputs = Object.keys(initializers);
|
|
|
- for (let input of inputs) {
|
|
|
- if (inputs.length == 1 || !input.startsWith('caffe.')) {
|
|
|
- this._inputs.push(new caffe2.Parameter(input, [ new caffe2.Argument(input, null, null) ]));
|
|
|
+ for (let input of netDef.external_input) {
|
|
|
+ if (netDef.external_input.length > 1) {
|
|
|
+ const initializer = inputs.get(input);
|
|
|
+ if (initializer && initializer.input === false) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
}
|
|
|
+ this._inputs.push(new caffe2.Parameter(input, [ new caffe2.Argument(input, null, null) ]));
|
|
|
}
|
|
|
|
|
|
this._outputs = [];
|
|
|
@@ -379,16 +404,25 @@ caffe2.Node = class {
|
|
|
|
|
|
const schema = metadata.type(this._operator);
|
|
|
|
|
|
- let inputs = op.input;
|
|
|
+ const inputs = op.input;
|
|
|
+ const outputs = op.output;
|
|
|
+
|
|
|
let tensors = {};
|
|
|
let index = 0;
|
|
|
- for (let input of inputs) {
|
|
|
- if (index > 0 && initializers[input]) {
|
|
|
- tensors[input] = new caffe2.Tensor(input, initializers[input], 'Initializer');
|
|
|
- delete initializers[input];
|
|
|
+ for (const input of inputs) {
|
|
|
+ if (index > 0 && initializers.has(input)) {
|
|
|
+ const initializer = initializers.get(input);
|
|
|
+ tensors[input] = new caffe2.Tensor(input, initializer);
|
|
|
+ initializer.input = false;
|
|
|
}
|
|
|
index++;
|
|
|
}
|
|
|
+ for (const output of outputs) {
|
|
|
+ if (initializers.has(output)) {
|
|
|
+ const initializer = initializers.get(output);
|
|
|
+ initializer.input = false;
|
|
|
+ }
|
|
|
+ }
|
|
|
this._inputs = [];
|
|
|
let inputIndex = 0;
|
|
|
if (schema && schema.inputs) {
|
|
|
@@ -412,7 +446,6 @@ caffe2.Node = class {
|
|
|
}));
|
|
|
}
|
|
|
|
|
|
- let outputs = op.output;
|
|
|
this._outputs = [];
|
|
|
let outputIndex = 0;
|
|
|
if (schema && schema.outputs) {
|
|
|
@@ -540,26 +573,13 @@ caffe2.Attribute = class {
|
|
|
|
|
|
caffe2.Tensor = class {
|
|
|
|
|
|
- constructor(name, tensor, kind) {
|
|
|
+ constructor(name, tensor) {
|
|
|
this._name = name;
|
|
|
- this._kind = kind;
|
|
|
-
|
|
|
- let args = {};
|
|
|
- if (tensor && tensor.arg) {
|
|
|
- for (let arg of tensor.arg) {
|
|
|
- args[arg.name] = arg;
|
|
|
- }
|
|
|
- }
|
|
|
- let shape = null;
|
|
|
- if (args.shape && args.shape.ints) {
|
|
|
- shape = args.shape.ints;
|
|
|
- }
|
|
|
- if (args.values) {
|
|
|
- this._values = args.values;
|
|
|
- }
|
|
|
- this._scale = Object.prototype.hasOwnProperty.call(args, 'Y_scale') ? args.Y_scale.f : 0;
|
|
|
- this._zeroPoint = Object.prototype.hasOwnProperty.call(args, 'Y_zero_point') ? args.Y_zero_point.i : 0;
|
|
|
+ const shape = tensor.shape && tensor.shape.ints ? tensor.shape.ints : null;
|
|
|
this._type = new caffe2.TensorType(tensor.dataType, new caffe2.TensorShape(shape));
|
|
|
+ this._values = tensor.values || null;
|
|
|
+ this._scale = tensor.Y_scale ? tensor.Y_scale.f : 0;
|
|
|
+ this._zeroPoint = tensor.Y_zero_point ? tensor.Y_zero_point.i : 0;
|
|
|
}
|
|
|
|
|
|
get name() {
|
|
|
@@ -571,7 +591,7 @@ caffe2.Tensor = class {
|
|
|
}
|
|
|
|
|
|
get kind() {
|
|
|
- return this._kind;
|
|
|
+ return 'Initializer';
|
|
|
}
|
|
|
|
|
|
get quantization() {
|