Explorar el Código

Fix TensorFlow Lite int16 support (#436)

Lutz Roeder hace 6 años
padre
commit
8042bcd703
Se han modificado 3 ficheros con 55 adiciones y 10 borrados
  1. 46 0
      src/tflite-metadata.json
  2. 9 0
      src/tflite.js
  3. 0 10
      tools/tflite

+ 46 - 0
src/tflite-metadata.json

@@ -664,5 +664,51 @@
         { "name": "output" }
       ]
     }
+  },
+  {
+    "name": "Quantize",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
+  {
+    "name": "Dequantize",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
+  {
+    "name": "Minimum",
+    "schema": {
+      "inputs": [
+        { "name": "input1" },
+        { "name": "input2" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
+  {
+    "name": "Maximum",
+    "schema": {
+      "inputs": [
+        { "name": "input1" },
+        { "name": "input2" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
   }
 ]

+ 9 - 0
src/tflite.js

@@ -229,6 +229,10 @@ tflite.Node = class {
                 case 'Sum':
                     optionsTypeName = 'ReducerOptions';
                     break;
+                case 'Minimum':
+                case 'Maximum':
+                    optionsTypeName = 'MaximumMinimumOptions';
+                    break;
             }
             const optionsType = tflite.Node._getType(optionsTypeName);
             if (typeof optionsType === 'function') {
@@ -576,6 +580,11 @@ tflite.Tensor = class {
                         context.index += 1;
                         context.count++;
                         break;
+                    case 'int16':
+                        results.push(context.data.getInt16(context.index));
+                        context.index += 2;
+                        context.count++;
+                        break;
                     case 'int32':
                         results.push(context.data.getInt32(context.index, true));
                         context.index += 4;

+ 0 - 10
tools/tflite

@@ -76,15 +76,6 @@ EOT
     fi
 }
 
-visualize() {
-    bold "tflite visualize"
-    venv
-    export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
-    ${python} -m pip install --quiet tensorflow
-    ${python} ./third_party/src/tensorflow/tensorflow/lite/tools/visualize.py $@
-    deactivate
-}
-
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
@@ -92,6 +83,5 @@ while [ "$#" != 0 ]; do
         "sync") sync;;
         "install") install;;
         "schema") schema;;
-        "visualize") visualize ${1} ${2} && shift && shift;;
     esac
 done