|
|
@@ -339,7 +339,17 @@ keras.Graph = class {
|
|
|
for (const layer of config.layers) {
|
|
|
if (layer.inbound_nodes) {
|
|
|
for (let inbound_node of layer.inbound_nodes) {
|
|
|
- inbound_node = inbound_node.every((inbound_connection) => Array.isArray(inbound_connection[0])) ? inbound_node.flat() : inbound_node;
|
|
|
+ const is_connection = (item) => {
|
|
|
+ return Array.isArray(item) && (item.length === 3 || item.length === 4) && typeof item[0] === 'string';
|
|
|
+ };
|
|
|
+ // wrap
|
|
|
+ if (is_connection(inbound_node)) {
|
|
|
+ inbound_node = [ inbound_node ];
|
|
|
+ }
|
|
|
+ // unwrap
|
|
|
+ if (Array.isArray(inbound_node) && inbound_node.every((array) => Array.isArray(array) && array.every((item) => is_connection(item)))) {
|
|
|
+ inbound_node = inbound_node.flat();
|
|
|
+ }
|
|
|
for (const inbound_connection of inbound_node) {
|
|
|
let inputName = inbound_connection[0];
|
|
|
const inputNode = nodeMap.get(inputName);
|