Quellcode durchsuchen

Dashed border for nodes with control dependencies

Lutz Roeder vor 8 Jahren
Ursprung
Commit
67934cfebe
6 geänderte Dateien mit 78 neuen und 71 gelöschten Zeilen
  1. 0 4
      src/view-onnx.js
  2. 6 4
      src/view-render.css
  3. 20 12
      src/view-render.js
  4. 39 27
      src/view-tf.js
  5. 0 4
      src/view-tflite.js
  6. 13 20
      src/view.js

+ 0 - 4
src/view-onnx.js

@@ -233,10 +233,6 @@ class OnnxNode {
         return [];
     }
 
-    get dependencies() {
-        return [];
-    }
-
     get outputs() {
         return OnnxOperatorMetadata.operatorMetadata.getOutputs(this._node);
     }

+ 6 - 4
src/view-render.css

@@ -1,6 +1,9 @@
 
-.node path { stroke: #000; fill: none; stroke-width: 1px; }
-.node line { stroke: #000; fill: none; stroke-width: 1px; }
+.node path { stroke: #333; fill: none; stroke-width: 1px; }
+.node line { stroke: #333; fill: none; stroke-width: 1px; }
+
+.node-control-dependency { stroke-dasharray: 4, 1; } 
+
 .node-item path { stroke-width: 0; stroke: #000; fill: #fff; }
 .node-item text { font-family: 'Open Sans', --apple-system, "Helvetica Neue", Helvetica, Arial, sans-serf; font-size: 10px; font-weight: 600;  text-rendering: geometricPrecision; }
 
@@ -21,6 +24,5 @@
 
 .edge-label text { font-family: 'Open Sans', --apple-system, "Helvetica Neue", Helvetica, Arial, sans-serf; font-size: 10px; }
 .edge-path { stroke: #000; stroke-width: 1px; fill: none; }
-.edge-path-control { stroke-dasharray: 5, 5; }
 
-.cluster rect { stroke: #000; fill: #000; fill-opacity: 0.04; stroke-width: 0; }
+.cluster rect { stroke: #000; fill: #000; fill-opacity: 0.04; stroke-opacity: 0.06; stroke-width: 1px; }

+ 20 - 12
src/view-render.js

@@ -172,8 +172,8 @@ class GraphRenderer {
 class NodeFormatter {
 
     constructor() {
-        this.items = [];
-        this.attributes = [];
+        this._items = [];
+        this._attributes = [];
     }
 
     addItem(content, className, title, handler) {
@@ -190,26 +190,30 @@ class NodeFormatter {
         if (handler) {
             item.handler = handler;
         }
-        this.items.push(item);
+        this._items.push(item);
     }
 
     addAttribute(name, value, title) {
-        this.attributes.push({ name: name, value: value, title: title });
+        this._attributes.push({ name: name, value: value, title: title });
     }
 
     setAttributeHandler(handler) {
-        this.attributeHandler = handler;
+        this._attributeHandler = handler;
+    }
+
+    setControlDependencies() {
+        this._controlDependencies = true;
     }
 
     format(context) {
         var root = d3.select(context).append('g');
-        var hasAttributes = this.attributes && this.attributes.length > 0;
+        var hasAttributes = this._attributes && this._attributes.length > 0;
         var x = 0;
         var y = 0;
         var maxWidth = 0;
         var itemHeight = 0;
         var itemBoxes = [];
-        this.items.forEach((item, index) => {
+        this._items.forEach((item, index) => {
             var yPadding = 4;
             var xPadding = 7;
             var itemGroup = root.append('g').classed('node-item', true);
@@ -259,13 +263,13 @@ class NodeFormatter {
         if (hasAttributes)
         {
             var attributeGroup = root.append('g').classed('node-attribute', true);
-            if (this.attributeHandler) {
-                attributeGroup.on('click', this.attributeHandler);
+            if (this._attributeHandler) {
+                attributeGroup.on('click', this._attributeHandler);
             }
             attributesPath = attributeGroup.append('path');
             attributeGroup.attr('transform', 'translate(' + x + ',' + y + ')');
             attributesHeight += 4;
-            this.attributes.forEach((attribute) => {
+            this._attributes.forEach((attribute) => {
                 var yPadding = 1;
                 var xPadding = 4;
                 var text = attributeGroup.append('text').attr('xml:space', 'preserve');
@@ -287,7 +291,7 @@ class NodeFormatter {
         }
 
         if (maxWidth > itemWidth) {
-            var d = (maxWidth - itemWidth) / this.items.length;
+            var d = (maxWidth - itemWidth) / this._items.length;
             itemBoxes.forEach((itemBox, index) => {
                 itemBox.x = itemBox.x + (index * d);
                 itemBox.width = itemBox.width + d;
@@ -317,7 +321,11 @@ class NodeFormatter {
         if (hasAttributes) {
             root.append('line').classed('node', true).attr('x1', 0).attr('y1', itemHeight).attr('x2', maxWidth).attr('y2', itemHeight);
         }
-        root.append('path').classed('node', true).attr('d', this.roundedRect(0, 0, maxWidth, itemHeight + attributesHeight, true, true, true, true));
+        var border = root.append('path').classed('node', true).attr('d', this.roundedRect(0, 0, maxWidth, itemHeight + attributesHeight, true, true, true, true));
+
+        if (this._controlDependencies) {
+            border.classed('node-control-dependency', true);
+        }
 
         context.innerHTML = '';
         return root.node();

+ 39 - 27
src/view-tf.js

@@ -141,17 +141,15 @@ class TensorFlowGraph {
 
     get nodes() {
         this.update();
-        // graph.graphDef.node.forEach(function (node) {
-        //     console.log(node.name + ' [' + (!node.input ? "" : node.input.map(s => s).join(',')) + ']');
-        // });
         var results = [];
         this._graph.graphDef.node.forEach((node) => {
-             if (node.op != 'NoOp') {
+            if (node.output.filter(output => !output.startsWith('^')) != 0 ||
+                node.input.filter(input => !input.startsWith('^')).length > 0) {
                 var id = node.name + ':0';
                 if (!this._initializerMap[id] && !this._inputMap[id] /* && node.op != 'NoOp' */) {
                     results.push(new TensorFlowNode(this, node));
-                }    
-             }
+                }
+            }
         });
         return results;
     }
@@ -340,16 +338,6 @@ class TensorFlowNode {
         return [];
     }
 
-    get dependencies() {
-        var results = [];
-        this._node.input.forEach((input) => {
-            if (input.startsWith('^')) {
-                results.push(input.substring(1));
-            }
-        });
-        return results;
-    }
-
     get outputs() {
         return this._graph.metadata.getOutputs(this._node);
     }
@@ -475,7 +463,15 @@ class TensorFlowAttribute {
     }
 
     get tensor() {
-        return this._value.hasOwnProperty('tensor');
+        if (this._value.hasOwnProperty('tensor')) {
+            if (this._value.tensor.tensorShape && this._value.tensor.tensorShape.dim) {
+                if (this._value.tensor.tensorShape.dim.length == 0) {
+                    return false;
+                }
+            }
+            return true;
+        }
+        return false;
     }
 }
 
@@ -698,13 +694,16 @@ class TensorFlowTensor {
 
     static formatTensorShape(shape) {
         if (shape.dim) {
+            if (shape.unknownRank) {
+                return '[-]';
+            }
             if (shape.dim.length == 0) {
                 return '';
             }
             if (shape.dim.length == 1 && !shape.dim[0].size) {
                 return '[0]';
             }
-            return '[' + shape.dim.map((dim) => dim.size ? dim.size.toString() : '?').join(',') + ']';
+            return '[' + shape.dim.map((dim) => (dim.size && dim.size != -1) ? dim.size.toString() : '?').join(',') + ']';
         }
         debugger;
         return '?';
@@ -776,6 +775,15 @@ class TensorFlowGraphOperatorMetadata {
                 }
                 var result = {};
                 result.name = inputArg.name;
+                if (inputArg.type) {
+                    result.type = TensorFlowTensor.formatDataType(inputArg.type);
+                }
+                else if (inputArg.typeAttr) {
+                    result.type = inputArg.typeAttr;
+                }
+                else if (inputArg.typeListAttr) {
+                    result.type = inputArg.typeListAttr;
+                }
                 result.connections = node.input.slice(index, index + count).map((id) => {
                     if (id.startsWith('^')) {
                         debugger;
@@ -815,10 +823,16 @@ class TensorFlowGraphOperatorMetadata {
                 }
                 var result = {};
                 result.name = outputArg.name;
+                if (outputArg.type) {
+                    result.type = TensorFlowTensor.formatDataType(outputArg.type);
+                }
+                else if (outputArg.typeAttr) {
+                    result.type = outputArg.typeAttr;
+                }
+                else if (outputArg.typeListAttr) {
+                    result.type = outputArg.typeListAttr;
+                }
                 result.connections = node.output.slice(index, index + count).map((id) => {
-                    if (id.startsWith('^')) {
-                        id = id.substring(1);
-                    }
                     return { id: id };
                 });
                 results.push(result);
@@ -827,12 +841,10 @@ class TensorFlowGraphOperatorMetadata {
         }
         else {
             node.output.slice(index).forEach((output) => {
-                if (!output.startsWith('^')) {
-                    results.push({
-                        name: '(' + index.toString() + ')',
-                        connections: [ { id: output } ]
-                    });
-                }
+                results.push({
+                    name: '(' + index.toString() + ')',
+                    connections: [ { id: output } ]
+                });
                 index++;
             });
         }

+ 0 - 4
src/view-tflite.js

@@ -227,10 +227,6 @@ class TensorFlowLiteNode {
         return results;
     }
 
-    get dependencies() {
-        return [];
-    }
-
     get outputs() {
         var results = [];
         var graph = this._graph._graph;

+ 13 - 20
src/view.js

@@ -188,27 +188,20 @@ function updateGraph(model) {
 
         node.outputs.forEach((output) => {
             output.connections.forEach((connection) => {
-                var tuple = edgeMap[connection.id];
-                if (!tuple) {
-                    tuple = { from: null, to: [] };
-                    edgeMap[connection.id] = tuple;
+                if (connection.id.startsWith('^')) {
+                    formatter.setControlDependencies();
+                }
+                else {
+                    var tuple = edgeMap[connection.id];
+                    if (!tuple) {
+                        tuple = { from: null, to: [] };
+                        edgeMap[connection.id] = tuple;
+                    }
+                    tuple.from = { 
+                        node: nodeId,
+                        name: output.name
+                    };    
                 }
-                tuple.from = { 
-                    node: nodeId,
-                    name: output.name
-                };    
-            });
-        });
-
-        node.dependencies.forEach((dependency) => {
-            var tuple = edgeMap[dependency];
-            if (!tuple) {
-                tuple = { from: null, to: [] };
-                edgeMap[dependency] = tuple;
-            }
-            tuple.to.push({ 
-                node: nodeId, 
-                dependency: true
             });
         });