فهرست منبع

Update keras.js

Lutz Roeder 4 سال پیش
والد
کامیت
7f12b1da1d
1فایلهای تغییر یافته به همراه199 افزوده شده و 194 حذف شده
  1. 199 194
      source/keras.js

+ 199 - 194
source/keras.js

@@ -6,41 +6,71 @@ var json = json || require('./json');
 keras.ModelFactory = class {
 
     match(context) {
-        const stream = context.stream;
-        const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
-        if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
-            return true;
-        }
-        const obj = context.open('json');
-        if (obj) {
-            if (obj.mxnet_version) {
-                return false;
-            }
-            if (obj.nodes && obj.arg_nodes && obj.heads) {
-                return false;
-            }
-            if (obj.modelTopology && (obj.format === 'layers-model' || obj.modelTopology.class_name || obj.modelTopology.model_config)) {
-                return true;
-            }
-            if (obj.model_config || (obj.class_name && obj.config)) {
-                return true;
-            }
-            if (Array.isArray(obj) && obj.every((item) => item.weights && item.paths)) {
-                return true;
-            }
-        }
-        return false;
+        return this._format(context).length > 0;
     }
 
     open(context) {
-        return keras.Metadata.open(context).then((metadata) => {
-            let format = 'Keras';
-            let backend = '';
+        const openModel = (format, producer, backend, config, weights) => {
+            return keras.Metadata.open(context).then((metadata) => {
+                return new keras.Model(metadata, format, producer, backend, config, weights);
+            });
+        };
+        const openShards = (manifests, shards) => {
             const weights = new keras.Weights();
-            const stream = context.stream;
-            const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
-            if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
+            const dtype_size_map = new Map([ [ 'float16', 2 ], [ 'float32', 4 ], [ 'float64', 8 ], [ 'int8', 1 ], [ 'int16', 2 ], [ 'int32', 4 ], [ 'int64', 8 ], [ 'uint8', 1 ], [ 'uint16', 2 ], [ 'uint32', 4 ], [ 'uint64', 8 ] ]);
+            for (const manifest of manifests) {
+                let buffer = null;
+                if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
+                    const list = manifest.paths.map((path) => shards.get(path));
+                    const size = list.reduce((a, b) => a + b.length, 0);
+                    buffer = new Uint8Array(size);
+                    let offset = 0;
+                    for (const item of list) {
+                        buffer.set(item, offset);
+                        offset += item.length;
+                    }
+                }
+                let offset = 0;
+                for (const weight of manifest.weights) {
+                    const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
+                    if (!dtype_size_map.has(dtype)) {
+                        throw new keras.Error("Unknown weight data type size '" + dtype + "'.");
+                    }
+                    const itemsize = dtype_size_map.get(dtype);
+                    const size = weight.shape.reduce((a, b) => a * b, 1);
+                    const length = itemsize * size;
+                    const data = buffer ? buffer.slice(offset, offset + length) : null;
+                    weights.add(weight.identifier, new keras.Tensor(weight.name, weight.shape, dtype, weight.quantization, true, data));
+                    offset += length;
+                }
+            }
+            return Promise.resolve(weights);
+        };
+        const openManifests = (manifests) => {
+            const shards = new Map();
+            for (const manifest of manifests) {
+                for (const path of manifest.paths) {
+                    if (!shards.has(path)) {
+                        shards.set(path, context.request(path, null));
+                    }
+                }
+            }
+            const promises = shards.values();
+            return Promise.all(promises).then((streams) => {
+                for (const key of shards.keys()) {
+                    shards.set(key, streams.shift().peek());
+                }
+                return openShards(manifests, shards);
+            }).catch(() => {
+                shards.clear();
+                return openShards(manifests, shards);
+            });
+        };
+        const stream = context.stream;
+        switch (this._format(context)) {
+            case 'keras.h5': {
                 return context.require('./hdf5').then((hdf5) => {
+                    const weights = new keras.Weights();
                     const file = hdf5.File.open(stream);
                     const rootGroup = file.rootGroup;
                     const read_model_config = (group) => {
@@ -69,9 +99,9 @@ keras.ModelFactory = class {
                     };
                     const model_config = read_model_config(rootGroup);
                     if (model_config) {
-                        backend = rootGroup.attributes.get('backend') || '';
+                        const backend = rootGroup.attributes.get('backend') || '';
                         const version = rootGroup.attributes.get('keras_version') || '';
-                        format = format + (version ? ' v' + version : '');
+                        const format = 'Keras' + (version ? ' v' + version : '');
                         const model_weights_group = rootGroup.group('model_weights');
                         if (model_weights_group) {
                             const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
@@ -98,13 +128,13 @@ keras.ModelFactory = class {
                         if (!model_config.class_name) {
                             throw new keras.Error("'class_name' is not present.");
                         }
-                        return new keras.Model(metadata, format, '', backend, model_config, weights);
+                        return openModel(format, '', backend, model_config, weights);
                     }
                     const layer_names = load_attributes_from_hdf5_group(rootGroup, 'layer_names');
                     if (layer_names && Array.isArray(layer_names)) {
                         const version = rootGroup.attributes.get('keras_version') || '';
-                        format = 'Keras Weights' + (version ? ' v' + version : '');
-                        backend = rootGroup.attributes.get('backend') || '';
+                        const format = 'Keras Weights' + (version ? ' v' + version : '');
+                        const backend = rootGroup.attributes.get('backend') || '';
                         for (const layer_name of layer_names) {
                             const layer_weights = rootGroup.group(layer_name);
                             if (layer_weights) {
@@ -124,182 +154,157 @@ keras.ModelFactory = class {
                                 }
                             }
                         }
-                        return new keras.Model(metadata, format, '', backend, null, weights);
+                        return openModel(format, '', backend, null, weights);
                     }
-                    const rootKeys = new Set(rootGroup.attributes.keys());
-                    rootKeys.delete('nb_layers');
-                    if (rootKeys.size > 0 || rootGroup.value !== null) {
-                        throw new keras.Error('File format is not HDF5 Weights');
-                    }
-                    format = 'HDF5 Weights';
-                    let weightsGroup = rootGroup;
-                    if (rootGroup.attributes.size === 0 && rootGroup.value === null && rootGroup.groups.size == 1) {
-                        const group = rootGroup.groups.values().next().value;
-                        if (group.attributes.size === 0 && group.value === null) {
-                            weightsGroup = group;
-                        }
-                    }
-                    const tensorKeys = new Set([ 'name', 'shape', 'quantization' ]);
-                    const groups = Array.from(weightsGroup.groups.values());
-                    if (groups.every((group) => group.attributes.size === 0 && group.groups.length == 0 && group.value !== null)) {
-                        for (const group of groups) {
-                            const variable = group.value;
-                            const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
-                            weights.add('', tensor);
+                    else {
+                        const rootKeys = new Set(rootGroup.attributes.keys());
+                        rootKeys.delete('nb_layers');
+                        if (rootKeys.size > 0 || rootGroup.value !== null) {
+                            throw new keras.Error('File format is not HDF5 Weights');
                         }
-                        return new keras.Model(metadata, format, '', backend, null, weights);
-                    }
-                    if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
-                        for (const group of groups) {
-                            const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
-                            for (const variableGroup of group.groups.values()) {
-                                if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
-                                    throw new keras.Error('Variable format is not HDF5 Weights');
-                                }
-                                const variable = variableGroup.value;
-                                if (!variable) {
-                                    throw new keras.Error('Variable value is not HDF5 Weights');
-                                }
-                                const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
-                                const tensor = new keras.Tensor(name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
-                                weights.add(moduleName, tensor);
+                        let format = 'HDF5 Weights';
+                        let weightsGroup = rootGroup;
+                        if (rootGroup.attributes.size === 0 && rootGroup.value === null && rootGroup.groups.size == 1) {
+                            const group = rootGroup.groups.values().next().value;
+                            if (group.attributes.size === 0 && group.value === null) {
+                                weightsGroup = group;
                             }
                         }
-                        return new keras.Model(metadata, format, '', backend, null, weights);
-                    }
-                    const walk = function(group) {
-                        if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
-                            for (const subGroup of group.groups.values()) {
-                                walk(subGroup);
+                        const tensorKeys = new Set([ 'name', 'shape', 'quantization' ]);
+                        const groups = Array.from(weightsGroup.groups.values());
+                        if (groups.every((group) => group.attributes.size === 0 && group.groups.length == 0 && group.value !== null)) {
+                            for (const group of groups) {
+                                const variable = group.value;
+                                const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
+                                weights.add('', tensor);
                             }
-                            return;
-                        }
-                        const subKeys = new Set([ 'index', 'need_grad' ]);
-                        const attribtues = Array.from(group.attributes.keys());
-                        const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
-                        if (match && attribtues.length !== 0) {
-                            format = 'nnabla HDF5 Weights';
+                            return openModel(format, '', '', null, weights);
                         }
-                        if (match && group.value !== null && group.groups.size === 0) {
-                            const variable = group.value;
-                            const variableName = group.path;
-                            let moduleName = variableName;
-                            const parts = variableName.split('/');
-                            if (parts.length > 1) {
-                                parts.pop();
-                                moduleName = parts.join('/');
+                        if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
+                            for (const group of groups) {
+                                const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
+                                for (const variableGroup of group.groups.values()) {
+                                    if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
+                                        throw new keras.Error('Variable format is not HDF5 Weights');
+                                    }
+                                    const variable = variableGroup.value;
+                                    if (!variable) {
+                                        throw new keras.Error('Variable value is not HDF5 Weights');
+                                    }
+                                    const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
+                                    const tensor = new keras.Tensor(name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
+                                    weights.add(moduleName, tensor);
+                                }
                             }
-                            const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
-                            weights.add(moduleName, tensor);
-                            return;
+                            return openModel(format, '', '', null, weights);
                         }
-                        throw new keras.Error('Module group format is not HDF5 Weights');
-                    };
-                    walk(weightsGroup);
-                    return new keras.Model(metadata, format, '', backend, null, weights);
-                });
-            }
-            const obj = context.open('json');
-            if (obj) {
-                let rootGroup = null;
-                let model_config = null;
-                let producer = '';
-                const manifests = [];
-                if (obj && Array.isArray(obj) && obj.every((manifest) => Array.isArray(manifest.weights) && Array.isArray(manifest.paths))) {
-                    format = 'TensorFlow.js Weights';
-                    rootGroup = {};
-                    manifests.push(...obj);
-                    for (const manifest of manifests) {
-                        for (const weight of manifest.weights) {
-                            const parts = weight.name.split('/');
-                            parts.pop();
-                            weight.identifier = parts.join('/');
-                        }
-                    }
-                }
-                else {
-                    if (obj.keras_version) {
-                        const version = obj.keras_version;
-                        format = format + (version ? (' v' + version) : '');
-                    }
-                    if (obj.backend) {
-                        backend = obj.backend;
-                    }
-                    model_config = obj;
-                    if (model_config && model_config.modelTopology) {
-                        backend = model_config.modelTopology.backend;
-                        const version = model_config.modelTopology.keras_version;
-                        format = format + (version ? (' v' + version) : '');
-                        format = 'TensorFlow.js ' + (model_config.format ? model_config.format : format);
-                        producer = model_config.convertedBy || model_config.generatedBy || '';
-                        manifests.push(...model_config.weightsManifest);
-                        for (const manifest of manifests) {
-                            for (const weight of manifest.weights) {
-                                weight.identifier = '';
+                        const walk = function(group) {
+                            if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
+                                for (const subGroup of group.groups.values()) {
+                                    walk(subGroup);
+                                }
+                                return;
                             }
-                        }
-                        model_config = model_config.modelTopology;
+                            const subKeys = new Set([ 'index', 'need_grad' ]);
+                            const attribtues = Array.from(group.attributes.keys());
+                            const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
+                            if (match && attribtues.length !== 0) {
+                                format = 'nnabla HDF5 Weights';
+                            }
+                            if (match && group.value !== null && group.groups.size === 0) {
+                                const variable = group.value;
+                                const variableName = group.path;
+                                let moduleName = variableName;
+                                const parts = variableName.split('/');
+                                if (parts.length > 1) {
+                                    parts.pop();
+                                    moduleName = parts.join('/');
+                                }
+                                const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
+                                weights.add(moduleName, tensor);
+                                return;
+                            }
+                            throw new keras.Error('Module group format is not HDF5 Weights');
+                        };
+                        walk(weightsGroup);
+                        return openModel(format, '', '', null, weights);
                     }
-                    if (model_config.model_config) {
-                        model_config = model_config.model_config;
+                });
+            }
+            case 'keras.json': {
+                const obj = context.open('json');
+                const format = 'Keras' + (obj.keras_version ? ' v' + obj.keras_version : '');
+                const backend = obj.backend || '';
+                const config = obj.model_config ? obj.model_config : obj;
+                const weights = new keras.Weights();
+                return openModel(format, '', backend, config, weights);
+            }
+            case 'keras.json.tfjs': {
+                const obj = context.open('json');
+                const modelTopology = obj.modelTopology;
+                const backend = modelTopology.backend || '';
+                const format = 'TensorFlow.js ' + (obj.format ? obj.format : 'Keras' + (modelTopology.keras_version ? (' v' + modelTopology.keras_version) : ''));
+                const producer = obj.convertedBy || obj.generatedBy || '';
+                const manifests = obj.weightsManifest;
+                for (const manifest of manifests) {
+                    for (const weight of manifest.weights) {
+                        weight.identifier = '';
                     }
                 }
-                if (!rootGroup && !model_config) {
-                    throw new keras.Error('\'model_config\' is not present.');
-                }
-                if (!rootGroup && !model_config.class_name) {
-                    throw new keras.Error('\'class_name\' is not present.');
-                }
-                const shards = new Map();
+                const model_config = modelTopology.model_config ? modelTopology.model_config : modelTopology;
+                return openManifests(manifests).then((weights) => {
+                    return openModel(format, producer, backend, model_config, weights);
+                });
+            }
+            case 'keras.json.tfjs.weights': {
+                const obj = context.open('json');
+                const manifests = [];
+                const format = 'TensorFlow.js Weights';
+                manifests.push(...obj);
                 for (const manifest of manifests) {
-                    for (const path of manifest.paths) {
-                        if (!shards.has(path)) {
-                            shards.set(path, context.request(path, null));
-                        }
+                    for (const weight of manifest.weights) {
+                        const parts = weight.name.split('/');
+                        parts.pop();
+                        weight.identifier = parts.join('/');
                     }
                 }
-                const create = (shards) => {
-                    const dtype_size_map = new Map([ [ 'float16', 2 ], [ 'float32', 4 ], [ 'float64', 8 ], [ 'int8', 1 ], [ 'int16', 2 ], [ 'int32', 4 ], [ 'int64', 8 ], [ 'uint8', 1 ], [ 'uint16', 2 ], [ 'uint32', 4 ], [ 'uint64', 8 ] ]);
-                    for (const manifest of manifests) {
-                        let buffer = null;
-                        if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
-                            const list = manifest.paths.map((path) => shards.get(path));
-                            const size = list.reduce((a, b) => a + b.length, 0);
-                            buffer = new Uint8Array(size);
-                            let offset = 0;
-                            for (const item of list) {
-                                buffer.set(item, offset);
-                                offset += item.length;
-                            }
-                        }
-                        let offset = 0;
-                        for (const weight of manifest.weights) {
-                            const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
-                            if (!dtype_size_map.has(dtype)) {
-                                throw new keras.Error("Unknown weight data type size '" + dtype + "'.");
-                            }
-                            const itemsize = dtype_size_map.get(dtype);
-                            const size = weight.shape.reduce((a, b) => a * b, 1);
-                            const length = itemsize * size;
-                            const data = buffer ? buffer.slice(offset, offset + length) : null;
-                            weights.add(weight.identifier, new keras.Tensor(weight.name, weight.shape, dtype, weight.quantization, true, data));
-                            offset += length;
-                        }
-                    }
-                    return new keras.Model(metadata, format, producer, backend, model_config, weights);
-                };
-                return Promise.all(shards.values()).then((streams) => {
-                    for (const key of shards.keys()) {
-                        shards.set(key, streams.shift().peek());
-                    }
-                    return create(shards);
-                }).catch(() => {
-                    shards.clear();
-                    return create(shards);
+                return openManifests(manifests).then((weights) => {
+                    return openModel(format, '', '', null, weights);
                 });
             }
-            throw new keras.Error('Unsupported Keras format.');
-        });
+            default: {
+                throw new keras.Error("Unsupported Keras format '" + this._format(context) + "'.");
+            }
+        }
+    }
+
+    _format(context) {
+        const stream = context.stream;
+        const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
+        if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
+            return 'keras.h5';
+        }
+        const obj = context.open('json');
+        if (obj) {
+            if (obj.mxnet_version) {
+                return '';
+            }
+            if (obj.nodes && obj.arg_nodes && obj.heads) {
+                return '';
+            }
+            if (obj.modelTopology) {
+                if (obj.format === 'layers-model' || obj.modelTopology.class_name || obj.modelTopology.model_config) {
+                    return 'keras.json.tfjs';
+                }
+            }
+            if (obj.model_config || (obj.class_name && obj.config)) {
+                return 'keras.json';
+            }
+            if (Array.isArray(obj) && obj.every((item) => item.weights && item.paths)) {
+                return 'keras.json.tfjs.weights';
+            }
+        }
+        return '';
     }
 };