|
|
@@ -12,7 +12,7 @@ keras.ModelFactory = class {
|
|
|
const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
|
|
|
const group = await context.peek('hdf5');
|
|
|
if (group && group.attributes && group.attributes.get('CLASS') !== 'hickle') {
|
|
|
- if (identifier === 'model.weights.h5') {
|
|
|
+ if (identifier.endsWith('.weights.h5')) {
|
|
|
return context.set('keras.model.weights.h5', group);
|
|
|
}
|
|
|
if (identifier === 'parameter.h5') {
|