Просмотр исходного кода

Sidebar attribute function support

Lutz Roeder 4 лет назад
Родитель
Сommit
18dadd4335
4 измененных файлов с 88 добавлено и 51 удалено
  1. 31 37
      source/tf.js
  2. 2 0
      source/view-sidebar.css
  3. 52 14
      source/view-sidebar.js
  4. 3 0
      source/view.js

+ 31 - 37
source/tf.js

@@ -496,20 +496,11 @@ tf.Model = class {
         this._producer = producer || '';
         this._graphs = [];
         if (model) {
-            const graphs = [];
             for (let i = 0; i < model.meta_graphs.length; i++) {
                 const meta_graph = model.meta_graphs[i];
                 const name = (meta_graph.meta_info_def && meta_graph.meta_info_def.any_info) ? meta_graph.meta_info_def.any_info.toString() : ((model.meta_graphs.length > 1) ? i.toString() : '-');
                 const graph = new tf.Graph(metadata, meta_graph, name, bundle);
-                graphs.push(graph);
-            }
-            // Recursively add all subgraphs.
-            while (graphs.length > 0) {
-                const graph = graphs.shift();
                 this._graphs.push(graph);
-                for (const func of graph.functions || []) {
-                    graphs.push(func);
-                }
             }
         }
         else {
@@ -543,10 +534,8 @@ tf.Graph = class {
         this._inputs = [];
         this._outputs = [];
         this._nodes = [];
-        this._functions = [];
 
         if (meta_graph && meta_graph.graph_def) {
-            metadata = new tf.GraphMetadata(metadata, meta_graph.meta_info_def);
             const graph = meta_graph.graph_def;
             if (graph.versions) {
                 this._version = 'v' + graph.versions.producer.toString();
@@ -561,6 +550,8 @@ tf.Graph = class {
                 this._tags = meta_graph.meta_info_def.tags.join(', ');
             }
 
+            metadata = new tf.GraphMetadata(metadata, graph.library);
+
             const nodes = graph.node;
             if (nodes) {
                 const node_map = new Map();
@@ -911,15 +902,6 @@ tf.Graph = class {
                     this._nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
                 }
             }
-
-            if (graph.library) {
-                const funcs = graph.library.function;
-                for (const func of funcs) {
-                    const value = new tf.Function(this, func, metadata);
-                    metadata.add(value);
-                    this._functions.push(value);
-                }
-            }
         }
         else if (bundle) {
             const nodeNames = [];
@@ -986,10 +968,6 @@ tf.Graph = class {
     get metadata() {
         return this._metadata;
     }
-
-    get functions() {
-        return this._functions;
-    }
 };
 
 tf.Parameter = class {
@@ -1040,8 +1018,8 @@ tf.Argument = class {
 };
 
 tf.Function = class {
+    constructor(metadata, func) {
 
-    constructor(graph, func, metadata) {
         this._name = func.signature.name;
         this._version = null;
         this._tags = null;
@@ -1195,6 +1173,10 @@ tf.Function = class {
         }
     }
 
+    get type() {
+        return 'function';
+    }
+
     get name() {
         return this._name;
     }
@@ -1257,9 +1239,7 @@ tf.Node = class {
             if (node.attr) {
                 this._attributes = Object.keys(node.attr).map((name) => {
                     const value = node.attr[name];
-                    const schema = value && value.metadata ? value.metadata : metadata.attribute(op, name);
-                    const visible = metadata.visible(this._type, name);
-                    return new tf.Attribute(schema, name, value, visible);
+                    return new tf.Attribute(metadata, op, name, value);
                 });
             }
             let inputIndex = 0;
@@ -1370,10 +1350,12 @@ tf.Node = class {
 
 tf.Attribute = class {
 
-    constructor(schema, name, value, visible) {
+    constructor(metadata, op, name, value) {
         this._name = name;
         this._value = null;
         this._type = null;
+        const schema = value && value.metadata ? value.metadata : metadata.attribute(op, name);
+        const visible = metadata.visible(op, name);
         if (Object.prototype.hasOwnProperty.call(value, 'tensor')) {
             this._type = 'tensor';
             this._value = new tf.Tensor(value.tensor);
@@ -1403,9 +1385,8 @@ tf.Attribute = class {
                 this._value = tf.Utility.decodeText(value.s);
                 break;
             case 'func': {
-                const func = value.func;
                 this._type = 'function';
-                this._value = func.name;
+                this._value = metadata.type(value.func.name);
                 break;
             }
             case 'list': {
@@ -1427,6 +1408,10 @@ tf.Attribute = class {
                     this._type = 'shape[]';
                     this._value = list.shape.map((shape) => new tf.TensorShape(shape));
                 }
+                else if (list.func && list.func.length > 0) {
+                    this._type = 'function[]';
+                    this._value = list.func.map((func) => metadata.type(func.name));
+                }
                 else {
                     this._value = [];
                 }
@@ -2243,22 +2228,31 @@ tf.EventFileReader = class {
 
 tf.GraphMetadata = class {
 
-    constructor(metadata) {
+    constructor(metadata, library) {
         this._metadata = metadata;
         this._functions = new Map();
         this._attributes = new Map();
         this._visibleCache = new Map();
-    }
 
-    add(func) {
-        if (this._functions.has(func.name)) {
-            throw new tf.Error("Duplicate function name '" + func.name + "'.");
+        if (library && Array.isArray(library.function)) {
+            for (const func of library.function) {
+                const name = func.signature.name;
+                if (this._functions.has(func.name)) {
+                    throw new tf.Error("Duplicate function name '" + func.name + "'.");
+                }
+                this._functions.set(name, func);
+            }
         }
-        this._functions.set(func.name, func);
+
     }
 
     type(name) {
         if (this._functions.has(name)) {
+            const func = this._functions.get(name);
+            if (func instanceof tf.Function) {
+                return func;
+            }
+            this._functions.set(name, new tf.Function(this, func));
             return this._functions.get(name);
         }
         return this._metadata.type(name);

+ 2 - 0
source/view-sidebar.css

@@ -19,6 +19,8 @@
 .sidebar-view-item-value code { font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace; overflow: auto; white-space: pre-wrap; word-wrap: break-word; }
 .sidebar-view-item-value pre { font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace; margin: 0; overflow: auto; white-space: pre; word-wrap: normal; display: block; }
 .sidebar-view-item-value-line { padding: 4px 6px 4px 6px; }
+.sidebar-view-item-value-line-link { padding: 4px 6px 4px 6px; cursor: default; }
+.sidebar-view-item-value-line-link:hover { text-decoration: underline; }
 .sidebar-view-item-value-line-border { padding: 4px 6px 4px 6px; border-top: 1px solid rgba(27, 31, 35, 0.05); }
 .sidebar-view-item-value-line-content { white-space: pre; word-wrap: normal; overflow: auto; display: block; }
 .sidebar-view-item-value-expander { font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace; float: right; color: #aaa; cursor: hand; user-select: none; -webkit-user-select: none; -moz-user-select: none; padding: 4px 6px 4px 6px; }

+ 52 - 14
source/view-sidebar.js

@@ -218,9 +218,13 @@ sidebar.NodeSidebar = class {
     }
 
     _addAttribute(name, attribute) {
-        const item = new sidebar.NameValueView(this._host, name, new NodeAttributeView(this._host, attribute));
-        this._attributes.push(item);
-        this._elements.push(item.render());
+        const item = new NodeAttributeView(this._host, attribute);
+        item.on('show-graph', (sender, graph) => {
+            this._raise('show-graph', graph);
+        });
+        const view = new sidebar.NameValueView(this._host, name, item);
+        this._attributes.push(view);
+        this._elements.push(view.render());
     }
 
     _addInput(name, input) {
@@ -295,6 +299,10 @@ sidebar.NodeSidebar = class {
                     return value.toString();
                 }
                 return '[...]';
+            case 'function':
+                return value.name;
+            case 'function[]':
+                return value ? value.map((item) => item.name).join(', ') : '(null)';
         }
         if (typeof value === 'string' && (!type || type != 'string')) {
             return quote ? '"' + value + '"' : value;
@@ -492,7 +500,8 @@ class NodeAttributeView {
         this._element = this._host.document.createElement('div');
         this._element.className = 'sidebar-view-item-value';
 
-        if (attribute.type) {
+        const type = this._attribute.type;
+        if (type) {
             this._expander = this._host.document.createElement('div');
             this._expander.className = 'sidebar-view-item-value-expander';
             this._expander.innerText = '+';
@@ -501,17 +510,32 @@ class NodeAttributeView {
             });
             this._element.appendChild(this._expander);
         }
-        let value = sidebar.NodeSidebar.formatAttributeValue(this._attribute.value, this._attribute.type);
-        if (value && value.length > 1000) {
-            value = value.substring(0, 1000) + '\u2026';
-        }
-        if (value && typeof value === 'string') {
-            value = value.split('<').join('&lt;').split('>').join('&gt;');
+        const value = this._attribute.value;
+        switch (type) {
+            case 'function': {
+                const line = this._host.document.createElement('div');
+                line.className = 'sidebar-view-item-value-line-link';
+                line.innerHTML = value.name;
+                line.addEventListener('click', () => {
+                    this._raise('show-graph', value);
+                });
+                this._element.appendChild(line);
+                break;
+            }
+            default: {
+                let text = sidebar.NodeSidebar.formatAttributeValue(value, type);
+                if (text && text.length > 1000) {
+                    text = text.substring(0, 1000) + '\u2026';
+                }
+                if (text && typeof text === 'string') {
+                    text = text.split('<').join('&lt;').split('>').join('&gt;');
+                }
+                const line = this._host.document.createElement('div');
+                line.className = 'sidebar-view-item-value-line';
+                line.innerHTML = (text ? text : '&nbsp;');
+                this._element.appendChild(line);
+            }
         }
-        const valueLine = this._host.document.createElement('div');
-        valueLine.className = 'sidebar-view-item-value-line';
-        valueLine.innerHTML = (value ? value : '&nbsp;');
-        this._element.appendChild(valueLine);
     }
 
     render() {
@@ -560,6 +584,20 @@ class NodeAttributeView {
             }
         }
     }
+
+    on(event, callback) {
+        this._events = this._events || {};
+        this._events[event] = this._events[event] || [];
+        this._events[event].push(callback);
+    }
+
+    _raise(event, data) {
+        if (this._events && this._events[event]) {
+            for (const callback of this._events[event]) {
+                callback(this, data);
+            }
+        }
+    }
 }
 
 sidebar.ParameterView = class {

+ 3 - 0
source/view.js

@@ -895,6 +895,9 @@ view.View = class {
             nodeSidebar.on('show-documentation', (/* sender, e */) => {
                 this.showNodeDocumentation(node);
             });
+            nodeSidebar.on('show-graph', (sender, graph) => {
+                this.pushGraph(graph);
+            });
             nodeSidebar.on('export-tensor', (sender, tensor) => {
                 this._host.require('./numpy').then((numpy) => {
                     const defaultPath = tensor.name ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_') : 'tensor';