|
|
@@ -6,41 +6,71 @@ var json = json || require('./json');
|
|
|
keras.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
- const stream = context.stream;
|
|
|
- const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
|
|
|
- if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
- return true;
|
|
|
- }
|
|
|
- const obj = context.open('json');
|
|
|
- if (obj) {
|
|
|
- if (obj.mxnet_version) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- if (obj.nodes && obj.arg_nodes && obj.heads) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- if (obj.modelTopology && (obj.format === 'layers-model' || obj.modelTopology.class_name || obj.modelTopology.model_config)) {
|
|
|
- return true;
|
|
|
- }
|
|
|
- if (obj.model_config || (obj.class_name && obj.config)) {
|
|
|
- return true;
|
|
|
- }
|
|
|
- if (Array.isArray(obj) && obj.every((item) => item.weights && item.paths)) {
|
|
|
- return true;
|
|
|
- }
|
|
|
- }
|
|
|
- return false;
|
|
|
+ return this._format(context).length > 0;
|
|
|
}
|
|
|
|
|
|
open(context) {
|
|
|
- return keras.Metadata.open(context).then((metadata) => {
|
|
|
- let format = 'Keras';
|
|
|
- let backend = '';
|
|
|
+ const openModel = (format, producer, backend, config, weights) => {
|
|
|
+ return keras.Metadata.open(context).then((metadata) => {
|
|
|
+ return new keras.Model(metadata, format, producer, backend, config, weights);
|
|
|
+ });
|
|
|
+ };
|
|
|
+ const openShards = (manifests, shards) => {
|
|
|
const weights = new keras.Weights();
|
|
|
- const stream = context.stream;
|
|
|
- const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
|
|
|
- if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
+ const dtype_size_map = new Map([ [ 'float16', 2 ], [ 'float32', 4 ], [ 'float64', 8 ], [ 'int8', 1 ], [ 'int16', 2 ], [ 'int32', 4 ], [ 'int64', 8 ], [ 'uint8', 1 ], [ 'uint16', 2 ], [ 'uint32', 4 ], [ 'uint64', 8 ] ]);
|
|
|
+ for (const manifest of manifests) {
|
|
|
+ let buffer = null;
|
|
|
+ if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
|
|
|
+ const list = manifest.paths.map((path) => shards.get(path));
|
|
|
+ const size = list.reduce((a, b) => a + b.length, 0);
|
|
|
+ buffer = new Uint8Array(size);
|
|
|
+ let offset = 0;
|
|
|
+ for (const item of list) {
|
|
|
+ buffer.set(item, offset);
|
|
|
+ offset += item.length;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ let offset = 0;
|
|
|
+ for (const weight of manifest.weights) {
|
|
|
+ const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
|
|
|
+ if (!dtype_size_map.has(dtype)) {
|
|
|
+ throw new keras.Error("Unknown weight data type size '" + dtype + "'.");
|
|
|
+ }
|
|
|
+ const itemsize = dtype_size_map.get(dtype);
|
|
|
+ const size = weight.shape.reduce((a, b) => a * b, 1);
|
|
|
+ const length = itemsize * size;
|
|
|
+ const data = buffer ? buffer.slice(offset, offset + length) : null;
|
|
|
+ weights.add(weight.identifier, new keras.Tensor(weight.name, weight.shape, dtype, weight.quantization, true, data));
|
|
|
+ offset += length;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return Promise.resolve(weights);
|
|
|
+ };
|
|
|
+ const openManifests = (manifests) => {
|
|
|
+ const shards = new Map();
|
|
|
+ for (const manifest of manifests) {
|
|
|
+ for (const path of manifest.paths) {
|
|
|
+ if (!shards.has(path)) {
|
|
|
+ shards.set(path, context.request(path, null));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const promises = shards.values();
|
|
|
+ return Promise.all(promises).then((streams) => {
|
|
|
+ for (const key of shards.keys()) {
|
|
|
+ shards.set(key, streams.shift().peek());
|
|
|
+ }
|
|
|
+ return openShards(manifests, shards);
|
|
|
+ }).catch(() => {
|
|
|
+ shards.clear();
|
|
|
+ return openShards(manifests, shards);
|
|
|
+ });
|
|
|
+ };
|
|
|
+ const stream = context.stream;
|
|
|
+ switch (this._format(context)) {
|
|
|
+ case 'keras.h5': {
|
|
|
return context.require('./hdf5').then((hdf5) => {
|
|
|
+ const weights = new keras.Weights();
|
|
|
const file = hdf5.File.open(stream);
|
|
|
const rootGroup = file.rootGroup;
|
|
|
const read_model_config = (group) => {
|
|
|
@@ -69,9 +99,9 @@ keras.ModelFactory = class {
|
|
|
};
|
|
|
const model_config = read_model_config(rootGroup);
|
|
|
if (model_config) {
|
|
|
- backend = rootGroup.attributes.get('backend') || '';
|
|
|
+ const backend = rootGroup.attributes.get('backend') || '';
|
|
|
const version = rootGroup.attributes.get('keras_version') || '';
|
|
|
- format = format + (version ? ' v' + version : '');
|
|
|
+ const format = 'Keras' + (version ? ' v' + version : '');
|
|
|
const model_weights_group = rootGroup.group('model_weights');
|
|
|
if (model_weights_group) {
|
|
|
const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
|
|
|
@@ -98,13 +128,13 @@ keras.ModelFactory = class {
|
|
|
if (!model_config.class_name) {
|
|
|
throw new keras.Error("'class_name' is not present.");
|
|
|
}
|
|
|
- return new keras.Model(metadata, format, '', backend, model_config, weights);
|
|
|
+ return openModel(format, '', backend, model_config, weights);
|
|
|
}
|
|
|
const layer_names = load_attributes_from_hdf5_group(rootGroup, 'layer_names');
|
|
|
if (layer_names && Array.isArray(layer_names)) {
|
|
|
const version = rootGroup.attributes.get('keras_version') || '';
|
|
|
- format = 'Keras Weights' + (version ? ' v' + version : '');
|
|
|
- backend = rootGroup.attributes.get('backend') || '';
|
|
|
+ const format = 'Keras Weights' + (version ? ' v' + version : '');
|
|
|
+ const backend = rootGroup.attributes.get('backend') || '';
|
|
|
for (const layer_name of layer_names) {
|
|
|
const layer_weights = rootGroup.group(layer_name);
|
|
|
if (layer_weights) {
|
|
|
@@ -124,182 +154,157 @@ keras.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- return new keras.Model(metadata, format, '', backend, null, weights);
|
|
|
+ return openModel(format, '', backend, null, weights);
|
|
|
}
|
|
|
- const rootKeys = new Set(rootGroup.attributes.keys());
|
|
|
- rootKeys.delete('nb_layers');
|
|
|
- if (rootKeys.size > 0 || rootGroup.value !== null) {
|
|
|
- throw new keras.Error('File format is not HDF5 Weights');
|
|
|
- }
|
|
|
- format = 'HDF5 Weights';
|
|
|
- let weightsGroup = rootGroup;
|
|
|
- if (rootGroup.attributes.size === 0 && rootGroup.value === null && rootGroup.groups.size == 1) {
|
|
|
- const group = rootGroup.groups.values().next().value;
|
|
|
- if (group.attributes.size === 0 && group.value === null) {
|
|
|
- weightsGroup = group;
|
|
|
- }
|
|
|
- }
|
|
|
- const tensorKeys = new Set([ 'name', 'shape', 'quantization' ]);
|
|
|
- const groups = Array.from(weightsGroup.groups.values());
|
|
|
- if (groups.every((group) => group.attributes.size === 0 && group.groups.length == 0 && group.value !== null)) {
|
|
|
- for (const group of groups) {
|
|
|
- const variable = group.value;
|
|
|
- const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
- weights.add('', tensor);
|
|
|
+ else {
|
|
|
+ const rootKeys = new Set(rootGroup.attributes.keys());
|
|
|
+ rootKeys.delete('nb_layers');
|
|
|
+ if (rootKeys.size > 0 || rootGroup.value !== null) {
|
|
|
+ throw new keras.Error('File format is not HDF5 Weights');
|
|
|
}
|
|
|
- return new keras.Model(metadata, format, '', backend, null, weights);
|
|
|
- }
|
|
|
- if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
|
|
|
- for (const group of groups) {
|
|
|
- const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
|
|
|
- for (const variableGroup of group.groups.values()) {
|
|
|
- if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
|
|
|
- throw new keras.Error('Variable format is not HDF5 Weights');
|
|
|
- }
|
|
|
- const variable = variableGroup.value;
|
|
|
- if (!variable) {
|
|
|
- throw new keras.Error('Variable value is not HDF5 Weights');
|
|
|
- }
|
|
|
- const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
|
|
|
- const tensor = new keras.Tensor(name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
- weights.add(moduleName, tensor);
|
|
|
+ let format = 'HDF5 Weights';
|
|
|
+ let weightsGroup = rootGroup;
|
|
|
+ if (rootGroup.attributes.size === 0 && rootGroup.value === null && rootGroup.groups.size == 1) {
|
|
|
+ const group = rootGroup.groups.values().next().value;
|
|
|
+ if (group.attributes.size === 0 && group.value === null) {
|
|
|
+ weightsGroup = group;
|
|
|
}
|
|
|
}
|
|
|
- return new keras.Model(metadata, format, '', backend, null, weights);
|
|
|
- }
|
|
|
- const walk = function(group) {
|
|
|
- if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
|
|
|
- for (const subGroup of group.groups.values()) {
|
|
|
- walk(subGroup);
|
|
|
+ const tensorKeys = new Set([ 'name', 'shape', 'quantization' ]);
|
|
|
+ const groups = Array.from(weightsGroup.groups.values());
|
|
|
+ if (groups.every((group) => group.attributes.size === 0 && group.groups.length == 0 && group.value !== null)) {
|
|
|
+ for (const group of groups) {
|
|
|
+ const variable = group.value;
|
|
|
+ const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
+ weights.add('', tensor);
|
|
|
}
|
|
|
- return;
|
|
|
- }
|
|
|
- const subKeys = new Set([ 'index', 'need_grad' ]);
|
|
|
- const attribtues = Array.from(group.attributes.keys());
|
|
|
- const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
|
|
|
- if (match && attribtues.length !== 0) {
|
|
|
- format = 'nnabla HDF5 Weights';
|
|
|
+ return openModel(format, '', '', null, weights);
|
|
|
}
|
|
|
- if (match && group.value !== null && group.groups.size === 0) {
|
|
|
- const variable = group.value;
|
|
|
- const variableName = group.path;
|
|
|
- let moduleName = variableName;
|
|
|
- const parts = variableName.split('/');
|
|
|
- if (parts.length > 1) {
|
|
|
- parts.pop();
|
|
|
- moduleName = parts.join('/');
|
|
|
+ if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
|
|
|
+ for (const group of groups) {
|
|
|
+ const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
|
|
|
+ for (const variableGroup of group.groups.values()) {
|
|
|
+ if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
|
|
|
+ throw new keras.Error('Variable format is not HDF5 Weights');
|
|
|
+ }
|
|
|
+ const variable = variableGroup.value;
|
|
|
+ if (!variable) {
|
|
|
+ throw new keras.Error('Variable value is not HDF5 Weights');
|
|
|
+ }
|
|
|
+ const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
|
|
|
+ const tensor = new keras.Tensor(name, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
+ weights.add(moduleName, tensor);
|
|
|
+ }
|
|
|
}
|
|
|
- const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
- weights.add(moduleName, tensor);
|
|
|
- return;
|
|
|
+ return openModel(format, '', '', null, weights);
|
|
|
}
|
|
|
- throw new keras.Error('Module group format is not HDF5 Weights');
|
|
|
- };
|
|
|
- walk(weightsGroup);
|
|
|
- return new keras.Model(metadata, format, '', backend, null, weights);
|
|
|
- });
|
|
|
- }
|
|
|
- const obj = context.open('json');
|
|
|
- if (obj) {
|
|
|
- let rootGroup = null;
|
|
|
- let model_config = null;
|
|
|
- let producer = '';
|
|
|
- const manifests = [];
|
|
|
- if (obj && Array.isArray(obj) && obj.every((manifest) => Array.isArray(manifest.weights) && Array.isArray(manifest.paths))) {
|
|
|
- format = 'TensorFlow.js Weights';
|
|
|
- rootGroup = {};
|
|
|
- manifests.push(...obj);
|
|
|
- for (const manifest of manifests) {
|
|
|
- for (const weight of manifest.weights) {
|
|
|
- const parts = weight.name.split('/');
|
|
|
- parts.pop();
|
|
|
- weight.identifier = parts.join('/');
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- else {
|
|
|
- if (obj.keras_version) {
|
|
|
- const version = obj.keras_version;
|
|
|
- format = format + (version ? (' v' + version) : '');
|
|
|
- }
|
|
|
- if (obj.backend) {
|
|
|
- backend = obj.backend;
|
|
|
- }
|
|
|
- model_config = obj;
|
|
|
- if (model_config && model_config.modelTopology) {
|
|
|
- backend = model_config.modelTopology.backend;
|
|
|
- const version = model_config.modelTopology.keras_version;
|
|
|
- format = format + (version ? (' v' + version) : '');
|
|
|
- format = 'TensorFlow.js ' + (model_config.format ? model_config.format : format);
|
|
|
- producer = model_config.convertedBy || model_config.generatedBy || '';
|
|
|
- manifests.push(...model_config.weightsManifest);
|
|
|
- for (const manifest of manifests) {
|
|
|
- for (const weight of manifest.weights) {
|
|
|
- weight.identifier = '';
|
|
|
+ const walk = function(group) {
|
|
|
+ if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
|
|
|
+ for (const subGroup of group.groups.values()) {
|
|
|
+ walk(subGroup);
|
|
|
+ }
|
|
|
+ return;
|
|
|
}
|
|
|
- }
|
|
|
- model_config = model_config.modelTopology;
|
|
|
+ const subKeys = new Set([ 'index', 'need_grad' ]);
|
|
|
+ const attribtues = Array.from(group.attributes.keys());
|
|
|
+ const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
|
|
|
+ if (match && attribtues.length !== 0) {
|
|
|
+ format = 'nnabla HDF5 Weights';
|
|
|
+ }
|
|
|
+ if (match && group.value !== null && group.groups.size === 0) {
|
|
|
+ const variable = group.value;
|
|
|
+ const variableName = group.path;
|
|
|
+ let moduleName = variableName;
|
|
|
+ const parts = variableName.split('/');
|
|
|
+ if (parts.length > 1) {
|
|
|
+ parts.pop();
|
|
|
+ moduleName = parts.join('/');
|
|
|
+ }
|
|
|
+ const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
|
|
|
+ weights.add(moduleName, tensor);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ throw new keras.Error('Module group format is not HDF5 Weights');
|
|
|
+ };
|
|
|
+ walk(weightsGroup);
|
|
|
+ return openModel(format, '', '', null, weights);
|
|
|
}
|
|
|
- if (model_config.model_config) {
|
|
|
- model_config = model_config.model_config;
|
|
|
+ });
|
|
|
+ }
|
|
|
+ case 'keras.json': {
|
|
|
+ const obj = context.open('json');
|
|
|
+ const format = 'Keras' + (obj.keras_version ? ' v' + obj.keras_version : '');
|
|
|
+ const backend = obj.backend || '';
|
|
|
+ const config = obj.model_config ? obj.model_config : obj;
|
|
|
+ const weights = new keras.Weights();
|
|
|
+ return openModel(format, '', backend, config, weights);
|
|
|
+ }
|
|
|
+ case 'keras.json.tfjs': {
|
|
|
+ const obj = context.open('json');
|
|
|
+ const modelTopology = obj.modelTopology;
|
|
|
+ const backend = modelTopology.backend || '';
|
|
|
+ const format = 'TensorFlow.js ' + (obj.format ? obj.format : 'Keras' + (modelTopology.keras_version ? (' v' + modelTopology.keras_version) : ''));
|
|
|
+ const producer = obj.convertedBy || obj.generatedBy || '';
|
|
|
+ const manifests = obj.weightsManifest;
|
|
|
+ for (const manifest of manifests) {
|
|
|
+ for (const weight of manifest.weights) {
|
|
|
+ weight.identifier = '';
|
|
|
}
|
|
|
}
|
|
|
- if (!rootGroup && !model_config) {
|
|
|
- throw new keras.Error('\'model_config\' is not present.');
|
|
|
- }
|
|
|
- if (!rootGroup && !model_config.class_name) {
|
|
|
- throw new keras.Error('\'class_name\' is not present.');
|
|
|
- }
|
|
|
- const shards = new Map();
|
|
|
+ const model_config = modelTopology.model_config ? modelTopology.model_config : modelTopology;
|
|
|
+ return openManifests(manifests).then((weights) => {
|
|
|
+ return openModel(format, producer, backend, model_config, weights);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ case 'keras.json.tfjs.weights': {
|
|
|
+ const obj = context.open('json');
|
|
|
+ const manifests = [];
|
|
|
+ const format = 'TensorFlow.js Weights';
|
|
|
+ manifests.push(...obj);
|
|
|
for (const manifest of manifests) {
|
|
|
- for (const path of manifest.paths) {
|
|
|
- if (!shards.has(path)) {
|
|
|
- shards.set(path, context.request(path, null));
|
|
|
- }
|
|
|
+ for (const weight of manifest.weights) {
|
|
|
+ const parts = weight.name.split('/');
|
|
|
+ parts.pop();
|
|
|
+ weight.identifier = parts.join('/');
|
|
|
}
|
|
|
}
|
|
|
- const create = (shards) => {
|
|
|
- const dtype_size_map = new Map([ [ 'float16', 2 ], [ 'float32', 4 ], [ 'float64', 8 ], [ 'int8', 1 ], [ 'int16', 2 ], [ 'int32', 4 ], [ 'int64', 8 ], [ 'uint8', 1 ], [ 'uint16', 2 ], [ 'uint32', 4 ], [ 'uint64', 8 ] ]);
|
|
|
- for (const manifest of manifests) {
|
|
|
- let buffer = null;
|
|
|
- if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
|
|
|
- const list = manifest.paths.map((path) => shards.get(path));
|
|
|
- const size = list.reduce((a, b) => a + b.length, 0);
|
|
|
- buffer = new Uint8Array(size);
|
|
|
- let offset = 0;
|
|
|
- for (const item of list) {
|
|
|
- buffer.set(item, offset);
|
|
|
- offset += item.length;
|
|
|
- }
|
|
|
- }
|
|
|
- let offset = 0;
|
|
|
- for (const weight of manifest.weights) {
|
|
|
- const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
|
|
|
- if (!dtype_size_map.has(dtype)) {
|
|
|
- throw new keras.Error("Unknown weight data type size '" + dtype + "'.");
|
|
|
- }
|
|
|
- const itemsize = dtype_size_map.get(dtype);
|
|
|
- const size = weight.shape.reduce((a, b) => a * b, 1);
|
|
|
- const length = itemsize * size;
|
|
|
- const data = buffer ? buffer.slice(offset, offset + length) : null;
|
|
|
- weights.add(weight.identifier, new keras.Tensor(weight.name, weight.shape, dtype, weight.quantization, true, data));
|
|
|
- offset += length;
|
|
|
- }
|
|
|
- }
|
|
|
- return new keras.Model(metadata, format, producer, backend, model_config, weights);
|
|
|
- };
|
|
|
- return Promise.all(shards.values()).then((streams) => {
|
|
|
- for (const key of shards.keys()) {
|
|
|
- shards.set(key, streams.shift().peek());
|
|
|
- }
|
|
|
- return create(shards);
|
|
|
- }).catch(() => {
|
|
|
- shards.clear();
|
|
|
- return create(shards);
|
|
|
+ return openManifests(manifests).then((weights) => {
|
|
|
+ return openModel(format, '', '', null, weights);
|
|
|
});
|
|
|
}
|
|
|
- throw new keras.Error('Unsupported Keras format.');
|
|
|
- });
|
|
|
+ default: {
|
|
|
+ throw new keras.Error("Unsupported Keras format '" + this._format(context) + "'.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ _format(context) {
|
|
|
+ const stream = context.stream;
|
|
|
+ const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
|
|
|
+ if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
+ return 'keras.h5';
|
|
|
+ }
|
|
|
+ const obj = context.open('json');
|
|
|
+ if (obj) {
|
|
|
+ if (obj.mxnet_version) {
|
|
|
+ return '';
|
|
|
+ }
|
|
|
+ if (obj.nodes && obj.arg_nodes && obj.heads) {
|
|
|
+ return '';
|
|
|
+ }
|
|
|
+ if (obj.modelTopology) {
|
|
|
+ if (obj.format === 'layers-model' || obj.modelTopology.class_name || obj.modelTopology.model_config) {
|
|
|
+ return 'keras.json.tfjs';
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (obj.model_config || (obj.class_name && obj.config)) {
|
|
|
+ return 'keras.json';
|
|
|
+ }
|
|
|
+ if (Array.isArray(obj) && obj.every((item) => item.weights && item.paths)) {
|
|
|
+ return 'keras.json.tfjs.weights';
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return '';
|
|
|
}
|
|
|
};
|
|
|
|