|
|
@@ -13,12 +13,12 @@ keras.ModelFactory = class {
|
|
|
const group = await context.peek('hdf5');
|
|
|
if (group && group.attributes && group.attributes.get('CLASS') !== 'hickle') {
|
|
|
if (identifier === 'model.weights.h5') {
|
|
|
- return context.match('keras.model.weights.h5', group);
|
|
|
+ return context.set('keras.model.weights.h5', group);
|
|
|
}
|
|
|
if (identifier === 'parameter.h5') {
|
|
|
- return context.match('hdf5.parameter.h5', group);
|
|
|
+ return context.set('hdf5.parameter.h5', group);
|
|
|
}
|
|
|
- return context.match('keras.h5', group);
|
|
|
+ return context.set('keras.h5', group);
|
|
|
}
|
|
|
const json = await context.peek('json');
|
|
|
if (json) {
|
|
|
@@ -26,27 +26,27 @@ keras.ModelFactory = class {
|
|
|
return null;
|
|
|
}
|
|
|
if (json.model_config || (json.class_name && json.config)) {
|
|
|
- return context.match('keras.config.json', json);
|
|
|
+ return context.set('keras.config.json', json);
|
|
|
}
|
|
|
if (identifier === 'metadata.json' && json.keras_version) {
|
|
|
- return context.match('keras.metadata.json', json);
|
|
|
+ return context.set('keras.metadata.json', json);
|
|
|
}
|
|
|
}
|
|
|
const container = await tfjs.Container.open(context);
|
|
|
if (container) {
|
|
|
- return context.match('tfjs', container);
|
|
|
+ return context.set('tfjs', container);
|
|
|
}
|
|
|
const pickle = await context.peek('pkl');
|
|
|
if (pickle && pickle.__class__ &&
|
|
|
pickle.__class__.__module__ === 'keras.engine.sequential' &&
|
|
|
pickle.__class__.__name__ === 'Sequential') {
|
|
|
- return context.match('tfjs.pickle', pickle);
|
|
|
+ return context.set('tfjs.pickle', pickle);
|
|
|
}
|
|
|
// model.weights.npz
|
|
|
const entries = await context.peek('npz');
|
|
|
const regex = /^(__root__|layers\/.+|_layer_checkpoint_dependencies\/.+)\.npy$/;
|
|
|
if (entries instanceof Map && entries.size > 0 && Array.from(entries).every(([name]) => regex.test(name))) {
|
|
|
- return context.match('keras.model.weights.npz', entries);
|
|
|
+ return context.set('keras.model.weights.npz', entries);
|
|
|
}
|
|
|
// keras_metadata.pb
|
|
|
if (extension === 'pb' && context.stream && context.stream.length > 16) {
|
|
|
@@ -56,7 +56,7 @@ keras.ModelFactory = class {
|
|
|
const buffer = stream.peek(Math.min(stream.length, 1024));
|
|
|
const content = String.fromCharCode.apply(null, buffer);
|
|
|
if (/root"/.test(content) && /\{\s*"class_name"\s*:/.test(content)) {
|
|
|
- return context.match('keras.pb.SavedMetadata');
|
|
|
+ return context.set('keras.pb.SavedMetadata');
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -222,7 +222,7 @@ keras.ModelFactory = class {
|
|
|
};
|
|
|
switch (context.type) {
|
|
|
case 'keras.config.json': {
|
|
|
- const obj = context.target;
|
|
|
+ const obj = context.value;
|
|
|
const config = obj.model_config ? obj.model_config : obj;
|
|
|
const backend = obj.backend || '';
|
|
|
let version = obj.keras_version ? obj.keras_version : null;
|
|
|
@@ -238,7 +238,7 @@ keras.ModelFactory = class {
|
|
|
return open_model(format, '', backend, config, null);
|
|
|
}
|
|
|
case 'keras.model.weights.h5': {
|
|
|
- const group = context.target;
|
|
|
+ const group = context.value;
|
|
|
const weights_store = read_weights_hdf5(group);
|
|
|
const metadata = await request_json(context, 'metadata.json');
|
|
|
let config = await request_json(context, 'config.json');
|
|
|
@@ -252,7 +252,7 @@ keras.ModelFactory = class {
|
|
|
return await open_model(format, '', '', config, null);
|
|
|
}
|
|
|
case 'keras.model.weights.npz': {
|
|
|
- const entries = context.target;
|
|
|
+ const entries = context.value;
|
|
|
const weights_store = read_weights_numpy(entries);
|
|
|
const metadata = await request_json(context, 'metadata.json');
|
|
|
let config = await request_json(context, 'config.json');
|
|
|
@@ -266,7 +266,7 @@ keras.ModelFactory = class {
|
|
|
return await open_model(format, '', '', config, null);
|
|
|
}
|
|
|
case 'keras.metadata.json': {
|
|
|
- const metadata = context.target;
|
|
|
+ const metadata = context.value;
|
|
|
let config = await request_json(context, 'config.json');
|
|
|
const name = config ? 'Keras' : 'Keras Weights';
|
|
|
const format = name + (metadata.keras_version ? ` v${metadata.keras_version}` : '');
|
|
|
@@ -317,7 +317,7 @@ keras.ModelFactory = class {
|
|
|
return null;
|
|
|
};
|
|
|
const weights = new keras.Weights();
|
|
|
- const group = context.target;
|
|
|
+ const group = context.value;
|
|
|
const root_group = find_root_group(group);
|
|
|
const model_config = read_model_config(root_group);
|
|
|
if (model_config) {
|
|
|
@@ -452,12 +452,12 @@ keras.ModelFactory = class {
|
|
|
return open_model(format, '', '', null, weights);
|
|
|
}
|
|
|
case 'tfjs': {
|
|
|
- const target = context.target;
|
|
|
+ const target = context.value;
|
|
|
await target.read();
|
|
|
return open_model(target.format, target.producer, target.backend, target.config, target.weights);
|
|
|
}
|
|
|
case 'keras.pickle': {
|
|
|
- const obj = context.target;
|
|
|
+ const obj = context.value;
|
|
|
const execution = new python.Execution();
|
|
|
const decoder = new TextDecoder('utf-8');
|
|
|
const format = `Keras Pickle${obj.keras_version ? ` v${decoder.decode(obj.keras_version)}` : ''}`;
|