Просмотр исходного кода

Workaround tensorflow/tensorflow#38338

Lutz Roeder 6 лет назад
Родитель
Сommit
067e671849
2 измененных файлов с 49 добавлено и 24 удалено
  1. 38 2
      src/tflite-schema.js
  2. 11 22
      src/tflite.js

+ 38 - 2
src/tflite-schema.js

@@ -11528,11 +11528,43 @@ TFLITE.BatchMatMulOptions.getSizePrefixedRootAsBatchMatMulOptions = function(bb,
   return (obj || new TFLITE.BatchMatMulOptions).__init(bb.readInt32(bb.position()) + bb.position(), bb);
 };
 
+/**
+ * @returns {boolean}
+ */
+TFLITE.BatchMatMulOptions.prototype.adjointLhs = function() {
+  var offset = this.bb.__offset(this.bb_pos, 4);
+  return offset ? !!this.bb.readInt8(this.bb_pos + offset) : false;
+};
+
+/**
+ * @returns {boolean}
+ */
+TFLITE.BatchMatMulOptions.prototype.adjointRhs = function() {
+  var offset = this.bb.__offset(this.bb_pos, 6);
+  return offset ? !!this.bb.readInt8(this.bb_pos + offset) : false;
+};
+
 /**
  * @param {flatbuffers.Builder} builder
  */
 TFLITE.BatchMatMulOptions.startBatchMatMulOptions = function(builder) {
-  builder.startObject(0);
+  builder.startObject(2);
+};
+
+/**
+ * @param {flatbuffers.Builder} builder
+ * @param {boolean} adjointLhs
+ */
+TFLITE.BatchMatMulOptions.addAdjointLhs = function(builder, adjointLhs) {
+  builder.addFieldInt8(0, +adjointLhs, +false);
+};
+
+/**
+ * @param {flatbuffers.Builder} builder
+ * @param {boolean} adjointRhs
+ */
+TFLITE.BatchMatMulOptions.addAdjointRhs = function(builder, adjointRhs) {
+  builder.addFieldInt8(1, +adjointRhs, +false);
 };
 
 /**
@@ -11546,10 +11578,14 @@ TFLITE.BatchMatMulOptions.endBatchMatMulOptions = function(builder) {
 
 /**
  * @param {flatbuffers.Builder} builder
+ * @param {boolean} adjointLhs
+ * @param {boolean} adjointRhs
  * @returns {flatbuffers.Offset}
  */
-TFLITE.BatchMatMulOptions.createBatchMatMulOptions = function(builder) {
+TFLITE.BatchMatMulOptions.createBatchMatMulOptions = function(builder, adjointLhs, adjointRhs) {
   TFLITE.BatchMatMulOptions.startBatchMatMulOptions(builder);
+  TFLITE.BatchMatMulOptions.addAdjointLhs(builder, adjointLhs);
+  TFLITE.BatchMatMulOptions.addAdjointRhs(builder, adjointRhs);
   return TFLITE.BatchMatMulOptions.endBatchMatMulOptions(builder);
 }
 

+ 11 - 22
src/tflite.js

@@ -52,10 +52,15 @@ tflite.Model = class {
         let builtinOperatorMap = {};
         for (const key of Object.keys(tflite.schema.BuiltinOperator)) {
             const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]);
-            const builtinOperatorIndex = tflite.schema.BuiltinOperator[key]; 
-            builtinOperatorMap[builtinOperatorIndex] = key.split('_').map((s) => {
-                return (s.length < 1 || upperCase.has(s)) ? s : s.substring(0, 1) + s.substring(1).toLowerCase();
-            }).join('');
+            const index = tflite.schema.BuiltinOperator[key];
+            switch (key) {
+                case 'BATCH_MATMUL':
+                    builtinOperatorMap[index] = "BatchMatMul";
+                    break;
+                default:
+                    builtinOperatorMap[index] = key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
+                    break;
+            }
         }
         for (let operatorIndex = 0; operatorIndex < model.operatorCodesLength(); operatorIndex++) {
             const operatorCode = model.operatorCodes(operatorIndex);
@@ -194,7 +199,7 @@ tflite.Node = class {
                 this._outputs.push(new tflite.Parameter(outputName, true, [ argument ]));
             }
             this._attributes = [];
-            if (operator.custom) {
+            if (operator.custom && node.customOptionsLength() > 0) {
                 let custom = [];
                 for (let m = 0; m < node.customOptionsLength(); m++) {
                     custom.push(node.customOptions(m));
@@ -218,7 +223,7 @@ tflite.Node = class {
                     optionsTypeName = 'MaximumMinimumOptions';
                     break;
             }
-            const optionsType = tflite.Node._getType(optionsTypeName);
+            const optionsType = tflite.schema[optionsTypeName] || null;
             if (typeof optionsType === 'function') {
                 let options = Reflect.construct(optionsType, []);
                 options = node.builtinOptions(options);
@@ -307,22 +312,6 @@ tflite.Node = class {
     get attributes() {
         return this._attributes;
     }
-
-    static _getType(name) {
-        const list = name.split('.');
-        let type = tflite.schema;
-        while (list.length > 0) {
-            const item = list.shift();
-            type = type[item];
-            if (!type) {
-                return null;
-            }
-        }
-        if (type == tflite.schema) {
-            return null;
-        }
-        return type;
-    }
 };
 
 tflite.Attribute = class {