|
|
@@ -17,26 +17,41 @@ class KerasModel {
|
|
|
|
|
|
static create(buffer, identifier, host, callback) {
|
|
|
try {
|
|
|
- var version = null;
|
|
|
- var backend = null;
|
|
|
- var json = null;
|
|
|
+ var format = 'Keras';
|
|
|
var rootGroup = null;
|
|
|
+ var rootJson = null;
|
|
|
+ var model_config = null;
|
|
|
|
|
|
var extension = identifier.split('.').pop();
|
|
|
if (extension == 'keras' || extension == 'h5') {
|
|
|
var file = new hdf5.File(buffer);
|
|
|
rootGroup = file.rootGroup;
|
|
|
- json = rootGroup.attributes.model_config;
|
|
|
- if (!json) {
|
|
|
+ var modelConfigJson = rootGroup.attributes.model_config;
|
|
|
+ if (!modelConfigJson) {
|
|
|
throw new KerasError('HDF5 file does not contain a \'model_config\' graph. Use \'save()\' instead of \'save_weights()\' to save both the graph and weights.');
|
|
|
}
|
|
|
+ model_config = JSON.parse(modelConfigJson);
|
|
|
}
|
|
|
else if (extension == 'json') {
|
|
|
var decoder = new window.TextDecoder('utf-8');
|
|
|
- json = decoder.decode(buffer);
|
|
|
+ var json = decoder.decode(buffer);
|
|
|
+ model_config = JSON.parse(json);
|
|
|
+ if (model_config && model_config.modelTopology && model_config.modelTopology.model_config) {
|
|
|
+ format = 'TensorFlow.js ' + format;
|
|
|
+ rootJson = model_config;
|
|
|
+ model_config = model_config.modelTopology.model_config;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- var model = new KerasModel(json, rootGroup);
|
|
|
+ if (!model_config) {
|
|
|
+ throw new KerasError('model_config is not present.');
|
|
|
+ }
|
|
|
+
|
|
|
+ if (!model_config.class_name) {
|
|
|
+ throw new KerasError('class_name is not present.');
|
|
|
+ }
|
|
|
+
|
|
|
+ var model = new KerasModel(format, model_config, rootGroup, rootJson);
|
|
|
|
|
|
KerasOperatorMetadata.open(host, (err, metadata) => {
|
|
|
callback(null, model);
|
|
|
@@ -47,30 +62,58 @@ class KerasModel {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- constructor(json, rootGroup) {
|
|
|
- var model = JSON.parse(json);
|
|
|
- if (!model.class_name) {
|
|
|
- throw new KerasError('class_name is not present.');
|
|
|
- }
|
|
|
- if (rootGroup && rootGroup.attributes.keras_version) {
|
|
|
- this._version = rootGroup.attributes.keras_version;
|
|
|
+ constructor(format, model_config, rootGroup, rootJson) {
|
|
|
+ this._format = format;
|
|
|
+ this._graphs = [];
|
|
|
+
|
|
|
+ var model_weights = null;
|
|
|
+ var weightsManifest = null;
|
|
|
+ if (rootGroup) {
|
|
|
+ if (rootGroup.attributes.keras_version) {
|
|
|
+ this._version = rootGroup.attributes.keras_version;
|
|
|
+ }
|
|
|
+ if (rootGroup.attributes.backend) {
|
|
|
+ this._backend = rootGroup.attributes.backend;
|
|
|
+ }
|
|
|
+ model_weights = rootGroup.group('model_weights');
|
|
|
}
|
|
|
- if (rootGroup && rootGroup.attributes.backend) {
|
|
|
- this._backend = rootGroup.attributes.backend;
|
|
|
+ else if (rootJson) {
|
|
|
+ if (rootJson.modelTopology && rootJson.modelTopology.keras_version) {
|
|
|
+ this._version = rootJson.modelTopology.keras_version;
|
|
|
+ }
|
|
|
+ if (rootJson.modelTopology && rootJson.modelTopology.backend) {
|
|
|
+ this._backend = rootJson.modelTopology.backend;
|
|
|
+ }
|
|
|
+ if (rootJson.weightsManifest) {
|
|
|
+ weightsManifest = {};
|
|
|
+ rootJson.weightsManifest.forEach((manifest) => {
|
|
|
+ var match = false;
|
|
|
+ var key = null;
|
|
|
+ manifest.weights.forEach((weights) => {
|
|
|
+ var name = weights.name.split('/').shift();
|
|
|
+ if (key == null) {
|
|
|
+ key = name;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ else if (key != name) {
|
|
|
+ match = false;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ if (match) {
|
|
|
+ weightsManifest[key] = manifest;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- var model_weights = rootGroup ? rootGroup.group('model_weights') : null;
|
|
|
- this._activeGraph = new KerasGraph(model, model_weights);
|
|
|
- this._graphs = [ this._activeGraph ];
|
|
|
+ this._activeGraph = new KerasGraph(model_config, model_weights, weightsManifest);
|
|
|
+ this._graphs.push(this._activeGraph);
|
|
|
}
|
|
|
|
|
|
get properties() {
|
|
|
var results = [];
|
|
|
|
|
|
- var format = 'Keras';
|
|
|
- if (this._version) {
|
|
|
- format = format + ' v' + this._version;
|
|
|
- }
|
|
|
+ var format = this._format + (this._version ? (' v' + this._version) : '');
|
|
|
results.push({ name: 'Format', value: format });
|
|
|
|
|
|
if (this._backend) {
|
|
|
@@ -95,7 +138,7 @@ class KerasModel {
|
|
|
|
|
|
class KerasGraph {
|
|
|
|
|
|
- constructor(model, model_weights) {
|
|
|
+ constructor(model, model_weights, weightsManifest) {
|
|
|
if (model.name) {
|
|
|
this._name = model.name;
|
|
|
}
|
|
|
@@ -109,10 +152,10 @@ class KerasGraph {
|
|
|
|
|
|
switch (model.class_name) {
|
|
|
case 'Sequential':
|
|
|
- this.loadSequential(model.config, model_weights, '');
|
|
|
+ this.loadSequential(model.config, model_weights, weightsManifest, '');
|
|
|
break;
|
|
|
case 'Model':
|
|
|
- this.loadModel(model.config, model_weights, '', null, null);
|
|
|
+ this.loadModel(model.config, model_weights, weightsManifest, '', null, null);
|
|
|
break;
|
|
|
default:
|
|
|
throw new KerasError('\'' + model.class_name + '\' is not supported.');
|
|
|
@@ -139,7 +182,7 @@ class KerasGraph {
|
|
|
return this._nodes;
|
|
|
}
|
|
|
|
|
|
- loadModel(config, model_weights, group, inputs, outputs) {
|
|
|
+ loadModel(config, model_weights, weightsManifest, group, inputs, outputs) {
|
|
|
if (group) {
|
|
|
this._groups = true;
|
|
|
}
|
|
|
@@ -238,13 +281,13 @@ class KerasGraph {
|
|
|
if (config.layers) {
|
|
|
config.layers.forEach((layer) => {
|
|
|
if (nodeMap[layer.name]) {
|
|
|
- this.loadNode(layer, layer._inputs, layer._outputs, model_weights, group);
|
|
|
+ this.loadNode(layer, layer._inputs, layer._outputs, model_weights, weightsManifest, group);
|
|
|
}
|
|
|
});
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- loadSequential(config, model_weights, group) {
|
|
|
+ loadSequential(config, model_weights, weightsManifest, group) {
|
|
|
if (group) {
|
|
|
this._groups = true;
|
|
|
}
|
|
|
@@ -267,7 +310,7 @@ class KerasGraph {
|
|
|
}
|
|
|
connection = name;
|
|
|
var outputs = [ connection ];
|
|
|
- this.loadNode(layer, inputs, outputs, model_weights, group);
|
|
|
+ this.loadNode(layer, inputs, outputs, model_weights, weightsManifest, group);
|
|
|
});
|
|
|
this._outputs.push({
|
|
|
id: connection,
|
|
|
@@ -276,24 +319,15 @@ class KerasGraph {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- loadNode(layer, inputs, outputs, model_weights, group) {
|
|
|
+ loadNode(layer, inputs, outputs, model_weights, weightsManifest, group) {
|
|
|
var class_name = layer.class_name;
|
|
|
switch (class_name) {
|
|
|
case 'Model':
|
|
|
- this.loadModel(layer.config, model_weights, layer.name, inputs, outputs);
|
|
|
+ this.loadModel(layer.config, model_weights, weightsManifest, layer.name, inputs, outputs);
|
|
|
break;
|
|
|
default:
|
|
|
var config = layer.config;
|
|
|
- var weights = null;
|
|
|
- if (model_weights) {
|
|
|
- if (group) {
|
|
|
- weights = model_weights.group(group);
|
|
|
- }
|
|
|
- else if (config) {
|
|
|
- weights = model_weights.group(config.name);
|
|
|
- }
|
|
|
- }
|
|
|
- this._nodes.push(new KerasNode(class_name, config, inputs, outputs, group, weights));
|
|
|
+ this._nodes.push(new KerasNode(class_name, config, inputs, outputs, group, model_weights, weightsManifest));
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
@@ -318,7 +352,7 @@ class KerasGraph {
|
|
|
|
|
|
class KerasNode {
|
|
|
|
|
|
- constructor(operator, config, inputs, outputs, group, weights) {
|
|
|
+ constructor(operator, config, inputs, outputs, group, model_weights, weightsManifest) {
|
|
|
if (group) {
|
|
|
this._group = group;
|
|
|
}
|
|
|
@@ -336,23 +370,44 @@ class KerasNode {
|
|
|
|
|
|
var name = this.name;
|
|
|
this._initializers = {};
|
|
|
- if (weights) {
|
|
|
- var weight_names = weights.attributes.weight_names;
|
|
|
- if (weight_names) {
|
|
|
- if (group) {
|
|
|
- weight_names = weight_names.filter(weight => weight.startsWith(name + '/'));
|
|
|
- }
|
|
|
- weight_names.forEach((weight_name) => {
|
|
|
- var weight_variable = weights.group(weight_name);
|
|
|
- if (weight_variable) {
|
|
|
- var variable = weight_variable.value;
|
|
|
- if (variable) {
|
|
|
- this._inputs.push(weight_name);
|
|
|
- this._initializers[weight_name] = new KerasTensor(variable);
|
|
|
+
|
|
|
+ if (model_weights) {
|
|
|
+ var weights = null;
|
|
|
+ if (group) {
|
|
|
+ weights = model_weights.group(group);
|
|
|
+ }
|
|
|
+ else if (config) {
|
|
|
+ weights = model_weights.group(config.name);
|
|
|
+ }
|
|
|
+ if (weights) {
|
|
|
+ var weight_names = weights.attributes.weight_names;
|
|
|
+ if (weight_names) {
|
|
|
+ if (group) {
|
|
|
+ weight_names = weight_names.filter(weight => weight.startsWith(name + '/'));
|
|
|
+ }
|
|
|
+ weight_names.forEach((weight_name) => {
|
|
|
+ var weight_variable = weights.group(weight_name);
|
|
|
+ if (weight_variable) {
|
|
|
+ var variable = weight_variable.value;
|
|
|
+ if (variable) {
|
|
|
+ this._inputs.push(weight_name);
|
|
|
+ this._initializers[weight_name] = new KerasTensor(variable.type, variable.shape, variable.rawData, '');
|
|
|
+ }
|
|
|
}
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (weightsManifest) {
|
|
|
+ var manifest = weightsManifest[name];
|
|
|
+ if (manifest) {
|
|
|
+ manifest.weights.forEach((weights) => {
|
|
|
+ if (weights.name) {
|
|
|
+ this._inputs.push(weights.name);
|
|
|
+ this._initializers[weights.name] = new KerasTensor(weights.dtype, weights.shape, null, manifest.paths.join(';'));
|
|
|
}
|
|
|
});
|
|
|
- }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -472,22 +527,35 @@ class KerasAttribute {
|
|
|
|
|
|
class KerasTensor {
|
|
|
|
|
|
- constructor(variable) {
|
|
|
- this._variable = variable;
|
|
|
+ constructor(type, shape, data, reference) {
|
|
|
+ this._type = type;
|
|
|
+ this._shape = shape;
|
|
|
+ this._data = data;
|
|
|
+ this._reference = reference;
|
|
|
}
|
|
|
|
|
|
get kind() {
|
|
|
- return 'Initializer';
|
|
|
+ return 'Weights';
|
|
|
+ }
|
|
|
+
|
|
|
+ get name() {
|
|
|
+ return this._name;
|
|
|
}
|
|
|
|
|
|
get type() {
|
|
|
- return this._variable.type + JSON.stringify(this._variable.shape);
|
|
|
+ return this._type + JSON.stringify(this._shape);
|
|
|
+ }
|
|
|
+
|
|
|
+ get reference() {
|
|
|
+ return this._reference;
|
|
|
}
|
|
|
|
|
|
get value() {
|
|
|
- var rawData = this._variable.rawData;
|
|
|
- if (rawData) {
|
|
|
- switch (this._variable.type) {
|
|
|
+ if (this._reference) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ if (this._data) {
|
|
|
+ switch (this._type) {
|
|
|
case 'float16':
|
|
|
this._precision = 16;
|
|
|
break;
|
|
|
@@ -500,15 +568,13 @@ class KerasTensor {
|
|
|
default:
|
|
|
return 'Tensor data type is not supported.';
|
|
|
}
|
|
|
- this._shape = this._variable.shape;
|
|
|
- this._rawData = new DataView(rawData.buffer, rawData.byteOffset, rawData.byteLength);
|
|
|
+ this._rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
|
|
|
this._index = 0;
|
|
|
this._count = 0;
|
|
|
var result = this.read(0);
|
|
|
delete this._index;
|
|
|
delete this._count;
|
|
|
delete this._rawData;
|
|
|
- delete this._shape;
|
|
|
delete this._precision;
|
|
|
return JSON.stringify(result, null, 4);
|
|
|
}
|