Browse Source

Fix Keras weights reader (#428)

Lutz Roeder 6 năm trước cách đây
mục cha
commit
25890d970a
4 tập tin đã thay đổi với 114 bổ sung85 xóa
  1. 2 1
      src/keras-metadata.json
  2. 97 82
      src/keras.js
  3. 15 1
      test/models.json
  4. 0 1
      tools/keras-script.py

+ 2 - 1
src/keras-metadata.json

@@ -3777,7 +3777,8 @@
       "inputs": [
         {
           "description": "\nArbitrary. Use the keyword argument input_shape\n(tuple of integers, does not include the samples axis)\nwhen using this layer as the first layer in a model.\n",
-          "name": "input"
+          "name": "inputs",
+          "option": "variadic"
         }
       ],
       "outputs": [

+ 97 - 82
src/keras.js

@@ -132,7 +132,7 @@ keras.Model = class {
         this._producer = producer;
         this._graphs = [];
 
-        let weights = {};
+        let weights = new keras.Weights();
         if (rootGroup) {
             let model_weights_group = rootGroup.group('model_weights');
             if (!model_weights_group && rootGroup.attribute('layer_names')) {
@@ -140,51 +140,24 @@ keras.Model = class {
             }
             if (model_weights_group) {
                 model_weights_group = new keras.Group(model_weights_group);
-                let layer_names = model_weights_group.attribute('layer_names');
-                let layer_names_map = new Set();
-                for (const layer_name of layer_names) {
-                    layer_names_map.add(layer_name);
-                }
-                for (const layer_name of layer_names) {
-                    let layer_weights = model_weights_group.group(layer_name);
+                for (const layer_name of model_weights_group.attribute('layer_names')) {
+                    const layer_weights = model_weights_group.group(layer_name);
                     if (layer_weights) {
-                        let weight_names = layer_weights.attribute('weight_names');
-                        if (layer_weights && weight_names && weight_names.length > 0) {
+                        const weight_names = layer_weights.attribute('weight_names');
+                        if (weight_names && weight_names.length > 0) {
                             for (let weight_name of weight_names) {
-                                let group = layer_weights.group(weight_name);
-                                if (group) {
-                                    let variable = group.value;
-                                    if (variable) {
-                                        if (model_config) {
-                                            let initializer = new keras.Tensor(weight_name, variable.type, variable.shape, variable.littleEndian, variable.data, '');
-                                            let parts = weight_name.split('/');
-                                            parts.pop();
-                                            let match = false;
-                                            while (parts.length > 0) {
-                                                let name = parts.join('/');
-                                                if (layer_names_map.has(name)) {
-                                                    match = true;
-                                                }
-                                                weights[name] = weights[name] || [];
-                                                weights[name].push(initializer);
-                                                parts.shift();
-                                            }
-                                            if (!match) {
-                                                weights[layer_name] = weights[layer_name] || [];
-                                                weights[layer_name].push(initializer);
-                                            }
-                                        }
-                                        else {
-                                            if (!weight_name.startsWith(layer_name + '/')) {
-                                                weight_name = layer_name + '/' + weight_name; 
-                                            }
-                                            let initializer = new keras.Tensor(weight_name, variable.type, variable.shape, variable.littleEndian, variable.data, '');
-                                            let parts = weight_name.split('/');
-                                            parts.pop();
-                                            let name = parts.join('/');
-                                            weights[name] = weights[name] || [];
-                                            weights[name].push(initializer);
-                                        }
+                                const weight = layer_weights.group(weight_name);
+                                if (weight && weight.value) {
+                                    const variable = weight.value;
+                                    const tensor = new keras.Tensor(weight_name, variable.type, variable.shape, variable.littleEndian, variable.data, '');
+                                    if (model_config) {
+                                        weights.add(layer_name, tensor);
+                                    }
+                                    else {
+                                        const components = weight_name.split('/');
+                                        components.pop();
+                                        const name = (components.length == 0 || components[0] !== layer_name) ? [ layer_name ].concat(components).join('/') : components.join('/');
+                                        weights.add(name, tensor);
                                     }
                                 }
                             }
@@ -196,15 +169,8 @@ keras.Model = class {
         else if (weightsManifest) {
             for (const manifest of weightsManifest) {
                 for (const weight of manifest.weights) {
-                    let p = weight.name.split('/');
-                    p.pop();
-                    let initializer = new keras.Tensor(weight.name, weight.dtype, weight.shape, false, null, manifest.paths.join(';'));
-                    while (p.length > 0) {
-                        let weightName = p.join('/');
-                        weights[weightName] = weights[weightName] || [];
-                        weights[weightName].push(initializer);
-                        p.shift();
-                    }
+                    const tensor = new keras.Tensor(weight.name, weight.dtype, weight.shape, false, null, manifest.paths.join(';'));
+                    weights.add('', tensor);
                 }
             }
         }
@@ -262,9 +228,9 @@ keras.Graph = class {
             }
         }
         else if (weights) {
-            for (const layer of Object.keys(weights)) {
-                if (weights[layer].length <= 6) {
-                    const node = new keras.Node(metadata, 'Weights', { name: layer }, [], [], false, weights);
+            for (const layer of weights.keys()) {
+                if (weights.get('', layer).length <= 6) {
+                    const node = new keras.Node(metadata, 'Weights', { name: layer }, [], [], '', weights);
                     this._nodes.push(node)
                 }
             }
@@ -401,8 +367,7 @@ keras.Graph = class {
         let inputType = null;
         let argument = inputName;
         let index = 0;
-        let layers = config.layers ? config.layers : config;
-
+        const layers = config.layers ? config.layers : config;
         for (const layer of layers) {
             let name = index.toString();
             let nodeInputs = [ argument ];
@@ -438,14 +403,16 @@ keras.Graph = class {
     }
 
     _loadNode(layer, inputs, outputs, weights, group, inputMap) {
-        let class_name = layer.class_name;
+        const class_name = layer.class_name;
         switch (class_name) {
             case 'Sequential': {
-                this._loadSequential(layer.config, weights, layer.name, inputs, outputs);
+                const name = layer.name || (layer.config ? layer.config.name : '')
+                this._loadSequential(layer.config, weights, (group ? group + '/' : '') + name, inputs, outputs);
                 break;
             }
             case 'Model': {
-                this._loadModel(layer.config, weights, layer.name, inputs, outputs);
+                const name = layer.name || (layer.config ? layer.config.name : '')
+                this._loadModel(layer.config, weights, (group ? group + '/' : '') + name, inputs, outputs);
                 break;
             }
             default: {
@@ -524,43 +491,43 @@ keras.Argument = class {
 keras.Node = class {
 
     constructor(metadata, operator, config, inputs, outputs, group, weights) {
-        if (group) {
-            this._group = group;
-        }
+        this._group = group || '';
         this._metadata = metadata;
         this._operator = operator;
-        this._name = (config && config.name) ? config.name : '';
+        const name = config && config.name ? config.name : '';
+        this._name = (this._group ? this._group + '/' : '') + name;
         this._inputs = [];
         this._outputs = [];
         this._attributes = [];
 
-        let names = [ this._name ];
+        let names = [ name ];
         if ((operator == 'Bidirectional' || operator == 'TimeDistributed') && (config && config.layer)) {
             let inner = config.layer;
             delete config.layer;
             this._inner = new keras.Node(this._metadata, inner.class_name, inner.config, [], [], null, null);
             if (operator == 'Bidirectional' && inner.config.name) {
-                names = [ this._name + '/forward_' + inner.config.name, this._name + '/backward_' + inner.config.name ];
+                names = [ name + '/forward_' + inner.config.name, name + '/backward_' + inner.config.name ];
+                if (!group) {
+                    group = name;
+                }
             }
         }
 
         let initializers = {};
         if (weights) {
             for (const name of names) {
-                if (weights[name]) {
-                    for (const initializer of weights[name]) {
-                        inputs.push(initializer.name);
-                        initializers[initializer.name] = initializer;
-                    }
+                for (const initializer of weights.get(group, name)) {
+                    inputs.push(initializer.name);
+                    initializers[initializer.name] = initializer;
                 }
             }
         }
 
         if (config) {
-            for (const attributeName of Object.keys(config)) {
-                const attributeValue = config[attributeName];
-                if (attributeName != 'name' && attributeValue != null) {
-                    this._attributes.push(new keras.Attribute(this._metadata, this.operator, attributeName, attributeValue));
+            for (const name of Object.keys(config)) {
+                const value = config[name];
+                if (name != 'name' && value != null) {
+                    this._attributes.push(new keras.Attribute(this._metadata, this.operator, name, value));
                 }
             }
         }
@@ -612,7 +579,7 @@ keras.Node = class {
                         break;
                 }
             }
-            const input = !variadic ? [ inputs.shift() ] : inputs.slice(0, inputs.length);
+            const input = !variadic ? [ inputs.shift() ] : inputs.splice(0, inputs.length);
             const inputArguments = input.map((id) => {
                 return new keras.Argument(id, null, initializers[id]);
             });
@@ -640,16 +607,16 @@ keras.Node = class {
         return this._operator;
     }
 
+    get metadata() {
+        return this._metadata.type(this._operator);
+    }
+
     get name() {
         return this._name;
     }
 
     get group() {
-        return this._group ? this._group : '';
-    }
-
-    get metadata() {
-        return this._metadata.type(this._operator);
+        return this._group;
     }
 
     get inputs() {
@@ -1278,6 +1245,54 @@ keras.JsonParser = class {
     }
 }
 
+keras.Weights = class {
+
+    constructor() {
+        this._map = new Map();
+    }
+
+    add(layer_name, tensor) {
+        if (!this._map.has(layer_name)) {
+            this._map.set(layer_name, []);
+        }
+        this._map.get(layer_name).push(tensor);
+    }
+
+    get(group, name) {
+        if (group) {
+            const list = this._map.get(group.split('/').shift());
+            if (list) {
+                const match1 = list.filter((tensor) => tensor.name.startsWith(name + '/'));
+                if (match1.length > 0) {
+                    return match1;
+                }
+                const match2 = list.filter((tensor) => tensor.name.startsWith(group + '/' + name + '/'));
+                if (match2.length > 0) {
+                    return match2;
+                }
+            }
+        }
+        else {
+            const match1 = this._map.get(name);
+            if (match1 && match1.length > 0) {
+                return match1;
+            }
+            const match2 = this._map.get('');
+            if (match2 && match2.length > 0) {
+                const match3 = match2.filter((tensor) => tensor.name.startsWith((group ? group + '/' : '') + name + '/'));
+                if (match3.length > 0) {
+                    return match3;
+                }
+            }
+        }
+        return [];
+    }
+
+    keys() {
+        return this._map.keys();
+    }
+}
+
 keras.Error = class extends Error {
 
     constructor(message) {

+ 15 - 1
test/models.json

@@ -2182,6 +2182,20 @@
     "link":   "https://keras.io/applications",
     "script": "./tools/keras sync install zoo"
   },
+  {
+    "type":   "keras",
+    "target": "nested_bidrectional.h5",
+    "source": "https://github.com/lutzroeder/netron/files/4304644/nested_bidrectional.zip[nested_bidrectional.h5]",
+    "format": "Keras v2.3.1",
+    "link":   "https://github.com/lutzroeder/netron/issues/428"
+  },
+  {
+    "type":   "keras",
+    "target": "nested_bidrectional_weights.h5",
+    "source": "https://github.com/lutzroeder/netron/files/4304644/nested_bidrectional.zip[nested_bidrectional_weights.h5]",
+    "format": "Keras v2.3.1",
+    "link":   "https://github.com/lutzroeder/netron/issues/428"
+  },
   {
     "type":   "keras",
     "target": "netron_issue_326.json",
@@ -2267,7 +2281,7 @@
   },
   {
     "type":   "keras",
-    "target": "VGG16.h5",
+    "target": "VGG19.h5",
     "link":   "https://keras.io/applications",
     "script": "./tools/keras sync install zoo"
   },

+ 0 - 1
tools/keras-script.py

@@ -355,7 +355,6 @@ def zoo():
     download_model('keras.applications.mobilenet_v2.MobileNetV2', '${test}/data/keras/MobileNetV2.h5')
     download_model('keras.applications.nasnet.NASNetMobile', '${test}/data/keras/NASNetMobile.h5')
     download_model('keras.applications.resnet50.ResNet50', '${test}/data/keras/ResNet50.h5')
-    download_model('keras.applications.vgg16.VGG16', '${test}/data/keras/VGG16.h5')
     download_model('keras.applications.vgg19.VGG19', '${test}/data/keras/VGG19.h5')
     download_model('keras.applications.xception.Xception', '${test}/data/keras/Xception.h5')