|
|
@@ -1,66 +1,66 @@
|
|
|
/*jshint esversion: 6 */
|
|
|
|
|
|
-class KerasModel {
|
|
|
+class KerasModelFactory {
|
|
|
+
|
|
|
+ match(buffer, identifier) {
|
|
|
+ var extension = identifier.split('.').pop();
|
|
|
+ return (extension == 'keras' || extension == 'h5' || extension == 'json');
|
|
|
+ }
|
|
|
|
|
|
- static open(buffer, identifier, host, callback) {
|
|
|
+ open(buffer, identifier, host, callback) {
|
|
|
host.import('/hdf5.js', (err) => {
|
|
|
if (err) {
|
|
|
callback(err, null);
|
|
|
}
|
|
|
else {
|
|
|
- KerasModel.create(buffer, identifier, host, (err, model) => {
|
|
|
- callback(err, model);
|
|
|
- });
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- static create(buffer, identifier, host, callback) {
|
|
|
- try {
|
|
|
- var format = 'Keras';
|
|
|
- var rootGroup = null;
|
|
|
- var rootJson = null;
|
|
|
- var model_config = null;
|
|
|
-
|
|
|
- var extension = identifier.split('.').pop();
|
|
|
- if (extension == 'keras' || extension == 'h5') {
|
|
|
- var file = new hdf5.File(buffer);
|
|
|
- rootGroup = file.rootGroup;
|
|
|
- var modelConfigJson = rootGroup.attributes.model_config;
|
|
|
- if (!modelConfigJson) {
|
|
|
- throw new KerasError('HDF5 file does not contain a \'model_config\' graph. Use \'save()\' instead of \'save_weights()\' to save both the graph and weights.');
|
|
|
+ try {
|
|
|
+ var format = 'Keras';
|
|
|
+ var rootGroup = null;
|
|
|
+ var rootJson = null;
|
|
|
+ var model_config = null;
|
|
|
+ var extension = identifier.split('.').pop();
|
|
|
+ if (extension == 'keras' || extension == 'h5') {
|
|
|
+ var file = new hdf5.File(buffer);
|
|
|
+ rootGroup = file.rootGroup;
|
|
|
+ var modelConfigJson = rootGroup.attributes.model_config;
|
|
|
+ if (!modelConfigJson) {
|
|
|
+ callback(new KerasError('HDF5 file does not contain a \'model_config\' graph. Use \'save()\' instead of \'save_weights()\' to save both the graph and weights.'), null);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ model_config = JSON.parse(modelConfigJson);
|
|
|
+ }
|
|
|
+ else if (extension == 'json') {
|
|
|
+ var decoder = new window.TextDecoder('utf-8');
|
|
|
+ var json = decoder.decode(buffer);
|
|
|
+ model_config = JSON.parse(json);
|
|
|
+ if (model_config && model_config.modelTopology && model_config.modelTopology.model_config) {
|
|
|
+ format = 'TensorFlow.js ' + format;
|
|
|
+ rootJson = model_config;
|
|
|
+ model_config = model_config.modelTopology.model_config;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (!model_config) {
|
|
|
+ callback(new KerasError('\'model_config\' is not present.'));
|
|
|
+ }
|
|
|
+ else if (!model_config.class_name) {
|
|
|
+ callback(new KerasError('\'class_name\' is not present.'), null);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ var model = new KerasModel(format, model_config, rootGroup, rootJson);
|
|
|
+ KerasOperatorMetadata.open(host, (err, metadata) => {
|
|
|
+ callback(null, model);
|
|
|
+ });
|
|
|
+ }
|
|
|
}
|
|
|
- model_config = JSON.parse(modelConfigJson);
|
|
|
- }
|
|
|
- else if (extension == 'json') {
|
|
|
- var decoder = new window.TextDecoder('utf-8');
|
|
|
- var json = decoder.decode(buffer);
|
|
|
- model_config = JSON.parse(json);
|
|
|
- if (model_config && model_config.modelTopology && model_config.modelTopology.model_config) {
|
|
|
- format = 'TensorFlow.js ' + format;
|
|
|
- rootJson = model_config;
|
|
|
- model_config = model_config.modelTopology.model_config;
|
|
|
+ catch (error) {
|
|
|
+ callback(new KerasError(error.message), null);
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- if (!model_config) {
|
|
|
- throw new KerasError('model_config is not present.');
|
|
|
- }
|
|
|
-
|
|
|
- if (!model_config.class_name) {
|
|
|
- throw new KerasError('class_name is not present.');
|
|
|
- }
|
|
|
-
|
|
|
- var model = new KerasModel(format, model_config, rootGroup, rootJson);
|
|
|
-
|
|
|
- KerasOperatorMetadata.open(host, (err, metadata) => {
|
|
|
- callback(null, model);
|
|
|
- });
|
|
|
- }
|
|
|
- catch (err) {
|
|
|
- callback(err, null);
|
|
|
- }
|
|
|
+ });
|
|
|
}
|
|
|
+}
|
|
|
+
|
|
|
+class KerasModel {
|
|
|
|
|
|
constructor(format, model_config, rootGroup, rootJson) {
|
|
|
this._format = format;
|