2
0
Эх сурвалжийг харах

Fix TensorFlow control dependencies (#617)

Lutz Roeder 4 жил өмнө
parent
commit
522469d99a
3 өөрчлөгдсөн 15 нэмэгдсэн , 7 устгасан
  1. 4 1
      source/om.js
  2. 1 1
      source/tf.js
  3. 10 5
      source/view.js

+ 4 - 1
source/om.js

@@ -112,7 +112,7 @@ om.Node = class {
             const name = pos === 0 ? 'internal_unnamed' : op.input[i].slice(0, pos);
             const src_index = op.input[i].slice(pos + 1);
             if (src_index === '-1') {
-                this._controlDependencies.push(name);
+                this._controlDependencies.push(new om.Argument(name));
                 continue;
             }
             const parameterName = this._type.inputs && i < this._type.inputs.length ? this._type.inputs[i].name : 'input' + (i === 0 ? '' : i.toString());
@@ -328,6 +328,9 @@ om.Parameter = class {
 om.Argument = class {
 
     constructor(name, type, initializer) {
+        if (typeof name !== 'string') {
+            throw new om.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
+        }
         this._name = name;
         this._type = type || null;
         this._initializer = initializer || null;

+ 1 - 1
source/tf.js

@@ -901,7 +901,7 @@ tf.Node = class {
                     new tf.Argument(output.name ? output.name : '-', null, null)
                 ]);
             }));
-            this._controlDependencies = node.controlDependencies.map((input) => input.name);
+            this._controlDependencies = node.controlDependencies.map((input) => new tf.Argument(input.name));
         }
         else if (tensors) {
             for (const tensor of tensors) {

+ 10 - 5
source/view.js

@@ -628,8 +628,8 @@ view.View = class {
                     }
 
                     if (node.controlDependencies && node.controlDependencies.length > 0) {
-                        for (const name of node.controlDependencies) {
-                            viewGraph.createArgument({ name: name, controlDependency: true }).to(viewNode);
+                        for (const argument of node.controlDependencies) {
+                            viewGraph.createArgument(argument).to(viewNode, true);
                         }
                     }
 
@@ -1209,15 +1209,20 @@ view.Argument = class {
         this._from = node;
     }
 
-    to(node) {
+    to(node, controlDependency) {
         this._to = this._to || [];
+        if (controlDependency) {
+            this._controlDependencies = this._controlDependencies || new Set();
+            this._controlDependencies.add(this._to.length);
+        }
         this._to.push(node);
     }
 
     build() {
         this._edges = this._edges || [];
         if (this._from && this._to) {
-            for (const to of this._to) {
+            for (let i = 0; i < this._to.length; i++) {
+                const to = this._to[i];
                 let text = '';
                 const type = this._argument.type;
                 if (type && type.shape && type.shape.dimensions && type.shape.dimensions.length > 0) {
@@ -1231,7 +1236,7 @@ view.Argument = class {
                 edge.w = to.name;
                 edge.label = text;
                 edge.id = 'edge-' + this._argument.name;
-                if (this._argument.controlDependency) {
+                if (this._controlDependencies && this._controlDependencies.has(i)) {
                     edge.class = 'edge-path-control-dependency';
                 }
                 this.context.setEdge(edge);