소스 검색

Keras inputs, wrappers and categories

Lutz Roeder 8 년 전
부모
커밋
bc3aed9c8c
4개의 변경된 파일254개의 추가작업 그리고 76개의 파일을 삭제
  1. 115 72
      src/keras-model.js
  2. 133 1
      src/keras-operator.json
  3. 6 2
      src/view-render.css
  4. 0 1
      src/view.js

+ 115 - 72
src/keras-model.js

@@ -128,9 +128,7 @@ class KerasGraph {
     }
 
     loadModel(root) {
-
         if (root.layers) {
-
             var nodeMap = {};
             root.layers.forEach((layer) => {
                 if (layer.name) {
@@ -144,86 +142,140 @@ class KerasGraph {
             root.layers.forEach((layer) => {
                 if (layer.inbound_nodes) {
                     layer.inbound_nodes.forEach((inbound_node) => {
-                        var input = { connections: [] };
                         inbound_node.forEach((inbound_connection) => {
+                            var input = { connections: [] };
                             var inputName = inbound_connection[0];
                             input.connections.push({ id: inputName });
                             var inputNode = nodeMap[inputName];
                             if (inputNode) {
-                                inputNode._outputs.push(inputNode.name);
+                                inputNode._outputs.push({
+                                    connections: [ { id: inputNode.name } ]
+                                });
                             }
+                            layer._inputs.push(input);
                         });       
-                        layer._inputs.push(input);
                     });
                 }
             });
         }
-
-        /*
         if (root.input_layers) {
             root.input_layers.forEach((input_layer) => {
-                this._inputs.push({ id: input_layer[0], name: input_layer[0] });
-            });    
+                var name = input_layer[0];
+                var input = {
+                    id: name,
+                    name: name
+                };
+                var node = nodeMap[name];
+                if (node && node.class_name == 'InputLayer') {
+                    this.translateInput(node, input);
+                    delete nodeMap[name];
+                }
+                this._inputs.push(input); 
+            });
         }
-        */
-
         if (root.output_layers) {
             root.output_layers.forEach((output_layer) => {
                 var inputName = output_layer[0];
                 var inputNode = nodeMap[inputName];
                 if (inputNode) {
-                    inputNode._outputs.push(inputName);
+                    inputNode._outputs.push({
+                        connections: [ { id: inputName } ]                        
+                    });
                 }
-                this._outputs.push({ id: inputName, name: inputName, type: '?' });
+                var output = {
+                    id: inputName,
+                    name: inputName,
+                    type: '?'
+                };
+                this._outputs.push(output);
             });
         }
-
         if (root.layers) {
             root.layers.forEach((layer) => {
-                var node = new KerasNode(layer.class_name, layer.name, layer.config, layer._inputs, layer._outputs);
-                this._nodes.push(node);
+                if (nodeMap[layer.name]) {
+                    this.translateNode(layer.name, layer, layer._inputs, layer._outputs).forEach((node) => {
+                        this._nodes.push(node);
+                    });
+                }
             });
         }
     }
 
     loadSequential(root) {
-        var output = 'input';
-
-        this._inputs.push({
-            name: output,
-            id: output,
-            type: '?'
-        });
-
+        var connection = 'input';
+        var input = {
+            id: connection,
+            name: connection
+        };
+        this._inputs.push(input);
         var id = 0;
         root.forEach((layer) => {
-            var inputs = [];
-            if (output) {
-                inputs.push({
-                    name: '(0)',
-                    connections: [ { id: output }]
-                });
-            }
-
+            var inputs = [ {
+                connections: [ { id: connection } ]
+            } ];
             var name = id.toString();
-            if (layer.config || layer.config.name) {
-                name = layer.config.name;
+            if (id == 0) {
+                this.translateInput(layer, input);
             }
             id++;
-            output = name;
-
-            var outputs = [ output ];
-
-            var node = new KerasNode(layer.class_name, name, layer.config, inputs, outputs);
-            this._nodes.push(node);
+            if (layer.config && layer.config.name) {
+                name = layer.config.name;
+            }
+            connection = name;
+            var outputs = [ {
+                connections: [ { id: connection } ]
+            } ];
+            this.translateNode(name, layer, inputs, outputs).forEach((node) => {
+                this._nodes.push(node);
+            });
         });
-
         this._outputs.push({
             name: 'output',
-            id: output,
+            id: connection,
             type: '?'
         });
     }
+
+    translateNode(name, layer, inputs, outputs) {
+        var results = [];
+        if (layer.class_name == 'Bidirectional' || layer.class_name == 'TimeDistributed') {
+            if (layer.config.layer) {
+                var subLayer = layer.config.layer;
+                var subConnection = name + '|' + layer;
+                inputs.push({
+                    name: 'layer',
+                    connections: [ { id: subConnection} ]
+                });
+                var subOutputs = [ {
+                    connections: [ { id: subConnection } ]
+                } ];
+                results.push(new KerasNode(subLayer.class_name, subLayer.config.name, subLayer.config, [], subOutputs));
+                delete layer.config.layer;
+            }
+        }        
+        var node = new KerasNode(layer.class_name, name, layer.config, inputs, outputs);
+        results.push(node);
+        return results;
+    }
+
+    translateInput(layer, input) {
+        input.type = '';
+        if (layer && layer.config) {
+            var config = layer.config;
+            if (config.dtype) {
+                input.type = config.dtype;
+                delete config.dtype;
+            }
+            if (config.batch_input_shape) {
+                var shape = config.batch_input_shape;
+                if (shape.length > 0 && shape[0] == null) {
+                    shape.shift();
+                }
+                input.type = input.type + '[' + shape.toString() + ']';
+                delete config.batch_input_shape;
+            }
+        }
+    }
 }
 
 class KerasNode {
@@ -251,8 +303,8 @@ class KerasNode {
     get inputs() {
         var results = [];
         this._inputs.forEach((input, index) => {
-            results.push({ 
-                name: '(' + index.toString() + ')', 
+            results.push({
+                name: input.name ? input.name : '(' + index.toString() + ')', 
                 connections: input.connections
             });
         });
@@ -263,8 +315,8 @@ class KerasNode {
         var results = [];
         this._outputs.forEach((output, index) => {
             results.push({ 
-                name: '(' + index.toString() + ')', 
-                connections: [ { id: output }]
+                name: output.name ? output.name : '(' + index.toString() + ')', 
+                connections: output.connections
             });
         });
         return results;
@@ -302,12 +354,21 @@ class KerasAttribute {
     }
 
     get value() {
-        if (this._value == true) {
+        if (this._value === true) {
             return 'true';
         }
-        if (this._value == false) {
+        if (this._value === false) {
             return 'false';
         }
+        if (this._value === null) {
+            return 'null';
+        }
+        if (typeof this._value == 'object' && this._value.class_name && this._value.config) {
+            return this._value.class_name + '(' + Object.keys(this._value.config).map(key => {
+                var value = this._value.config[key];
+                return key + '=' + JSON.stringify(value);
+            }).join(', ') + ')';
+        }
         if (this._value) {
             return JSON.stringify(this._value);
         }
@@ -346,27 +407,6 @@ class KerasOperatorMetadata {
                 }
             });
         }
-
-        this._categoryMap = {
-            'Conv1D': 'Layer',
-            'Conv2D': 'Layer',
-            'Conv3D': 'Layer',
-            'Convolution1D': 'Layer',
-            'Convolution2D': 'Layer',
-            'Convolution3D': 'Layer',
-            'DepthwiseConv2D': 'Layer',
-            'Dense': 'Layer',
-            'BatchNormalization': 'Normalization',
-            'Concatenate': 'Tensor',
-            'Activation': 'Activation',
-            'GlobalAveragePooling2D': 'Pool',
-            'AveragePooling2D': 'Pool',
-            'MaxPooling2D': 'Layer',
-            'GlobalMaxPooling2D': 'Layer',
-            'Flatten': 'Shape',
-            'Reshape': 'Shape',
-            'Dropout': 'Dropout'
-        };    
     }
 
     showAttribute(operator, attributeName, attributeValue) {
@@ -397,9 +437,12 @@ class KerasOperatorMetadata {
     }
 
     getOperatorCategory(operator) {
-        var category = this._categoryMap[operator];
-        if (category) {
-            return category;
+        var schema = this._map[operator];
+        if (schema) {
+            var category = schema.category;
+            if (category) {
+                return category;
+            }
         }
         return null;
     }

+ 133 - 1
src/keras-operator.json

@@ -2,14 +2,24 @@
   {
     "name": "Bidirectional",
     "schema": {
+      "category": "Wrapper",
       "attributes": [
         { "name": "merge_mode", "default": "concat" }
       ]
     }
   },
+  {
+    "name": "TimeDistributed",
+    "schema": {
+      "category": "Wrapper",
+      "attributes": [
+      ]
+    }
+  },
   {
     "name": "Activation",
     "schema": {
+      "category": "Activation",
       "attributes": [
       ]
     }
@@ -17,6 +27,7 @@
   {
     "name": "MaxPooling2D",
     "schema": {
+      "category": "Pool",
       "attributes": [
         { "name": "data_format", "default": "channels_last" },
         { "name": "padding", "default": "valid" },
@@ -26,8 +37,9 @@
     }
   },
   {
-    "name": "MaxPooling2D",
+    "name": "UpSampling2D",
     "schema": {
+      "category": "Layer",
       "attributes": [
         { "name": "data_format", "default": "channels_last" }
       ]
@@ -36,6 +48,25 @@
   {
     "name": "GlobalMaxPooling2D",
     "schema": {
+      "category": "Pool",
+      "attributes": [
+        { "name": "data_format", "default": "channels_last" }
+      ]
+    }
+  },
+  {
+    "name": "GlobalAveragePooling2D",
+    "schema": {
+      "category": "Pool",
+      "attributes": [
+        { "name": "data_format", "default": "channels_last" }
+      ]
+    }
+  },
+  {
+    "name": "AveragePooling2D",
+    "schema": {
+      "category": "Pool",
       "attributes": [
         { "name": "data_format", "default": "channels_last" }
       ]
@@ -44,6 +75,7 @@
   {
     "name": "BatchNormalization",
     "schema": {
+      "category": "Normalization",
       "attributes": [
         { "name": "axis", "default": -1 },
         { "name": "epsilon", "default": 1e-3 },
@@ -60,6 +92,7 @@
   {
     "name": "Dense",
     "schema": {
+      "category": "Layer",
       "attributes": [
         { "name": "activation", "default": "linear" },
         { "name": "use_bias", "default": true },
@@ -68,9 +101,83 @@
       ]
     }
   },
+  {
+    "name": "LSTM",
+    "schema": {
+      "category": "Layer",
+      "attributes": [
+        { "name": "return_sequences", "default": false },
+        { "name": "return_state", "default": false },
+        { "name": "activation", "default": "tanh" },
+        { "name": "recurrent_activation", "default": "hard_sigmoid" },
+        { "name": "use_bias", "default": true },
+        { "name": "bias_initializer", "default": { "class_name": "Zeros", "config": {} } },
+        { "name": "unit_forget_bias", "default": true },
+        { "name": "dropout", "default": 0.0 },
+        { "name": "recurrent_dropout", "default": 0.0 },
+        { "name": "implementation", "default": 1 },
+        { "name": "unroll", "default": false },
+        { "name": "stateful", "default": false },
+        { "name": "go_backwards", "default": false },
+        { "name": "kernel_initializer", "default": { "class_name": "VarianceScaling", "config": { "distribution": "uniform", "scale": 1, "seed": null, "mode": "fan_avg" } } },
+        { "name": "recurrent_initializer", "default": { "class_name": "Orthogonal", "config": { "seed": null, "gain": 1 } } }
+      ]
+    }
+  },
+  {
+    "name": "GRU",
+    "schema": {
+      "category": "Layer",
+      "attributes": [
+        { "name": "activation", "default": "tanh" },
+        { "name": "recurrent_activation", "default": "hard_sigmoid" },
+        { "name": "use_bias", "default": true },
+        { "name": "kernel_initializer", "default": { "class_name": "VarianceScaling", "config": { "distribution": "uniform", "scale": 1, "seed": null, "mode": "fan_avg" } } },
+        { "name": "recurrent_initializer", "default": { "class_name": "Orthogonal", "config": { "seed": null, "gain": 1 } } },
+        { "name": "bias_initializer", "default": { "class_name": "Zeros", "config": {} } },
+        { "name": "dropout", "default": 0.0 },
+        { "name": "implementation", "default": 1 },
+        { "name": "return_sequences", "default": false },
+        { "name": "return_state", "default": false },
+        { "name": "go_backwards", "default": false },
+        { "name": "stateful", "default": false },
+        { "name": "unroll", "default": false }
+      ]
+    }
+  },
+  {
+    "name": "RNN",
+    "schema": {
+      "category": "Layer",
+      "attributes": [
+        { "name": "return_sequences", "default": false },
+        { "name": "return_state", "default": false },
+        { "name": "go_backwards", "default": false },
+        { "name": "stateful", "default": false },
+        { "name": "unroll", "default": false }
+      ]
+    }
+  },
   {
     "name": "Conv2D",
     "schema": {
+      "category": "Layer",
+      "attributes": [
+        { "name": "activation", "default": "linear" },
+        { "name": "padding", "default": "valid" },
+        { "name": "use_bias", "default": true },
+        { "name": "data_format", "default": "channels_last" },
+        { "name": "strides", "default": [1, 1] },
+        { "name": "dilation_rate", "default": [1, 1] },
+        { "name": "bias_initializer", "default": { "class_name": "Zeros", "config": {} } },
+        { "name": "kernel_initializer", "default": { "class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1, "seed": null, "mode": "fan_avg" } } }
+      ]
+    }
+  },
+  {
+    "name": "Convolution2D",
+    "schema": {
+      "category": "Layer",
       "attributes": [
         { "name": "activation", "default": "linear" },
         { "name": "padding", "default": "valid" },
@@ -86,6 +193,7 @@
   {
     "name": "DepthwiseConv2D",
     "schema": {
+      "category": "Layer",
       "attributes": [
         { "name": "activation", "default": "linear" },
         { "name": "padding", "default": "valid" },
@@ -97,5 +205,29 @@
         { "name": "depthwise_initializer", "default": { "class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1, "seed": null, "mode": "fan_avg" } } }
       ]
     }
+  },
+  {
+    "name": "Concatenate",
+    "schema": {
+      "category": "Tensor"
+    }
+  },
+  {
+    "name": "Flatten",
+    "schema": {
+      "category": "Shape"
+    }
+  },
+  {
+    "name": "Reshape",
+    "schema": {
+      "category": "Shape"
+    }
+  },
+  {
+    "name": "Dropout",
+    "schema": {
+      "category": "Dropout"
+    }
   }
 ]

+ 6 - 2
src/view-render.css

@@ -16,15 +16,19 @@
 .node-item-operator-constant path { fill: #eee; }
 .node-item-operator-constant text { fill: #000; }
 .node-item-operator-constant:hover path { fill: #fff; }
+
+.node-item-operator-control path { fill: #eee; }
+.node-item-operator-control text { fill: #000; }
+.node-item-operator-control:hover path { fill: #fff; }
+
 .node-item-operator-layer path { fill: #358; }
+.node-item-operator-wrapper path { fill: #457; }
 .node-item-operator-activation path { fill: #4B1B16; }
 .node-item-operator-pool path { fill: #353; }
 .node-item-operator-normalization path { fill: #354; }
 .node-item-operator-dropout path { fill: #454770; }
 .node-item-operator-shape path { fill: #6C4F47; }
 .node-item-operator-tensor path { fill: #59423B; }
-.node-item-operator-control path { fill: #eee; }
-.node-item-operator-control text { fill: #000; }
 
 .node-item-input path { fill: #fff; }
 .node-item-input:hover { cursor: hand; }

+ 0 - 1
src/view.js

@@ -257,7 +257,6 @@ function updateGraph(model) {
         }
         tuple.from = { 
             node: nodeId,
-            // name: valueInfo.name
         };
 
         var formatter = new NodeFormatter();