|
|
@@ -35,38 +35,51 @@ keras.ModelFactory = class {
|
|
|
open(context) {
|
|
|
return keras.Metadata.open(context).then((metadata) => {
|
|
|
let format = 'Keras';
|
|
|
- let producer = '';
|
|
|
let backend = '';
|
|
|
- let model_config = null;
|
|
|
- let rootGroup = null;
|
|
|
const weights = new keras.Weights();
|
|
|
- const manifests = [];
|
|
|
const stream = context.stream;
|
|
|
const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
|
|
|
if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
return context.require('./hdf5').then((hdf5) => {
|
|
|
- const buffer = stream.peek();
|
|
|
- const file = new hdf5.File(buffer);
|
|
|
- rootGroup = file.rootGroup;
|
|
|
- if (rootGroup.attribute('model_config') || rootGroup.attribute('layer_names')) {
|
|
|
- const model_config_json = rootGroup.attribute('model_config');
|
|
|
- if (model_config_json) {
|
|
|
- const reader = json.TextReader.open(model_config_json);
|
|
|
- model_config = reader.read();
|
|
|
+ const file = hdf5.File.open(stream);
|
|
|
+ let rootGroup = file.rootGroup;
|
|
|
+ const read_model_config = (group) => {
|
|
|
+ if (group.attributes.has('model_config')) {
|
|
|
+ const buffer = rootGroup.attributes.get('model_config');
|
|
|
+ const reader = json.TextReader.open(buffer);
|
|
|
+ return reader.read();
|
|
|
}
|
|
|
- backend = rootGroup.attribute('backend') || '';
|
|
|
- const version = rootGroup.attribute('keras_version') || '';
|
|
|
- format = format + (version ? ' v' + version : '');
|
|
|
- let model_weights_group = rootGroup.group('model_weights');
|
|
|
- if (!model_weights_group && rootGroup.attribute('layer_names')) {
|
|
|
- model_weights_group = rootGroup;
|
|
|
+ return null;
|
|
|
+ };
|
|
|
+ const load_attributes_from_hdf5_group = (group, name) => {
|
|
|
+ if (group.attributes.has(name)) {
|
|
|
+ return group.attributes.get(name);
|
|
|
+ }
|
|
|
+ if (group.attributes.has(name + '0')) {
|
|
|
+ let index = 0;
|
|
|
+ let value = [];
|
|
|
+ while (group.attributes.has(name + index.toString())) {
|
|
|
+ const chunk = group.attributes.get(name + index.toString());
|
|
|
+ value = value.concat(chunk);
|
|
|
+ index++;
|
|
|
+ }
|
|
|
+ return value;
|
|
|
}
|
|
|
+ return null;
|
|
|
+ };
|
|
|
+ const model_config = read_model_config(rootGroup);
|
|
|
+ const layer_names = load_attributes_from_hdf5_group(rootGroup, 'layer_names');
|
|
|
+ if (model_config || (layer_names && Array.isArray(layer_names))) {
|
|
|
+ backend = rootGroup.attributes.get('backend') || '';
|
|
|
+ const version = rootGroup.attributes.get('keras_version') || '';
|
|
|
+ format = format + (version ? ' v' + version : '');
|
|
|
+ const model_weights_group = layer_names ? rootGroup : rootGroup.group('model_weights');
|
|
|
if (model_weights_group) {
|
|
|
- model_weights_group = new keras.Group(model_weights_group);
|
|
|
- for (const layer_name of model_weights_group.attribute('layer_names')) {
|
|
|
+ const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
|
|
|
+ for (const layer_name of layer_names) {
|
|
|
const layer_weights = model_weights_group.group(layer_name);
|
|
|
if (layer_weights) {
|
|
|
- const weight_names = layer_weights.attribute('weight_names');
|
|
|
+ const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
|
|
|
if (weight_names && weight_names.length > 0) {
|
|
|
for (const weight_name of weight_names) {
|
|
|
const weight = layer_weights.group(weight_name);
|
|
|
@@ -90,45 +103,31 @@ keras.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
else {
|
|
|
- const attributes = new Set([ 'nb_layers' ]);
|
|
|
- if (Object.keys(rootGroup.attributes).filter((name) => !attributes.has(name)).length !== 0 || rootGroup.value !== null) {
|
|
|
+ const rootKeys = new Set([ 'nb_layers' ]);
|
|
|
+ if (Array.from(rootGroup.attributes.keys()).filter((name) => !rootKeys.has(name)).length !== 0 || rootGroup.value !== null) {
|
|
|
throw new keras.Error('File format is not HDF5 Weights');
|
|
|
}
|
|
|
format = 'HDF5 Weights';
|
|
|
- if (Object.keys(rootGroup.attributes).length === 0 && rootGroup.value === null &&
|
|
|
- rootGroup.groups.length == 1 && rootGroup.groups[0] &&
|
|
|
- Object.keys(rootGroup.groups[0].attributes).length === 0 && rootGroup.groups[0].value === null) {
|
|
|
- rootGroup = rootGroup.groups[0];
|
|
|
+ if (rootGroup.attributes.size === 0 && rootGroup.value === null && rootGroup.groups.size == 1) {
|
|
|
+ const group = rootGroup.groups.values().next().value;
|
|
|
+ if (group.attributes.size === 0 && group.value === null) {
|
|
|
+ rootGroup = group;
|
|
|
+ }
|
|
|
}
|
|
|
- if (rootGroup.groups.every((group) => Object.keys(group.attributes).length === 0 && group.groups.length == 0 && group.value !== null)) {
|
|
|
- for (const group of rootGroup.groups) {
|
|
|
+ const tensorKeys = new Set([ 'name', 'shape', 'quantization' ]);
|
|
|
+ const groups = Array.from(rootGroup.groups.values());
|
|
|
+ if (groups.every((group) => group.attributes.size === 0 && group.groups.length == 0 && group.value !== null)) {
|
|
|
+ for (const group of groups) {
|
|
|
const variable = group.value;
|
|
|
const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
weights.add('', tensor);
|
|
|
}
|
|
|
}
|
|
|
- else if (rootGroup.groups.every((group) => Object.keys(group.attributes).length === 0 && group.value === null)) {
|
|
|
- for (const group of rootGroup.groups) {
|
|
|
- const moduleName = group.attributes.name || group.name;
|
|
|
- for (const variableGroup of group.groups) {
|
|
|
- if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
|
|
|
- throw new keras.Error('Group is not HDF5 tensor variable.');
|
|
|
- }
|
|
|
- const variable = variableGroup.value;
|
|
|
- if (!variable) {
|
|
|
- throw new keras.Error('Variable value is not HDF5 tensor.');
|
|
|
- }
|
|
|
- const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
|
|
|
- const tensor = new keras.Tensor(name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
- weights.add(moduleName, tensor);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- else if (rootGroup.groups.every((group) => group.value === null && group.groups.every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
|
|
|
- for (const group of rootGroup.groups) {
|
|
|
- const moduleName = group.attributes.name || group.name;
|
|
|
- for (const variableGroup of group.groups) {
|
|
|
- if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
|
|
|
+ else if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
|
|
|
+ for (const group of groups) {
|
|
|
+ const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
|
|
|
+ for (const variableGroup of group.groups.values()) {
|
|
|
+ if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
|
|
|
throw new keras.Error('Variable format is not HDF5 Weights');
|
|
|
}
|
|
|
const variable = variableGroup.value;
|
|
|
@@ -143,12 +142,19 @@ keras.ModelFactory = class {
|
|
|
}
|
|
|
else {
|
|
|
const walk = function(group) {
|
|
|
- if (Object.keys(group.attributes).length === 0 && group.value === null && group.groups.length > 0) {
|
|
|
- for (const subGroup of group.groups) {
|
|
|
+ if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
|
|
|
+ for (const subGroup of group.groups.values()) {
|
|
|
walk(subGroup);
|
|
|
}
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const subKeys = new Set([ 'index', 'need_grad' ]);
|
|
|
+ const attribtues = Array.from(group.attributes.keys());
|
|
|
+ const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
|
|
|
+ if (match && attribtues.length !== 0) {
|
|
|
+ format = 'nnabla HDF5 Weights';
|
|
|
}
|
|
|
- else if (Object.keys(group.attributes).length === 0 && group.value !== null && group.groups.length === 0) {
|
|
|
+ if (match && group.value !== null && group.groups.size === 0) {
|
|
|
const variable = group.value;
|
|
|
const variableName = group.path;
|
|
|
let moduleName = variableName;
|
|
|
@@ -159,10 +165,9 @@ keras.ModelFactory = class {
|
|
|
}
|
|
|
const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
weights.add(moduleName, tensor);
|
|
|
+ return;
|
|
|
}
|
|
|
- else {
|
|
|
- throw new keras.Error('Module group format is not HDF5 Weights');
|
|
|
- }
|
|
|
+ throw new keras.Error('Module group format is not HDF5 Weights');
|
|
|
};
|
|
|
walk(rootGroup);
|
|
|
}
|
|
|
@@ -173,11 +178,15 @@ keras.ModelFactory = class {
|
|
|
if (!rootGroup && !model_config.class_name) {
|
|
|
throw new keras.Error('\'class_name\' is not present.');
|
|
|
}
|
|
|
- return new keras.Model(metadata, format, producer, backend, model_config, weights);
|
|
|
+ return new keras.Model(metadata, format, '', backend, model_config, weights);
|
|
|
});
|
|
|
}
|
|
|
const obj = context.open('json');
|
|
|
if (obj) {
|
|
|
+ let rootGroup = null;
|
|
|
+ let model_config = null;
|
|
|
+ let producer = '';
|
|
|
+ const manifests = [];
|
|
|
if (obj && Array.isArray(obj) && obj.every((manifest) => Array.isArray(manifest.weights) && Array.isArray(manifest.paths))) {
|
|
|
format = 'TensorFlow.js Weights';
|
|
|
rootGroup = {};
|
|
|
@@ -670,7 +679,8 @@ keras.Node = class {
|
|
|
}
|
|
|
}
|
|
|
if (name !== 'name' && value !== null) {
|
|
|
- this._attributes.push(new keras.Attribute(metadata.attribute(this.type, name), name, value));
|
|
|
+ const attribute = new keras.Attribute(metadata.attribute(this.type, name), name, value);
|
|
|
+ this._attributes.push(attribute);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -1201,44 +1211,6 @@ keras.Metadata = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-keras.Group = class {
|
|
|
-
|
|
|
- constructor(group) {
|
|
|
- this._group = group;
|
|
|
- }
|
|
|
-
|
|
|
- attribute(name) {
|
|
|
- let value = this._group.attribute(name);
|
|
|
- if (!value) {
|
|
|
- if (this._group.attribute(name + '0')) {
|
|
|
- let index = 0;
|
|
|
- value = [];
|
|
|
- for (;;) {
|
|
|
- const chunk = this._group.attribute(name + index.toString());
|
|
|
- if (!chunk) {
|
|
|
- break;
|
|
|
- }
|
|
|
- value = value.concat(chunk);
|
|
|
- index++;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return value;
|
|
|
- }
|
|
|
-
|
|
|
- group(name) {
|
|
|
- const value = this._group.group(name);
|
|
|
- if (value) {
|
|
|
- return new keras.Group(value);
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- get value() {
|
|
|
- return this._group.value;
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
keras.Weights = class {
|
|
|
|
|
|
constructor() {
|