|
|
@@ -52,10 +52,15 @@ tflite.Model = class {
|
|
|
let builtinOperatorMap = {};
|
|
|
for (const key of Object.keys(tflite.schema.BuiltinOperator)) {
|
|
|
const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]);
|
|
|
- const builtinOperatorIndex = tflite.schema.BuiltinOperator[key];
|
|
|
- builtinOperatorMap[builtinOperatorIndex] = key.split('_').map((s) => {
|
|
|
- return (s.length < 1 || upperCase.has(s)) ? s : s.substring(0, 1) + s.substring(1).toLowerCase();
|
|
|
- }).join('');
|
|
|
+ const index = tflite.schema.BuiltinOperator[key];
|
|
|
+ switch (key) {
|
|
|
+ case 'BATCH_MATMUL':
|
|
|
+ builtinOperatorMap[index] = "BatchMatMul";
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ builtinOperatorMap[index] = key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
for (let operatorIndex = 0; operatorIndex < model.operatorCodesLength(); operatorIndex++) {
|
|
|
const operatorCode = model.operatorCodes(operatorIndex);
|
|
|
@@ -194,7 +199,7 @@ tflite.Node = class {
|
|
|
this._outputs.push(new tflite.Parameter(outputName, true, [ argument ]));
|
|
|
}
|
|
|
this._attributes = [];
|
|
|
- if (operator.custom) {
|
|
|
+ if (operator.custom && node.customOptionsLength() > 0) {
|
|
|
let custom = [];
|
|
|
for (let m = 0; m < node.customOptionsLength(); m++) {
|
|
|
custom.push(node.customOptions(m));
|
|
|
@@ -218,7 +223,7 @@ tflite.Node = class {
|
|
|
optionsTypeName = 'MaximumMinimumOptions';
|
|
|
break;
|
|
|
}
|
|
|
- const optionsType = tflite.Node._getType(optionsTypeName);
|
|
|
+ const optionsType = tflite.schema[optionsTypeName] || null;
|
|
|
if (typeof optionsType === 'function') {
|
|
|
let options = Reflect.construct(optionsType, []);
|
|
|
options = node.builtinOptions(options);
|
|
|
@@ -307,22 +312,6 @@ tflite.Node = class {
|
|
|
get attributes() {
|
|
|
return this._attributes;
|
|
|
}
|
|
|
-
|
|
|
- static _getType(name) {
|
|
|
- const list = name.split('.');
|
|
|
- let type = tflite.schema;
|
|
|
- while (list.length > 0) {
|
|
|
- const item = list.shift();
|
|
|
- type = type[item];
|
|
|
- if (!type) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- }
|
|
|
- if (type == tflite.schema) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- return type;
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
tflite.Attribute = class {
|