Lutz Roeder пре 8 година
родитељ
комит
761f75afe3
3 измењених фајлова са 141 додато и 68 уклоњено
  1. 1 1
      README.md
  2. 133 67
      src/keras-model.js
  3. 7 0
      src/view-node.js

+ 1 - 1
README.md

@@ -3,7 +3,7 @@
 
 Netron is a viewer for neural network, deep learning and machine learning models. 
 
-Netron supports **[ONNX](http://onnx.ai)** (`.onnx`, `.pb`), **Keras** (`.h5`, `.keras`), **CoreML** (`.mlmodel`) and **TensorFlow Lite** (`.tflite`). Netron has experimental support for **Caffe** (`.caffemodel`), **Caffe2** (`predict_net.pb`), **MXNet** (`-symbol.json`) and **TensorFlow** (`.pb`, `.meta`).
+Netron supports **[ONNX](http://onnx.ai)** (`.onnx`, `.pb`), **Keras** (`.h5`, `.keras`), **CoreML** (`.mlmodel`) and **TensorFlow Lite** (`.tflite`). Netron has experimental support for **Caffe** (`.caffemodel`), **Caffe2** (`predict_net.pb`), **MXNet** (`-symbol.json`), **TensorFlow.js** (`model.json`, `.pb`) and **TensorFlow** (`.pb`, `.meta`).
 
 <p align='center'><a href='https://www.lutzroeder.com/ai'><img src='media/screenshot.png' width='800'></a></p>
 

+ 133 - 67
src/keras-model.js

@@ -17,26 +17,41 @@ class KerasModel {
 
     static create(buffer, identifier, host, callback) {
         try {
-            var version = null;
-            var backend = null;
-            var json = null;
+            var format = 'Keras';
             var rootGroup = null;
+            var rootJson = null;
+            var model_config = null;
 
             var extension = identifier.split('.').pop();
             if (extension == 'keras' || extension == 'h5') {
                 var file = new hdf5.File(buffer);
                 rootGroup = file.rootGroup;
-                json = rootGroup.attributes.model_config;
-                if (!json) {
+                var modelConfigJson = rootGroup.attributes.model_config;
+                if (!modelConfigJson) {
                     throw new KerasError('HDF5 file does not contain a \'model_config\' graph. Use \'save()\' instead of \'save_weights()\' to save both the graph and weights.');
                 }
+                model_config = JSON.parse(modelConfigJson);
             }
             else if (extension == 'json') {
                 var decoder = new window.TextDecoder('utf-8');
-                json = decoder.decode(buffer);
+                var json = decoder.decode(buffer);
+                model_config = JSON.parse(json);
+                if (model_config && model_config.modelTopology && model_config.modelTopology.model_config) {
+                    format = 'TensorFlow.js ' + format;
+                    rootJson = model_config;
+                    model_config = model_config.modelTopology.model_config;
+                }
             }
 
-            var model = new KerasModel(json, rootGroup);
+            if (!model_config) {
+                throw new KerasError('model_config is not present.');
+            }
+
+            if (!model_config.class_name) {
+                throw new KerasError('class_name is not present.');
+            }
+    
+            var model = new KerasModel(format, model_config, rootGroup, rootJson);
 
             KerasOperatorMetadata.open(host, (err, metadata) => {
                 callback(null, model);
@@ -47,30 +62,58 @@ class KerasModel {
         }
     }
 
-    constructor(json, rootGroup) {
-        var model = JSON.parse(json);
-        if (!model.class_name) {
-            throw new KerasError('class_name is not present.');
-        }
-        if (rootGroup && rootGroup.attributes.keras_version) {
-            this._version = rootGroup.attributes.keras_version;
+    constructor(format, model_config, rootGroup, rootJson) {
+        this._format = format;
+        this._graphs = [];
+
+        var model_weights = null;
+        var weightsManifest = null;
+        if (rootGroup) {
+            if (rootGroup.attributes.keras_version) {
+                this._version = rootGroup.attributes.keras_version;
+            }
+            if (rootGroup.attributes.backend) {
+                this._backend = rootGroup.attributes.backend;
+            }
+            model_weights = rootGroup.group('model_weights');
         }
-        if (rootGroup && rootGroup.attributes.backend) {
-            this._backend = rootGroup.attributes.backend;
+        else if (rootJson) {
+            if (rootJson.modelTopology && rootJson.modelTopology.keras_version) {
+                this._version = rootJson.modelTopology.keras_version;
+            }
+            if (rootJson.modelTopology && rootJson.modelTopology.backend) {
+                this._backend = rootJson.modelTopology.backend;
+            }
+            if (rootJson.weightsManifest) {
+                weightsManifest = {};
+                rootJson.weightsManifest.forEach((manifest) => {
+                    var match = false;
+                    var key = null;
+                    manifest.weights.forEach((weights) => {
+                        var name = weights.name.split('/').shift();
+                        if (key == null) {
+                            key = name;
+                            match = true;
+                        }
+                        else if (key != name) {
+                            match = false;
+                        }
+                    });
+                    if (match) {
+                        weightsManifest[key] = manifest;
+                    }
+                });
+            }
         }
 
-        var model_weights = rootGroup ? rootGroup.group('model_weights') : null;
-        this._activeGraph = new KerasGraph(model, model_weights);
-        this._graphs = [ this._activeGraph ];
+        this._activeGraph = new KerasGraph(model_config, model_weights, weightsManifest);
+        this._graphs.push(this._activeGraph);
     }
 
     get properties() {
         var results = [];
 
-        var format = 'Keras';
-        if (this._version) {
-            format = format + ' v' + this._version;
-        }
+        var format = this._format + (this._version ? (' v' + this._version) : '');
         results.push({ name: 'Format', value: format });
 
         if (this._backend) {
@@ -95,7 +138,7 @@ class KerasModel {
 
 class KerasGraph {
 
-    constructor(model, model_weights) {
+    constructor(model, model_weights, weightsManifest) {
         if (model.name) {
             this._name = model.name;
         }
@@ -109,10 +152,10 @@ class KerasGraph {
 
         switch (model.class_name) {
             case 'Sequential':
-                this.loadSequential(model.config, model_weights, '');
+                this.loadSequential(model.config, model_weights, weightsManifest, '');
                 break;
             case 'Model':
-                this.loadModel(model.config, model_weights, '', null, null);
+                this.loadModel(model.config, model_weights, weightsManifest, '', null, null);
                 break;
             default:
                 throw new KerasError('\'' + model.class_name + '\' is not supported.');
@@ -139,7 +182,7 @@ class KerasGraph {
         return this._nodes;
     }
 
-    loadModel(config, model_weights, group, inputs, outputs) {
+    loadModel(config, model_weights, weightsManifest, group, inputs, outputs) {
         if (group) {
             this._groups = true;
         }
@@ -238,13 +281,13 @@ class KerasGraph {
         if (config.layers) {
             config.layers.forEach((layer) => {
                 if (nodeMap[layer.name]) {
-                    this.loadNode(layer, layer._inputs, layer._outputs, model_weights, group);
+                    this.loadNode(layer, layer._inputs, layer._outputs, model_weights, weightsManifest, group);
                 }
             });
         }
     }
 
-    loadSequential(config, model_weights, group) {
+    loadSequential(config, model_weights, weightsManifest, group) {
         if (group) {
             this._groups = true;
         }
@@ -267,7 +310,7 @@ class KerasGraph {
             }
             connection = name;
             var outputs = [ connection ];
-            this.loadNode(layer, inputs, outputs, model_weights, group);
+            this.loadNode(layer, inputs, outputs, model_weights, weightsManifest, group);
         });
         this._outputs.push({ 
             id: connection,
@@ -276,24 +319,15 @@ class KerasGraph {
         });
     }
 
-    loadNode(layer, inputs, outputs, model_weights, group) {
+    loadNode(layer, inputs, outputs, model_weights, weightsManifest, group) {
         var class_name = layer.class_name;
         switch (class_name) {
             case 'Model':
-                this.loadModel(layer.config, model_weights, layer.name, inputs, outputs);
+                this.loadModel(layer.config, model_weights, weightsManifest, layer.name, inputs, outputs);
                 break;
             default:
                 var config = layer.config;
-                var weights = null;
-                if (model_weights) {
-                    if (group) {
-                        weights = model_weights.group(group);
-                    }
-                    else if (config) {
-                        weights = model_weights.group(config.name);
-                    }
-                }
-                this._nodes.push(new KerasNode(class_name, config, inputs, outputs, group, weights));
+                this._nodes.push(new KerasNode(class_name, config, inputs, outputs, group, model_weights, weightsManifest));
                 break;
         }
     }
@@ -318,7 +352,7 @@ class KerasGraph {
 
 class KerasNode {
 
-    constructor(operator, config, inputs, outputs, group, weights) {
+    constructor(operator, config, inputs, outputs, group, model_weights, weightsManifest) {
         if (group) {
             this._group = group;
         }
@@ -336,23 +370,44 @@ class KerasNode {
 
         var name = this.name;
         this._initializers = {};
-        if (weights) {
-            var weight_names = weights.attributes.weight_names;
-            if (weight_names) {
-                if (group) {
-                    weight_names = weight_names.filter(weight => weight.startsWith(name + '/'));
-                }
-                weight_names.forEach((weight_name) => {
-                    var weight_variable = weights.group(weight_name);
-                    if (weight_variable) {
-                        var variable = weight_variable.value;
-                        if (variable) {
-                            this._inputs.push(weight_name);
-                            this._initializers[weight_name] = new KerasTensor(variable);
+
+        if (model_weights) {
+            var weights = null;
+            if (group) {
+                weights = model_weights.group(group);
+            }
+            else if (config) {
+                weights = model_weights.group(config.name);
+            }
+            if (weights) {
+                var weight_names = weights.attributes.weight_names;
+                if (weight_names) {
+                    if (group) {
+                        weight_names = weight_names.filter(weight => weight.startsWith(name + '/'));
+                    }
+                    weight_names.forEach((weight_name) => {
+                        var weight_variable = weights.group(weight_name);
+                        if (weight_variable) {
+                            var variable = weight_variable.value;
+                            if (variable) {
+                                this._inputs.push(weight_name);
+                                this._initializers[weight_name] = new KerasTensor(variable.type, variable.shape, variable.rawData, '');
+                            }
                         }
+                    });
+                }
+            }
+        }
+        else if (weightsManifest) {
+            var manifest = weightsManifest[name];
+            if (manifest) {
+                manifest.weights.forEach((weights) => {
+                    if (weights.name) {
+                        this._inputs.push(weights.name);
+                        this._initializers[weights.name] = new KerasTensor(weights.dtype, weights.shape, null, manifest.paths.join(';'));
                     }
                 });
-            }
+            } 
         }
     }
 
@@ -472,22 +527,35 @@ class KerasAttribute {
 
 class KerasTensor {
 
-    constructor(variable) {
-        this._variable = variable;
+    constructor(type, shape, data, reference) {
+        this._type = type;
+        this._shape = shape;
+        this._data = data;
+        this._reference = reference;
     }
 
     get kind() {
-        return 'Initializer';
+        return 'Weights';
+    }
+
+    get name() {
+        return this._name;
     }
 
     get type() {
-        return this._variable.type + JSON.stringify(this._variable.shape);
+        return this._type + JSON.stringify(this._shape);
+    }
+
+    get reference() {
+        return this._reference;
     }
 
     get value() {
-        var rawData = this._variable.rawData;
-        if (rawData) {
-            switch (this._variable.type) {
+        if (this._reference) { 
+            return null;
+        }
+        if (this._data) {
+            switch (this._type) {
                 case 'float16':
                     this._precision = 16;
                     break;
@@ -500,15 +568,13 @@ class KerasTensor {
                 default:
                     return 'Tensor data type is not supported.';
             }
-            this._shape = this._variable.shape;
-            this._rawData = new DataView(rawData.buffer, rawData.byteOffset, rawData.byteLength);
+            this._rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
             this._index = 0;
             this._count = 0;
             var result = this.read(0);
             delete this._index;
             delete this._count;
             delete this._rawData;
-            delete this._shape;
             delete this._precision;
             return JSON.stringify(result, null, 4);
         }

+ 7 - 0
src/view-node.js

@@ -323,6 +323,13 @@ class NodeViewItemConnection {
                         quantizationLine.innerHTML = 'quantization: ' + '<code><b>' + quantization + '</b></code>';
                         this._element.appendChild(quantizationLine);   
                     }
+                    var reference = initializer.reference;
+                    if (reference) {
+                        var referenceLine = document.createElement('div');
+                        referenceLine.className = 'node-view-item-value-line-border';
+                        referenceLine.innerHTML = 'reference: ' + '<b>' + reference + '</b>';
+                        this._element.appendChild(referenceLine);   
+                    }
                     var value = initializer.value;
                     if (value) {
                         var valueLine = document.createElement('div');