Просмотр исходного кода

show tflite quantization values

Lutz Roeder 8 лет назад
Родитель
Сommit
90326a3bfb
2 измененных файлов с 97 добавлено и 78 удалено
  1. 3 0
      src/view-template.js
  2. 94 78
      src/view-tflite.js

+ 3 - 0
src/view-template.js

@@ -48,6 +48,9 @@ var itemsTemplate = `
 {{#items}}
 <div class='item'>    
 <b>{{{name}}}{{#if type}}: {{/if}}</b>{{#if type}}<code>{{{type}}}</code>{{/if}}<br>
+{{#if quantization}}
+<pre>{{{quantization}}}</pre>
+{{/if}}
 {{#if doc}}
 {{{doc}}}
 {{/if}}

+ 94 - 78
src/view-tflite.js

@@ -1,7 +1,5 @@
 /*jshint esversion: 6 */
 
-// Experimental
-
 class TensorFlowLiteModel {
     
     constructor(hostService) {
@@ -109,7 +107,7 @@ class TensorFlowLiteGraph {
                 this._inputs.push({ 
                     id: tensorIndex.toString(),
                     name: tensor.name(),
-                    type: this.formatTensorType(tensor) 
+                    type: TensorFlowLiteTensor.formatTensorType(tensor) 
                 });
             }
         }
@@ -126,7 +124,7 @@ class TensorFlowLiteGraph {
                 this._outputs.push({ 
                     id: tensorIndex.toString(),
                     name: tensor.name(),
-                    type: this.formatTensorType(tensor) 
+                    type: TensorFlowLiteTensor.formatTensorType(tensor) 
                 });
             }
         }
@@ -143,9 +141,7 @@ class TensorFlowLiteGraph {
                 var tensor = graph.tensors(i);
                 var buffer = model.buffers(tensor.buffer());
                 if (buffer.dataLength() > 0) {
-                    tensor = this.formatTensor(tensor, buffer);
-                    tensor.id = i.toString();
-                    this._initializers.push(tensor);
+                    this._initializers.push(new TensorFlowLiteTensor(tensor, buffer, i));
                 }
             }    
         }
@@ -172,41 +168,6 @@ class TensorFlowLiteGraph {
         } 
         return results;
     }
-
-    formatTensorType(tensor) {
-        if (!this.tensorTypeMap)
-        {
-            this.tensorTypeMap = {};
-            this.tensorTypeMap[tflite.TensorType.FLOAT32] = 'float';
-            this.tensorTypeMap[tflite.TensorType.FLOAT16] = 'float16';
-            this.tensorTypeMap[tflite.TensorType.INT32] = 'int32';
-            this.tensorTypeMap[tflite.TensorType.UINT8] = 'byte';
-            this.tensorTypeMap[tflite.TensorType.INT64] = 'int64';
-            this.tensorTypeMap[tflite.TensorType.STRING] = 'string';
-        }
-        var result = this.tensorTypeMap[tensor.type()]; 
-        if (!result) {
-            debugger;
-            result = '?';
-        }
-        var shapeLength = tensor.shapeLength();
-        if (shapeLength > 0) {
-            var dimensions = [];
-            for (var i = 0; i < shapeLength; i++) {
-                dimensions.push(tensor.shape(i).toString());
-            }
-            result += '[' + dimensions.join(',') + ']';
-        }
-        return result;
-    }
-
-    formatTensor(tensor, buffer) {
-        var result = {};
-        result.name = tensor.name();
-        result.type = this.formatTensorType(tensor);
-        result.value = function () { return new TensorFlowLiteTensorFormatter(tensor, buffer).toString(); };
-        return result;
-    }
 }
 
 class TensorFlowLiteNode {
@@ -239,7 +200,7 @@ class TensorFlowLiteNode {
                 this._inputs.push({
                     id: tensorIndex.toString(),
                     name: operatorMetadata.getInputName(this.operator, i),
-                    type: this._graph.formatTensorType(tensor)
+                    type: TensorFlowLiteTensor.formatTensorType(tensor)
                 });
             }
         }
@@ -259,7 +220,7 @@ class TensorFlowLiteNode {
                 this._outputs.push({
                     id: tensorIndex.toString(),
                     name: operatorMetadata.getOutputName(this.operator, i),
-                    type: this._graph.formatTensorType(tensor)
+                    type: TensorFlowLiteTensor.formatTensorType(tensor)
                 });
             }
         }
@@ -392,91 +353,117 @@ class TensorFlowLiteNode {
     }
 }
 
-class TensorFlowLiteTensorFormatter {
+class TensorFlowLiteTensor {
+
+    constructor(tensor, buffer, index) {
+        this._index = index;
+        this._tensor = tensor;
+        this._buffer = buffer;
+    }
+
+    get id() {
+        return this._index.toString();
+    }
 
-    constructor(tensor, buffer) {
-        this.tensor = tensor;
-        this.buffer = buffer;
-        if (window.TextDecoder) {
-            this.utf8Decoder = new TextDecoder('utf-8');
+    get name() {
+        return this._tensor.name();
+    }
+
+    get type() {
+        return TensorFlowLiteTensor.formatTensorType(this._tensor);
+    }
+
+    get quantization() {
+        var quantization = this._tensor.quantization();
+        if (quantization && quantization.scaleLength() == 1 && quantization.zeroPointLength() == 1) {
+            var scale = quantization.scale(0);
+            if (scale != 0) {
+                var zeroPoint = quantization.zeroPoint(0).toFloat64();
+                return 'f = ' + scale.toString() + ' * ' + (zeroPoint != 0 ? ('(q - ' + zeroPoint.toString() + ')') : 'q');
+            }
         }
+        return null;
     }
 
-    toString() {
+    get value() {
         var size = 1;
-        for (var i = 0; i < this.tensor.shapeLength(); i++) {
-            size *= this.tensor.shape(i);
+        for (var i = 0; i < this._tensor.shapeLength(); i++) {
+            size *= this._tensor.shape(i);
         }
         if (size > 65536) {
             return 'Tensor is too large to display.';
         }
 
-        if (this.buffer.dataLength() == 0) {
+        if (this._buffer.dataLength() == 0) {
             return 'Tensor data is empty.';
         }
 
-        var array = this.buffer.dataArray();
-        this.data = new DataView(array.buffer, array.byteOffset, array.byteLength);
+        var array = this._buffer.dataArray();
+        this._data = new DataView(array.buffer, array.byteOffset, array.byteLength);
 
-        if (this.tensor.type() == tflite.TensorType.STRING) {
+        if (this._tensor.type() == tflite.TensorType.STRING) {
+            var utf8Decoder = window.TextDecoder ? new TextDecoder('utf-8') : null;
             var offset = 0;
-            var count = this.data.getInt32(0, true);
+            var count = this._data.getInt32(0, true);
             offset += 4;
             var offsetTable = [];
             for (var j = 0; j < count; j++) {
-                offsetTable.push(this.data.getInt32(offset, true));
+                offsetTable.push(this._data.getInt32(offset, true));
                 offset += 4;
             }
             offsetTable.push(array.length);
             var stringTable = [];
             for (var k = 0; k < count; k++) {
                 var textArray = array.subarray(offsetTable[k], offsetTable[k + 1]);
-                if (this.utf8Decoder) {
-                    stringTable.push(this.utf8Decoder.decode(textArray));
+                if (utf8Decoder) {
+                    stringTable.push(utf8Decoder.decode(textArray));
                 }
                 else {
                     stringTable.push(String.fromCharCode.apply(null, textArray));
                 }
             }
-            this.data = stringTable;
+            this._data = stringTable;
         }
 
-        this.index = 0;                
+        this._index = 0;                
         var result = this.read(0);
-        this.data = null;
+
+        delete this._index;        
+        delete this._data;
+        delete this._utf8Decoder;
 
         return JSON.stringify(result, null, 4);
     }
 
     read(dimension) {
-        var size = this.tensor.shape(dimension);
+        var size = this._tensor.shape(dimension);
         var results = [];
-        if (dimension == this.tensor.shapeLength() - 1) {
+        if (dimension == this._tensor.shapeLength() - 1) {
             for (var i = 0; i < size; i++) {
-                switch (this.tensor.type())
+                switch (this._tensor.type())
                 {
                     case tflite.TensorType.FLOAT32:
-                        results.push(this.data.getFloat32(this.index, true));
-                        this.index += 4;
+                        results.push(this._data.getFloat32(this._index, true));
+                        this._index += 4;
                         break;
                     case tflite.TensorType.FLOAT16:
-                        results.push(this.decodeNumberFromFloat16(this.data.getUint16(this.index, true)));
-                        this.index += 2;
+                        results.push(this.decodeNumberFromFloat16(this._data.getUint16(this._index, true)));
+                        this._index += 2;
                         break;
                     case tflite.TensorType.UINT8:
-                        results.push(this.data.getUint8(this.index));
-                        this.index += 4;
+                        results.push(this._data.getUint8(this._index));
+                        this._index += 1;
                         break;
                     case tflite.TensorType.INT32:
-                        results.push(this.data.getInt32(this.index, true));
-                        this.index += 4;
+                        results.push(this._data.getInt32(this._index, true));
+                        this._index += 4;
                         break;
                     case tflite.TensorType.INT64:
-                        results.push(new Int64(this.data.getInt64(this.index, true)));
-                        this.index += 8;
+                        results.push(new Int64(this._data.getInt64(this._index, true)));
+                        this._index += 8;
                         break;
                     case tflite.TensorType.STRING:
-                        results.push(this.data[this.index++]);
+                        results.push(this._data[this._index++]);
                         break;
                     default:
                         debugger;
@@ -504,6 +491,35 @@ class TensorFlowLiteTensorFormatter {
         }
         return (s ? -1 : 1) * Math.pow(2, e-15) * (1 + (f / Math.pow(2, 10)));
     }
+
+    static formatTensorType(tensor) {
+        if (!TensorFlowLiteTensor.tensorTypeMap)
+        {
+            var map = {};
+            map[tflite.TensorType.FLOAT32] = 'float';
+            map[tflite.TensorType.FLOAT16] = 'float16';
+            map[tflite.TensorType.INT32] = 'int32';
+            map[tflite.TensorType.UINT8] = 'byte';
+            map[tflite.TensorType.INT64] = 'int64';
+            map[tflite.TensorType.STRING] = 'string';
+            TensorFlowLiteTensor.tensorTypeMap = map;
+        }
+        var result = TensorFlowLiteTensor.tensorTypeMap[tensor.type()]; 
+        if (!result) {
+            debugger;
+            result = '?';
+        }
+        var shapeLength = tensor.shapeLength();
+        if (shapeLength > 0) {
+            var dimensions = [];
+            for (var i = 0; i < shapeLength; i++) {
+                dimensions.push(tensor.shape(i).toString());
+            }
+            result += '[' + dimensions.join(',') + ']';
+        }
+        return result;
+    }
+
 }
 
 class TensorFlowLiteOperatorMetadata {