|
|
@@ -128,9 +128,7 @@ class KerasGraph {
|
|
|
}
|
|
|
|
|
|
loadModel(root) {
|
|
|
-
|
|
|
if (root.layers) {
|
|
|
-
|
|
|
var nodeMap = {};
|
|
|
root.layers.forEach((layer) => {
|
|
|
if (layer.name) {
|
|
|
@@ -144,86 +142,140 @@ class KerasGraph {
|
|
|
root.layers.forEach((layer) => {
|
|
|
if (layer.inbound_nodes) {
|
|
|
layer.inbound_nodes.forEach((inbound_node) => {
|
|
|
- var input = { connections: [] };
|
|
|
inbound_node.forEach((inbound_connection) => {
|
|
|
+ var input = { connections: [] };
|
|
|
var inputName = inbound_connection[0];
|
|
|
input.connections.push({ id: inputName });
|
|
|
var inputNode = nodeMap[inputName];
|
|
|
if (inputNode) {
|
|
|
- inputNode._outputs.push(inputNode.name);
|
|
|
+ inputNode._outputs.push({
|
|
|
+ connections: [ { id: inputNode.name } ]
|
|
|
+ });
|
|
|
}
|
|
|
+ layer._inputs.push(input);
|
|
|
});
|
|
|
- layer._inputs.push(input);
|
|
|
});
|
|
|
}
|
|
|
});
|
|
|
}
|
|
|
-
|
|
|
- /*
|
|
|
if (root.input_layers) {
|
|
|
root.input_layers.forEach((input_layer) => {
|
|
|
- this._inputs.push({ id: input_layer[0], name: input_layer[0] });
|
|
|
- });
|
|
|
+ var name = input_layer[0];
|
|
|
+ var input = {
|
|
|
+ id: name,
|
|
|
+ name: name
|
|
|
+ };
|
|
|
+ var node = nodeMap[name];
|
|
|
+ if (node && node.class_name == 'InputLayer') {
|
|
|
+ this.translateInput(node, input);
|
|
|
+ delete nodeMap[name];
|
|
|
+ }
|
|
|
+ this._inputs.push(input);
|
|
|
+ });
|
|
|
}
|
|
|
- */
|
|
|
-
|
|
|
if (root.output_layers) {
|
|
|
root.output_layers.forEach((output_layer) => {
|
|
|
var inputName = output_layer[0];
|
|
|
var inputNode = nodeMap[inputName];
|
|
|
if (inputNode) {
|
|
|
- inputNode._outputs.push(inputName);
|
|
|
+ inputNode._outputs.push({
|
|
|
+ connections: [ { id: inputName } ]
|
|
|
+ });
|
|
|
}
|
|
|
- this._outputs.push({ id: inputName, name: inputName, type: '?' });
|
|
|
+ var output = {
|
|
|
+ id: inputName,
|
|
|
+ name: inputName,
|
|
|
+ type: '?'
|
|
|
+ };
|
|
|
+ this._outputs.push(output);
|
|
|
});
|
|
|
}
|
|
|
-
|
|
|
if (root.layers) {
|
|
|
root.layers.forEach((layer) => {
|
|
|
- var node = new KerasNode(layer.class_name, layer.name, layer.config, layer._inputs, layer._outputs);
|
|
|
- this._nodes.push(node);
|
|
|
+ if (nodeMap[layer.name]) {
|
|
|
+ this.translateNode(layer.name, layer, layer._inputs, layer._outputs).forEach((node) => {
|
|
|
+ this._nodes.push(node);
|
|
|
+ });
|
|
|
+ }
|
|
|
});
|
|
|
}
|
|
|
}
|
|
|
|
|
|
loadSequential(root) {
|
|
|
- var output = 'input';
|
|
|
-
|
|
|
- this._inputs.push({
|
|
|
- name: output,
|
|
|
- id: output,
|
|
|
- type: '?'
|
|
|
- });
|
|
|
-
|
|
|
+ var connection = 'input';
|
|
|
+ var input = {
|
|
|
+ id: connection,
|
|
|
+ name: connection
|
|
|
+ };
|
|
|
+ this._inputs.push(input);
|
|
|
var id = 0;
|
|
|
root.forEach((layer) => {
|
|
|
- var inputs = [];
|
|
|
- if (output) {
|
|
|
- inputs.push({
|
|
|
- name: '(0)',
|
|
|
- connections: [ { id: output }]
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
+ var inputs = [ {
|
|
|
+ connections: [ { id: connection } ]
|
|
|
+ } ];
|
|
|
var name = id.toString();
|
|
|
- if (layer.config || layer.config.name) {
|
|
|
- name = layer.config.name;
|
|
|
+ if (id == 0) {
|
|
|
+ this.translateInput(layer, input);
|
|
|
}
|
|
|
id++;
|
|
|
- output = name;
|
|
|
-
|
|
|
- var outputs = [ output ];
|
|
|
-
|
|
|
- var node = new KerasNode(layer.class_name, name, layer.config, inputs, outputs);
|
|
|
- this._nodes.push(node);
|
|
|
+ if (layer.config && layer.config.name) {
|
|
|
+ name = layer.config.name;
|
|
|
+ }
|
|
|
+ connection = name;
|
|
|
+ var outputs = [ {
|
|
|
+ connections: [ { id: connection } ]
|
|
|
+ } ];
|
|
|
+ this.translateNode(name, layer, inputs, outputs).forEach((node) => {
|
|
|
+ this._nodes.push(node);
|
|
|
+ });
|
|
|
});
|
|
|
-
|
|
|
this._outputs.push({
|
|
|
name: 'output',
|
|
|
- id: output,
|
|
|
+ id: connection,
|
|
|
type: '?'
|
|
|
});
|
|
|
}
|
|
|
+
|
|
|
+ translateNode(name, layer, inputs, outputs) {
|
|
|
+ var results = [];
|
|
|
+ if (layer.class_name == 'Bidirectional' || layer.class_name == 'TimeDistributed') {
|
|
|
+ if (layer.config.layer) {
|
|
|
+ var subLayer = layer.config.layer;
|
|
|
+ var subConnection = name + '|' + layer;
|
|
|
+ inputs.push({
|
|
|
+ name: 'layer',
|
|
|
+ connections: [ { id: subConnection} ]
|
|
|
+ });
|
|
|
+ var subOutputs = [ {
|
|
|
+ connections: [ { id: subConnection } ]
|
|
|
+ } ];
|
|
|
+ results.push(new KerasNode(subLayer.class_name, subLayer.config.name, subLayer.config, [], subOutputs));
|
|
|
+ delete layer.config.layer;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ var node = new KerasNode(layer.class_name, name, layer.config, inputs, outputs);
|
|
|
+ results.push(node);
|
|
|
+ return results;
|
|
|
+ }
|
|
|
+
|
|
|
+ translateInput(layer, input) {
|
|
|
+ input.type = '';
|
|
|
+ if (layer && layer.config) {
|
|
|
+ var config = layer.config;
|
|
|
+ if (config.dtype) {
|
|
|
+ input.type = config.dtype;
|
|
|
+ delete config.dtype;
|
|
|
+ }
|
|
|
+ if (config.batch_input_shape) {
|
|
|
+ var shape = config.batch_input_shape;
|
|
|
+ if (shape.length > 0 && shape[0] == null) {
|
|
|
+ shape.shift();
|
|
|
+ }
|
|
|
+ input.type = input.type + '[' + shape.toString() + ']';
|
|
|
+ delete config.batch_input_shape;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
class KerasNode {
|
|
|
@@ -251,8 +303,8 @@ class KerasNode {
|
|
|
get inputs() {
|
|
|
var results = [];
|
|
|
this._inputs.forEach((input, index) => {
|
|
|
- results.push({
|
|
|
- name: '(' + index.toString() + ')',
|
|
|
+ results.push({
|
|
|
+ name: input.name ? input.name : '(' + index.toString() + ')',
|
|
|
connections: input.connections
|
|
|
});
|
|
|
});
|
|
|
@@ -263,8 +315,8 @@ class KerasNode {
|
|
|
var results = [];
|
|
|
this._outputs.forEach((output, index) => {
|
|
|
results.push({
|
|
|
- name: '(' + index.toString() + ')',
|
|
|
- connections: [ { id: output }]
|
|
|
+ name: output.name ? output.name : '(' + index.toString() + ')',
|
|
|
+ connections: output.connections
|
|
|
});
|
|
|
});
|
|
|
return results;
|
|
|
@@ -302,12 +354,21 @@ class KerasAttribute {
|
|
|
}
|
|
|
|
|
|
get value() {
|
|
|
- if (this._value == true) {
|
|
|
+ if (this._value === true) {
|
|
|
return 'true';
|
|
|
}
|
|
|
- if (this._value == false) {
|
|
|
+ if (this._value === false) {
|
|
|
return 'false';
|
|
|
}
|
|
|
+ if (this._value === null) {
|
|
|
+ return 'null';
|
|
|
+ }
|
|
|
+ if (typeof this._value == 'object' && this._value.class_name && this._value.config) {
|
|
|
+ return this._value.class_name + '(' + Object.keys(this._value.config).map(key => {
|
|
|
+ var value = this._value.config[key];
|
|
|
+ return key + '=' + JSON.stringify(value);
|
|
|
+ }).join(', ') + ')';
|
|
|
+ }
|
|
|
if (this._value) {
|
|
|
return JSON.stringify(this._value);
|
|
|
}
|
|
|
@@ -346,27 +407,6 @@ class KerasOperatorMetadata {
|
|
|
}
|
|
|
});
|
|
|
}
|
|
|
-
|
|
|
- this._categoryMap = {
|
|
|
- 'Conv1D': 'Layer',
|
|
|
- 'Conv2D': 'Layer',
|
|
|
- 'Conv3D': 'Layer',
|
|
|
- 'Convolution1D': 'Layer',
|
|
|
- 'Convolution2D': 'Layer',
|
|
|
- 'Convolution3D': 'Layer',
|
|
|
- 'DepthwiseConv2D': 'Layer',
|
|
|
- 'Dense': 'Layer',
|
|
|
- 'BatchNormalization': 'Normalization',
|
|
|
- 'Concatenate': 'Tensor',
|
|
|
- 'Activation': 'Activation',
|
|
|
- 'GlobalAveragePooling2D': 'Pool',
|
|
|
- 'AveragePooling2D': 'Pool',
|
|
|
- 'MaxPooling2D': 'Layer',
|
|
|
- 'GlobalMaxPooling2D': 'Layer',
|
|
|
- 'Flatten': 'Shape',
|
|
|
- 'Reshape': 'Shape',
|
|
|
- 'Dropout': 'Dropout'
|
|
|
- };
|
|
|
}
|
|
|
|
|
|
showAttribute(operator, attributeName, attributeValue) {
|
|
|
@@ -397,9 +437,12 @@ class KerasOperatorMetadata {
|
|
|
}
|
|
|
|
|
|
getOperatorCategory(operator) {
|
|
|
- var category = this._categoryMap[operator];
|
|
|
- if (category) {
|
|
|
- return category;
|
|
|
+ var schema = this._map[operator];
|
|
|
+ if (schema) {
|
|
|
+ var category = schema.category;
|
|
|
+ if (category) {
|
|
|
+ return category;
|
|
|
+ }
|
|
|
}
|
|
|
return null;
|
|
|
}
|