Lutz Roeder 4 лет назад
Родитель
Сommit
1b77542d40
3 измененных файлов с 97 добавлено и 118 удалено
  1. 21 21
      source/hdf5.js
  2. 69 97
      source/keras.js
  3. 7 0
      test/models.json

+ 21 - 21
source/hdf5.js

@@ -7,13 +7,18 @@ var zip = zip || require('./zip');
 
 hdf5.File = class {
 
-    constructor(buffer) {
-        // https://support.hdfgroup.org/HDF5/doc/H5.format.html
+    static open(data) {
+        const buffer = data instanceof Uint8Array ? data : data.peek();
         const reader = new hdf5.Reader(buffer, 0);
-        this._globalHeap = new hdf5.GlobalHeap(reader);
-        if (!reader.match('\x89HDF\r\n\x1A\n')) {
-            throw new hdf5.Error('Not a valid HDF5 file.');
+        if (reader.match('\x89HDF\r\n\x1A\n')) {
+            return new hdf5.File(reader);
         }
+        return null;
+    }
+
+    constructor(reader) {
+        // https://support.hdfgroup.org/HDF5/doc/H5.format.html
+        this._globalHeap = new hdf5.GlobalHeap(reader);
         const version = reader.byte();
         switch (version) {
             case 0:
@@ -94,9 +99,8 @@ hdf5.Group = class {
             }
         }
         else {
-            const group = this._groupMap[path];
-            if (group) {
-                return group;
+            if (this._groups.has(path)) {
+                return this._groups.get(path);
             }
         }
         return null;
@@ -107,11 +111,6 @@ hdf5.Group = class {
         return this._groups;
     }
 
-    attribute(name) {
-        this._decodeDataObject();
-        return this._attributes[name];
-    }
-
     get attributes() {
         this._decodeDataObject();
         return this._attributes;
@@ -127,11 +126,11 @@ hdf5.Group = class {
             this._dataObjectHeader = new hdf5.DataObjectHeader(this._reader.at(this._entry.objectHeaderAddress));
         }
         if (!this._attributes) {
-            this._attributes = {};
+            this._attributes = new Map();
             for (const attribute of this._dataObjectHeader.attributes) {
                 const name = attribute.name;
                 const value = attribute.decodeValue(this._globalHeap);
-                this._attributes[name] = value;
+                this._attributes.set(name, value);
             }
             this._value = null;
             const datatype = this._dataObjectHeader.datatype;
@@ -146,8 +145,7 @@ hdf5.Group = class {
 
     _decodeGroups() {
         if (!this._groups) {
-            this._groupMap = {};
-            this._groups = [];
+            this._groups = new Map();
             if (this._entry) {
                 if (this._entry.treeAddress || this._entry.heapAddress) {
                     const heap = new hdf5.Heap(this._reader.at(this._entry.heapAddress));
@@ -156,8 +154,7 @@ hdf5.Group = class {
                         for (const entry of node.entries) {
                             const name = heap.getString(entry.linkNameOffset);
                             const group = new hdf5.Group(this._reader, entry, null, this._globalHeap, this._path, name);
-                            this._groups.push(group);
-                            this._groupMap[name] = group;
+                            this._groups.set(name, group);
                         }
                     }
                 }
@@ -169,8 +166,7 @@ hdf5.Group = class {
                         const name = link.name;
                         const objectHeader = new hdf5.DataObjectHeader(this._reader.at(link.objectHeaderAddress));
                         const linkGroup = new hdf5.Group(this._reader, null, objectHeader, this._globalHeap, this._path, name);
-                        this._groups.push(linkGroup);
-                        this._groupMap[name] = linkGroup;
+                        this._groups.set(name, linkGroup);
                     }
                 }
             }
@@ -943,6 +939,8 @@ hdf5.Datatype = class {
                 throw new hdf5.Error('Unsupported character encoding.');
             case 5: // opaque
                 return reader.read(this._size);
+            case 8: // enumerated
+                return reader.read(this._size);
             case 9: // variable-length
                 return {
                     length: reader.uint32(),
@@ -962,6 +960,8 @@ hdf5.Datatype = class {
                 return data;
             case 5: // opaque
                 return data;
+            case 8: // enumerated
+                return data;
             case 9: { // variable-length
                 const globalHeapObject = globalHeap.get(data.globalHeapID);
                 if (globalHeapObject != null) {

+ 69 - 97
source/keras.js

@@ -35,38 +35,51 @@ keras.ModelFactory = class {
     open(context) {
         return keras.Metadata.open(context).then((metadata) => {
             let format = 'Keras';
-            let producer = '';
             let backend = '';
-            let model_config = null;
-            let rootGroup = null;
             const weights = new keras.Weights();
-            const manifests = [];
             const stream = context.stream;
             const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
             if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
                 return context.require('./hdf5').then((hdf5) => {
-                    const buffer = stream.peek();
-                    const file = new hdf5.File(buffer);
-                    rootGroup = file.rootGroup;
-                    if (rootGroup.attribute('model_config') || rootGroup.attribute('layer_names')) {
-                        const model_config_json = rootGroup.attribute('model_config');
-                        if (model_config_json) {
-                            const reader = json.TextReader.open(model_config_json);
-                            model_config = reader.read();
+                    const file = hdf5.File.open(stream);
+                    let rootGroup = file.rootGroup;
+                    const read_model_config = (group) => {
+                        if (group.attributes.has('model_config')) {
+                            const buffer = rootGroup.attributes.get('model_config');
+                            const reader = json.TextReader.open(buffer);
+                            return reader.read();
                         }
-                        backend = rootGroup.attribute('backend') || '';
-                        const version = rootGroup.attribute('keras_version') || '';
-                        format = format + (version ? ' v' + version : '');
-                        let model_weights_group = rootGroup.group('model_weights');
-                        if (!model_weights_group && rootGroup.attribute('layer_names')) {
-                            model_weights_group = rootGroup;
+                        return null;
+                    };
+                    const load_attributes_from_hdf5_group = (group, name) => {
+                        if (group.attributes.has(name)) {
+                            return group.attributes.get(name);
+                        }
+                        if (group.attributes.has(name + '0')) {
+                            let index = 0;
+                            let value = [];
+                            while (group.attributes.has(name + index.toString())) {
+                                const chunk = group.attributes.get(name + index.toString());
+                                value = value.concat(chunk);
+                                index++;
+                            }
+                            return value;
                         }
+                        return null;
+                    };
+                    const model_config = read_model_config(rootGroup);
+                    const layer_names = load_attributes_from_hdf5_group(rootGroup, 'layer_names');
+                    if (model_config || (layer_names && Array.isArray(layer_names))) {
+                        backend = rootGroup.attributes.get('backend') || '';
+                        const version = rootGroup.attributes.get('keras_version') || '';
+                        format = format + (version ? ' v' + version : '');
+                        const model_weights_group = layer_names ? rootGroup : rootGroup.group('model_weights');
                         if (model_weights_group) {
-                            model_weights_group = new keras.Group(model_weights_group);
-                            for (const layer_name of model_weights_group.attribute('layer_names')) {
+                            const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
+                            for (const layer_name of layer_names) {
                                 const layer_weights = model_weights_group.group(layer_name);
                                 if (layer_weights) {
-                                    const weight_names = layer_weights.attribute('weight_names');
+                                    const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
                                     if (weight_names && weight_names.length > 0) {
                                         for (const weight_name of weight_names) {
                                             const weight = layer_weights.group(weight_name);
@@ -90,45 +103,31 @@ keras.ModelFactory = class {
                         }
                     }
                     else {
-                        const attributes = new Set([ 'nb_layers' ]);
-                        if (Object.keys(rootGroup.attributes).filter((name) => !attributes.has(name)).length !== 0 || rootGroup.value !== null) {
+                        const rootKeys = new Set([ 'nb_layers' ]);
+                        if (Array.from(rootGroup.attributes.keys()).filter((name) => !rootKeys.has(name)).length !== 0 || rootGroup.value !== null) {
                             throw new keras.Error('File format is not HDF5 Weights');
                         }
                         format = 'HDF5 Weights';
-                        if (Object.keys(rootGroup.attributes).length === 0 && rootGroup.value === null &&
-                            rootGroup.groups.length == 1 && rootGroup.groups[0] &&
-                            Object.keys(rootGroup.groups[0].attributes).length === 0 && rootGroup.groups[0].value === null) {
-                            rootGroup = rootGroup.groups[0];
+                        if (rootGroup.attributes.size === 0 && rootGroup.value === null && rootGroup.groups.size == 1) {
+                            const group = rootGroup.groups.values().next().value;
+                            if (group.attributes.size === 0 && group.value === null) {
+                                rootGroup = group;
+                            }
                         }
-                        if (rootGroup.groups.every((group) => Object.keys(group.attributes).length === 0 && group.groups.length == 0 && group.value !== null)) {
-                            for (const group of rootGroup.groups) {
+                        const tensorKeys = new Set([ 'name', 'shape', 'quantization' ]);
+                        const groups = Array.from(rootGroup.groups.values());
+                        if (groups.every((group) => group.attributes.size === 0 && group.groups.length == 0 && group.value !== null)) {
+                            for (const group of groups) {
                                 const variable = group.value;
                                 const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
                                 weights.add('', tensor);
                             }
                         }
-                        else if (rootGroup.groups.every((group) => Object.keys(group.attributes).length === 0 && group.value === null)) {
-                            for (const group of rootGroup.groups) {
-                                const moduleName = group.attributes.name || group.name;
-                                for (const variableGroup of group.groups) {
-                                    if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
-                                        throw new keras.Error('Group is not HDF5 tensor variable.');
-                                    }
-                                    const variable = variableGroup.value;
-                                    if (!variable) {
-                                        throw new keras.Error('Variable value is not HDF5 tensor.');
-                                    }
-                                    const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
-                                    const tensor = new keras.Tensor(name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
-                                    weights.add(moduleName, tensor);
-                                }
-                            }
-                        }
-                        else if (rootGroup.groups.every((group) => group.value === null && group.groups.every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
-                            for (const group of rootGroup.groups) {
-                                const moduleName = group.attributes.name || group.name;
-                                for (const variableGroup of group.groups) {
-                                    if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
+                        else if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
+                            for (const group of groups) {
+                                const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
+                                for (const variableGroup of group.groups.values()) {
+                                    if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
                                         throw new keras.Error('Variable format is not HDF5 Weights');
                                     }
                                     const variable = variableGroup.value;
@@ -143,12 +142,19 @@ keras.ModelFactory = class {
                         }
                         else {
                             const walk = function(group) {
-                                if (Object.keys(group.attributes).length === 0 && group.value === null && group.groups.length > 0) {
-                                    for (const subGroup of group.groups) {
+                                if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
+                                    for (const subGroup of group.groups.values()) {
                                         walk(subGroup);
                                     }
+                                    return;
+                                }
+                                const subKeys = new Set([ 'index', 'need_grad' ]);
+                                const attribtues = Array.from(group.attributes.keys());
+                                const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
+                                if (match && attribtues.length !== 0) {
+                                    format = 'nnabla HDF5 Weights';
                                 }
-                                else if (Object.keys(group.attributes).length === 0 && group.value !== null && group.groups.length === 0) {
+                                if (match && group.value !== null && group.groups.size === 0) {
                                     const variable = group.value;
                                     const variableName = group.path;
                                     let moduleName = variableName;
@@ -159,10 +165,9 @@ keras.ModelFactory = class {
                                     }
                                     const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
                                     weights.add(moduleName, tensor);
+                                    return;
                                 }
-                                else {
-                                    throw new keras.Error('Module group format is not HDF5 Weights');
-                                }
+                                throw new keras.Error('Module group format is not HDF5 Weights');
                             };
                             walk(rootGroup);
                         }
@@ -173,11 +178,15 @@ keras.ModelFactory = class {
                     if (!rootGroup && !model_config.class_name) {
                         throw new keras.Error('\'class_name\' is not present.');
                     }
-                    return new keras.Model(metadata, format, producer, backend, model_config, weights);
+                    return new keras.Model(metadata, format, '', backend, model_config, weights);
                 });
             }
             const obj = context.open('json');
             if (obj) {
+                let rootGroup = null;
+                let model_config = null;
+                let producer = '';
+                const manifests = [];
                 if (obj && Array.isArray(obj) && obj.every((manifest) => Array.isArray(manifest.weights) && Array.isArray(manifest.paths))) {
                     format = 'TensorFlow.js Weights';
                     rootGroup = {};
@@ -670,7 +679,8 @@ keras.Node = class {
                     }
                 }
                 if (name !== 'name' && value !== null) {
-                    this._attributes.push(new keras.Attribute(metadata.attribute(this.type, name), name, value));
+                    const attribute = new keras.Attribute(metadata.attribute(this.type, name), name, value);
+                    this._attributes.push(attribute);
                 }
             }
         }
@@ -1201,44 +1211,6 @@ keras.Metadata = class {
     }
 };
 
-keras.Group = class {
-
-    constructor(group) {
-        this._group = group;
-    }
-
-    attribute(name) {
-        let value = this._group.attribute(name);
-        if (!value) {
-            if (this._group.attribute(name + '0')) {
-                let index = 0;
-                value = [];
-                for (;;) {
-                    const chunk = this._group.attribute(name + index.toString());
-                    if (!chunk) {
-                        break;
-                    }
-                    value = value.concat(chunk);
-                    index++;
-                }
-            }
-        }
-        return value;
-    }
-
-    group(name) {
-        const value = this._group.group(name);
-        if (value) {
-            return new keras.Group(value);
-        }
-        return null;
-    }
-
-    get value() {
-        return this._group.value;
-    }
-};
-
 keras.Weights = class {
 
     constructor() {

+ 7 - 0
test/models.json

@@ -2785,6 +2785,13 @@
     "format": "MXNet Model Archive v1.0",
     "link":   "https://github.com/lutzroeder/netron/issues/286"
   },
+  {
+    "type":   "nnabla",
+    "target": "tecogan_model.h5",
+    "source": "https://nnabla.org/pretrained-models/nnabla-examples/GANs/tecogan/tecogan_model.h5",
+    "format": "nnabla HDF5 Weights",
+    "link":   "https://github.com/sony/nnabla-examples/issues/192"
+  },
   {
     "type":   "ncnn",
     "target": "centerface.param,centerface.bin",