Просмотр исходного кода

TensorFlow float tensor rendering

Lutz Roeder 8 лет назад
Родитель
Сommit
e09d586033
4 измененных файлов с 84 добавлено и 6 удалено
  1. 1 1
      src/view-electron.html
  2. 1 1
      src/view-onnx.js
  3. 77 3
      src/view-tf.js
  4. 5 1
      src/view.js

+ 1 - 1
src/view-electron.html

@@ -8,7 +8,7 @@
 <body>
 <div id='welcome' class='background' style='display: block'>
     <img id='logo' class='center' src='logo.svg' width='128' height='128'>
-    <button id='open-file-button' class='center' style='top: 200px; width: 150px; opacity: 0;'>Open Model...</button>
+    <button id='open-file-button' class='center' style='top: 200px; width: 125px; opacity: 0;'>Open Model...</button>
     <div id='spinner' class='spinner' style='display: none'></div>
     <!-- Preload fonts to workaround Chrome SVG layout issue -->
     <div style='font-weight: normal; color: #e6e6e6; user-select: none;'>.</div>

+ 1 - 1
src/view-onnx.js

@@ -456,8 +456,8 @@ class OnnxTensor {
     }
 
     read(dimension) {
-        var size = this._tensor.dims[dimension];
         var results = [];
+        var size = this._tensor.dims[dimension];
         if (dimension == this._tensor.dims.length - 1) {
             for (var i = 0; i < size; i++) {
                 if (this._count > 10000) {

+ 77 - 3
src/view-tf.js

@@ -150,7 +150,7 @@ class TensorFlowGraph {
                 if (tensor) {
                     this._initializerMap[node.input[0]] = "-";
                     tensor._id = node.output[0]; // TODO update tensor id
-                    tensor._title = 'Constant Identity';
+                    tensor._title = 'Identity Constant';
                     this._initializerMap[node.output[0]] = tensor;
                 }
             }
@@ -402,9 +402,15 @@ class TensorFlowAttribute {
                 }
                 return list.type.map((type) => TensorFlowTensor.formatDataType(type)).join(', ');
             }
+            else if (list.shape && list.shape.length > 0) {
+                if (list.shape.length > 65536) {
+                    return "Too large to render.";
+                }
+                return list.shape.map((shape) => TensorFlowTensor.formatTensorShape(shape)).join(', ');
+            }
         }
         debugger;
-        return '?';        
+        return '';        
     }
 
     get hidden() {
@@ -444,7 +450,75 @@ class TensorFlowTensor {
     }
 
     get value() {
-        return '?';        
+        if (!this._tensor.dtype) {
+            return 'Tensor has no data type.';
+        }
+        if (!this._tensor.tensorShape) {
+            return 'Tensor has no dimensions.';
+        }
+
+        switch (this._tensor.dtype) {
+            case tensorflow.DataType.DT_FLOAT:
+                if (this._tensor.tensorContent && this._tensor.tensorContent.length > 0) {
+                    this._rawData = new DataView(this._tensor.tensorContent.buffer, this._tensor.tensorContent.byteOffset, this._tensor.tensorContent.byteLength)
+                }
+                else {
+                    return 'Tensor data is empty.';
+                }
+                break;
+            default:
+                debugger;
+                return 'Tensor data type is not implemented.';
+        }
+
+        this._index = 0;
+        this._count = 0;
+        var result = this.read(0);
+        delete this._index;
+        delete this._count;
+        delete this._data;
+        delete this._rawData;
+
+        return JSON.stringify(result, null, 4);
+    }
+
+    read(dimension) {
+        var results = [];
+        var dim = this._tensor.tensorShape.dim[dimension];
+        var size = dim.size;
+        if (dimension == this._tensor.tensorShape.dim.length - 1) {
+            for (var i = 0; i < size; i++) {
+                if (this._count > 10000) {
+                    results.push('...');
+                    return results;
+                }
+                if (this._data) {
+                    results.push(this._data[this._index++]);
+                }
+                else {
+                    if (this._rawData) {
+                        switch (this._tensor.dtype)
+                        {
+                            case tensorflow.DataType.DT_FLOAT:
+                                results.push(this._rawData.getFloat32(this._index, true));
+                                this._index += 4;
+                                this._count++;
+                                break;
+                        }
+                    }
+                }
+            }
+        }
+        else {
+            for (var j = 0; j < size; j++) {
+                if (this._count > 10000) {
+                    results.push('...');
+                    return results;
+                }
+                results.push(this.read(dimension + 1));
+            }
+        }
+        return results;
     }
 
     static formatTensorType(tensor) {

+ 5 - 1
src/view.js

@@ -104,7 +104,11 @@ function updateGraph(model) {
     var svg = dagreD3.d3.select(svgElement);
 
     var g = new dagreD3.graphlib.Graph();
-    g.setGraph({});
+    // g.setGraph({});
+    // g.setGraph({ ranker: 'network-simplex' });
+    g.setGraph({ ranker: 'tight-tree' });
+    // g.setGraph({ ranker: 'longest-path' });
+    // g.setGraph({ acyclicer: 'greedy' });
     g.setDefaultEdgeLabel(function() { return {}; });
 
     var nodeId = 0;