|
|
@@ -64,7 +64,7 @@ tflite.Model = class {
|
|
|
this._graphs = [];
|
|
|
this._format = 'TensorFlow Lite v' + model.version().toString();
|
|
|
this._description = model.description() || '';
|
|
|
- let operatorCodeList = [];
|
|
|
+ let operators = [];
|
|
|
let builtinOperatorMap = {};
|
|
|
for (const key of Object.keys(tflite.schema.BuiltinOperator)) {
|
|
|
const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]);
|
|
|
@@ -76,14 +76,14 @@ tflite.Model = class {
|
|
|
for (let operatorIndex = 0; operatorIndex < model.operatorCodesLength(); operatorIndex++) {
|
|
|
const operatorCode = model.operatorCodes(operatorIndex);
|
|
|
const builtinCode = operatorCode.builtinCode();
|
|
|
- operatorCodeList.push(builtinCode === tflite.schema.BuiltinOperator.CUSTOM ?
|
|
|
+ operators.push(builtinCode === tflite.schema.BuiltinOperator.CUSTOM ?
|
|
|
{ name: operatorCode.customCode(), custom: true } :
|
|
|
{ name: builtinOperatorMap[builtinCode] });
|
|
|
}
|
|
|
const subgraphsLength = model.subgraphsLength();
|
|
|
for (let subgraph = 0; subgraph < subgraphsLength; subgraph++) {
|
|
|
const name = (subgraphsLength > 1) ? subgraph.toString() : '';
|
|
|
- this._graphs.push(new tflite.Graph(metadata, model.subgraphs(subgraph), name, operatorCodeList, model));
|
|
|
+ this._graphs.push(new tflite.Graph(metadata, model.subgraphs(subgraph), name, operators, model));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -102,7 +102,7 @@ tflite.Model = class {
|
|
|
|
|
|
tflite.Graph = class {
|
|
|
|
|
|
- constructor(metadata, graph, name, operatorCodeList, model) {
|
|
|
+ constructor(metadata, graph, name, operators, model) {
|
|
|
this._name = graph.name() || name;
|
|
|
this._nodes = [];
|
|
|
this._inputs = [];
|
|
|
@@ -123,7 +123,7 @@ tflite.Graph = class {
|
|
|
for (let j = 0; j < graph.operatorsLength(); j++) {
|
|
|
const node = graph.operators(j);
|
|
|
const opcodeIndex = node.opcodeIndex();
|
|
|
- const operator = (opcodeIndex < operatorCodeList.length) ? operatorCodeList[opcodeIndex] : { name: '(' + opcodeIndex.toString() + ')' };
|
|
|
+ const operator = (opcodeIndex < operators.length) ? operators[opcodeIndex] : { name: '(' + opcodeIndex.toString() + ')' };
|
|
|
this._nodes.push(new tflite.Node(metadata, node, operator, args));
|
|
|
}
|
|
|
for (let k = 0; k < graph.inputsLength(); k++) {
|
|
|
@@ -219,10 +219,16 @@ tflite.Node = class {
|
|
|
}
|
|
|
let optionsTypeName = this.operator + 'Options';
|
|
|
switch (this.operator) {
|
|
|
- case 'MaxPool2D':
|
|
|
case 'AveragePool2D':
|
|
|
+ case 'MaxPool2D':
|
|
|
optionsTypeName = 'Pool2DOptions';
|
|
|
break;
|
|
|
+ case 'Mean':
|
|
|
+ case 'ReduceMax':
|
|
|
+ case 'ReduceMin':
|
|
|
+ case 'Sum':
|
|
|
+ optionsTypeName = 'ReducerOptions';
|
|
|
+ break;
|
|
|
}
|
|
|
const optionsType = tflite.Node._getType(optionsTypeName);
|
|
|
if (typeof optionsType === 'function') {
|
|
|
@@ -260,19 +266,14 @@ tflite.Node = class {
|
|
|
else {
|
|
|
value = options[attributeName]();
|
|
|
}
|
|
|
- const attribute = new tflite.Attribute(this._metadata, this.operator, attributeName, value);
|
|
|
- if (attribute.name == 'fused_activation_function') {
|
|
|
- value = attribute.value;
|
|
|
- if (attribute.value != 'NONE') {
|
|
|
- const activationFunctionMap = { 'RELU': 'Relu', 'RELU_N1_TO_1': "ReluN1To1", "RELU6": "Relu6", "TANH": "Tanh", "SIGN_BIT": "SignBit" };
|
|
|
- if (activationFunctionMap[value]) {
|
|
|
- value = activationFunctionMap[value];
|
|
|
- }
|
|
|
- this._chain = [];
|
|
|
- this._chain.push(new tflite.Node(metadata, null, { name: value }, []));
|
|
|
+ if (attributeName === 'fusedActivationFunction' && value !== 0) {
|
|
|
+ const activationFunctionMap = { 1: 'Relu', 2: "ReluN1To1", 3: "Relu6", 4: "Tanh", 5: "SignBit" };
|
|
|
+ if (activationFunctionMap[value]) {
|
|
|
+ value = activationFunctionMap[value];
|
|
|
}
|
|
|
+ this._chain = [ new tflite.Node(metadata, null, { name: value }, []) ];
|
|
|
}
|
|
|
- this._attributes.push(attribute);
|
|
|
+ this._attributes.push(new tflite.Attribute(this._metadata, this.operator, attributeName, value));
|
|
|
}
|
|
|
}
|
|
|
}
|