فهرست منبع

Add TensorFlow test file (#895)

Lutz Roeder 3 سال پیش
والد
کامیت
79db8298f8
4فایلهای تغییر یافته به همراه58 افزوده شده و 54 حذف شده
  1. 42 46
      source/tf.js
  2. 4 4
      source/view-sidebar.js
  3. 2 2
      source/view.js
  4. 10 2
      test/models.json

+ 42 - 46
source/tf.js

@@ -704,7 +704,6 @@ tf.Graph = class {
         this._outputs = [];
         this._nodes = [];
         this._version = null;
-
         if (meta_graph && meta_graph.graph_def) {
             const graph = meta_graph.graph_def;
             if (graph.versions) {
@@ -727,8 +726,7 @@ tf.Graph = class {
             this._outputs = context.outputs;
         }
         else if (bundle) {
-            const nodeNames = [];
-            const nodeMap = new Map();
+            const nodes = new Map();
             for (const tensor of bundle.tensors) {
                 const parts = tensor.name.split('/');
                 if (bundle.format === 2) {
@@ -745,17 +743,17 @@ tf.Graph = class {
                     }
                 }
                 const tensorName = parts.pop();
-                const nodeName = parts.join('/');
-                if (!nodeMap.has(nodeName)) {
-                    nodeNames.push(nodeName);
-                    nodeMap.set(nodeName, []);
+                const name = parts.join('/');
+                if (!nodes.has(name)) {
+                    nodes.set(name, []);
                 }
-                nodeMap.get(nodeName).push({ name: tensorName, value: tensor });
+                nodes.get(name).push({ name: tensorName, value: tensor });
             }
             const namespaces = new Set();
-            for (const name of nodeNames) {
-                this._nodes.push(new tf.Node(metadata, namespaces, null, 'Node', name, null, nodeMap.get(name)));
-            }
+            this._nodes = Array.from(nodes).map((entry) => {
+                const node = { op: 'Node', name: entry[0] };
+                return new tf.Node(metadata, node, namespaces, null, entry[1]);
+            });
         }
     }
 
@@ -924,41 +922,45 @@ tf.Function = class {
 
 tf.Node = class {
 
-    constructor(metadata, namespaces, node, op, name, initializers, tensors) {
-        this._type = Object.assign({}, node && node.metadata ? node.metadata : metadata.type(op) || { name: op });
-        this._type.identifier = this._type.name;
-        this._type.name = op;
-        this._name = name;
+    constructor(metadata, node, namespaces, initializers, tensors) {
+        this._type = node.metadata || metadata.type(node.op) || { name: node.op };
+        this._name = node.name;
         this._attributes = [];
         this._inputs = [];
         this._outputs = [];
-
         this._group = '';
-        if (namespaces.has(name)) {
-            this._group = name;
-        }
-        else {
-            const lastIndex = name.lastIndexOf('/');
-            if (lastIndex != -1) {
-                const namespace = name.substring(0, lastIndex);
-                if (namespaces.has(namespace)) {
-                    this._group = namespace;
+        if (node.name) {
+            if (namespaces.has(node.name)) {
+                this._group = node.name;
+            }
+            else {
+                const lastIndex = node.name.lastIndexOf('/');
+                if (lastIndex != -1) {
+                    const namespace = node.name.substring(0, lastIndex);
+                    if (namespaces.has(namespace)) {
+                        this._group = namespace;
+                    }
                 }
             }
         }
-
-        if (node) {
+        if (tensors) {
+            for (const tensor of tensors) {
+                this._inputs.push(new tf.Parameter(tensor.name, [
+                    new tf.Argument(tensor.value.name, null, tensor.value)
+                ]));
+            }
+        }
+        else {
             if (node.device !== undefined) {
                 this._device = node.device;
             }
             if (node.attr) {
-                this._attributes = Object.keys(node.attr).map((name) => {
-                    const value = node.attr[name];
-                    return new tf.Attribute(metadata, op, name, value);
+                this._attributes = Object.entries(node.attr).map((entry) => {
+                    return new tf.Attribute(metadata, node.op, entry[0], entry[1]);
                 });
             }
             let inputIndex = 0;
-            const inputs = node.input.filter((input) => !input.name.startsWith('^'));
+            const inputs = (node.input || []).filter((input) => !input.name.startsWith('^'));
             if (this._type && this._type.inputs) {
                 for (const input of this._type.inputs) {
                     let inputCount = 1;
@@ -987,7 +989,7 @@ tf.Node = class {
                 ]);
             }));
             let outputIndex = 0;
-            const outputs = node.output;
+            const outputs = node.output || [];
             if (this._type && this._type.outputs) {
                 for (const output of this._type.outputs) {
                     let outputCount = 1;
@@ -1015,14 +1017,8 @@ tf.Node = class {
                     new tf.Argument(output.name ? output.name : '-', null, null)
                 ]);
             }));
-            this._controlDependencies = node.controlDependencies.map((input) => new tf.Argument(input.name));
-        }
-        else if (tensors) {
-            for (const tensor of tensors) {
-                this._inputs.push(new tf.Parameter(tensor.name, [
-                    new tf.Argument(tensor.value.name, null, tensor.value)
-                ]));
-            }
+            const controlDependencies = node.controlDependencies || [];
+            this._controlDependencies = controlDependencies.map((input) => new tf.Argument(input.name));
         }
     }
 
@@ -1101,9 +1097,8 @@ tf.Attribute = class {
                 break;
             }
             case 'func': {
-                const name = value.func.name;
                 this._type = 'function';
-                this._value = metadata.type(name);
+                this._value = new tf.Node(metadata, { op: value.func.name, attr: value.func.attr });
                 break;
             }
             case 'list': {
@@ -1127,7 +1122,7 @@ tf.Attribute = class {
                 }
                 else if (list.func && list.func.length > 0) {
                     this._type = 'function[]';
-                    this._value = list.func.map((func) => metadata.type(func.name));
+                    this._value = list.func.map((func) => new tf.Node(metadata, { op: func.name, attr: func.attr }));
                 }
                 else {
                     this._value = [];
@@ -2426,7 +2421,8 @@ tf.Utility = class {
                             }
                         }
                         if (match) {
-                            node.metadata = metadata;
+                            node.metadata = Object.assign({}, metadata);
+                            node.metadata.name = node.op;
                             break;
                         }
                         else {
@@ -2473,7 +2469,7 @@ tf.Utility = class {
             context.inputs.push(input);
         }
         for (const node of node_map.values()) {
-            context.nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
+            context.nodes.push(new tf.Node(metadata, node, namespaces, initializers));
         }
         return context;
     }

+ 4 - 4
source/view-sidebar.js

@@ -298,9 +298,9 @@ sidebar.NodeSidebar = class {
                 }
                 return '[...]';
             case 'function':
-                return value.name;
+                return value.type.name;
             case 'function[]':
-                return value ? value.map((item) => item.name).join(', ') : '(null)';
+                return value ? value.map((item) => item.type.name).join(', ') : '(null)';
         }
         if (typeof value === 'string' && (!type || type != 'string')) {
             return quote ? '"' + value + '"' : value;
@@ -515,9 +515,9 @@ class NodeAttributeView {
             case 'function': {
                 const line = this._host.document.createElement('div');
                 line.className = 'sidebar-view-item-value-line-link';
-                line.innerHTML = value.name;
+                line.innerHTML = value.type.name;
                 line.addEventListener('click', () => {
-                    this._raise('show-graph', value);
+                    this._raise('show-graph', value.type);
                 });
                 this._element.appendChild(line);
                 break;

+ 2 - 2
source/view.js

@@ -805,7 +805,7 @@ view.View = class {
 
     showDocumentation(type) {
         if (type && (type.description || type.inputs || type.outputs || type.attributes)) {
-            if (type.nodes) {
+            if (type.nodes && type.nodes.length > 0) {
                 this.pushGraph(type);
             }
             const documentationSidebar = new sidebar.DocumentationSidebar(this._host, type);
@@ -1013,7 +1013,7 @@ view.Node = class extends grapher.Node {
         const tooltip = this.context.view.options.names && (node.name || node.location) ? type.name : (node.name || node.location);
         const title = header.add(null, styles, content, tooltip);
         title.on('click', () => this.context.view.showNodeProperties(node, null));
-        if (node.type.nodes) {
+        if (node.type.nodes && node.type.nodes.length > 0) {
             const definition = header.add(null, styles, '\u0192', 'Show Function Definition');
             definition.on('click', () => this.context.view.pushGraph(node.type));
         }

+ 10 - 2
test/models.json

@@ -3235,7 +3235,7 @@
     "target":   "denotation_Add_ImageNet1920WithImageMetadataBgr8_SRGB_0_255.onnx",
     "source":   "https://github.com/lutzroeder/netron/files/2587943/onnx_denotation_models.zip[denotation_Add_ImageNet1920WithImageMetadataBgr8_SRGB_0_255.onnx]",
     "format":   "ONNX v3",
-    "assert": [ "model.graphs[0].nodes[0].outputs[0].arguments[0].type.denotation = Image(Bgr8,SRGB,NominalRange_0_255)" ],
+    "assert":   [ "model.graphs[0].nodes[0].outputs[0].arguments[0].type.denotation = Image(Bgr8,SRGB,NominalRange_0_255)" ],
     "link":     "https://github.com/lutzroeder/netron/issues/183"
   },
   {
@@ -5561,6 +5561,7 @@
     "target":   "events.out.tfevents.1606692323.b5c8f88cc7ee.58.2",
     "source":   "https://github.com/lutzroeder/netron/files/5613448/events.out.tfevents.1606692323.b5c8f88cc7ee.58.2.zip[events.out.tfevents.1606692323.b5c8f88cc7ee.58.2]",
     "format":   "TensorFlow Event File v2",
+    "assert":   [ "model.graphs[0].nodes[0].type.name=aten::_convolution" ],
     "producer": "PyTorch",
     "link":     "https://github.com/lutzroeder/netron/issues/638"
   },
@@ -5837,6 +5838,13 @@
     "format":   "TensorFlow Graph",
     "link":     "https://github.com/lutzroeder/netron/issues/847"
   },
+  {
+    "type":     "tf",
+    "target":   "netron_issue_895.pbtxt",
+    "source":   "https://github.com/lutzroeder/netron/files/8459475/netron_issue_895.pbtxt.zip[netron_issue_895.pbtxt]",
+    "format":   "TensorFlow Graph",
+    "link":     "https://github.com/lutzroeder/netron/issues/895"
+  },
   {
     "type":     "tf",
     "target":   "pose_estimation_for_mobile.pb",
@@ -6047,7 +6055,7 @@
     "target":   "densenet.tflite",
     "source":   "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz[densenet/densenet.tflite]",
     "format":   "TensorFlow Lite v3",
-    "assert": [ "model.graphs[0].nodes[0].type.name=Conv2D" ],
+    "assert":   [ "model.graphs[0].nodes[0].type.name=Conv2D" ],
     "link":     "https://www.tensorflow.org/lite/guide/hosted_models"
   },
   {