Parcourir la source

Fix Keras argument identifier (#540)

Lutz Roeder il y a 5 ans
Parent
commit
0b3e5c3624
2 fichiers modifiés avec 20 ajouts et 3 suppressions
  1. 11 1
      source/keras.js
  2. 9 2
      test/models.json

+ 11 - 1
source/keras.js

@@ -339,7 +339,17 @@ keras.Graph = class {
             for (const layer of config.layers) {
                 if (layer.inbound_nodes) {
                     for (let inbound_node of layer.inbound_nodes) {
-                        inbound_node = inbound_node.every((inbound_connection) => Array.isArray(inbound_connection[0])) ? inbound_node.flat() : inbound_node;
+                        const is_connection = (item) => {
+                            return Array.isArray(item) && (item.length === 3 || item.length === 4) && typeof item[0] === 'string';
+                        };
+                        // wrap
+                        if (is_connection(inbound_node)) {
+                            inbound_node = [ inbound_node ];
+                        }
+                        // unwrap
+                        if (Array.isArray(inbound_node) && inbound_node.every((array) => Array.isArray(array) && array.every((item) => is_connection(item)))) {
+                            inbound_node = inbound_node.flat();
+                        }
                         for (const inbound_connection of inbound_node) {
                             let inputName = inbound_connection[0];
                             const inputNode = nodeMap.get(inputName);

+ 9 - 2
test/models.json

@@ -2312,8 +2312,15 @@
   },
   {
     "type":   "keras",
-    "target": "yolov3-tiny.h5",
-    "source": "https://github.com/lutzroeder/netron/files/5192003/yolov3-tiny.zip[yolov3-tiny.h5]",
+    "target": "netron_issue_540_1.h5",
+    "source": "https://github.com/lutzroeder/netron/files/5748679/netron_issue_540_1.h5.zip[netron_issue_540_1.h5]",
+    "format": "Keras v2.4.0", "runtime": "tensorflow",
+    "link":   "https://github.com/lutzroeder/netron/issues/540"
+  },
+  {
+    "type":   "keras",
+    "target": "netron_issue_540_2.h5",
+    "source": "https://github.com/lutzroeder/netron/files/5748680/netron_issue_540_2.h5.zip[netron_issue_540_2.h5]",
     "format": "Keras v2.4.0", "runtime": "tensorflow",
     "link":   "https://github.com/lutzroeder/netron/issues/540"
   },