Bläddra i källkod

Move model creation to ModelFactory

Lutz Roeder 8 år sedan
förälder
incheckning
d696d5345d
9 ändrade filer med 265 tillägg och 261 borttagningar
  1. 26 17
      src/caffe-model.js
  2. 26 17
      src/caffe2-model.js
  3. 26 20
      src/coreml-model.js
  4. 51 51
      src/keras-model.js
  5. 9 8
      src/mxnet-model.js
  6. 19 17
      src/onnx-model.js
  7. 52 49
      src/tf-model.js
  8. 28 23
      src/tflite-model.js
  9. 28 59
      src/view.js

+ 26 - 17
src/caffe-model.js

@@ -4,34 +4,36 @@
 
 var caffe = null;
 
-class CaffeModel {
+class CaffeModelFactory {
+
+    match(buffer, identifier) {
+        var extension = identifier.split('.').pop();
+        return extension == 'caffemodel';
+    }
 
-    static open(buffer, identifier, host, callback) { 
+    open(buffer, identifier, host, callback) { 
         host.import('/caffe.js', (err) => {
             if (err) {
                 callback(err, null);
             }
             else {
                 caffe = protobuf.roots.caffe.caffe;
-                CaffeModel.create(buffer, identifier, host, (err, model) => {
-                    callback(err, model);
-                });
+                try {
+                    var netParameter = caffe.NetParameter.decode(buffer);
+                    var model = new CaffeModel(netParameter);
+                    CaffeOperatorMetadata.open(host, (err, metadata) => {
+                        callback(null, model);
+                    });
+                }
+                catch (error) {
+                    callback(new CaffeError(error.message), null);
+                }
             }
         });
     }
+}
 
-    static create(buffer, identifier, host, callback) {
-        try {
-            var netParameter = caffe.NetParameter.decode(buffer);
-            var model = new CaffeModel(netParameter);
-            CaffeOperatorMetadata.open(host, (err, metadata) => {
-                callback(null, model);
-            });
-        }
-        catch (err) {
-            callback(err, null);
-        }
-    }
+class CaffeModel {
 
     constructor(netParameter) {
         if (netParameter.layers && netParameter.layers.length > 0) {
@@ -561,4 +563,11 @@ class CaffeOperatorMetadata
         }
         return true;
     }
+}
+
+class CaffeError extends Error {
+    constructor(message) {
+        super(message);
+        this.name = 'Caffe Error';
+    }
 }

+ 26 - 17
src/caffe2-model.js

@@ -4,34 +4,36 @@
 
 var caffe2 = null;
 
-class Caffe2Model {
+class Caffe2ModelFactory {
+
+    match(buffer, identifier) {
+        return identifier.endsWith('predict_net.pb');
+    }    
 
-    static open(buffer, identifier, host, callback) { 
+    open(buffer, identifier, host, callback) {
         host.import('/caffe2.js', (err) => {
             if (err) {
                 callback(err, null);
             }
             else {
                 caffe2 = protobuf.roots.caffe2.caffe2;
-                Caffe2Model.create(buffer, identifier, host, (err, model) => {
-                    callback(err, model);
-                });
+                try {
+                    var netDef = caffe2.NetDef.decode(buffer);
+                    var model = new Caffe2Model(netDef);
+                    Caffe2OperatorMetadata.open(host, (err, metadata) => {
+                        callback(null, model);
+                    });
+                }
+                catch (error) {
+                    callback(new Caffe2Error(error.message), null);
+                }
             }
         });
     }
 
-    static create(buffer, identifier, host, callback) {
-        try {
-            var netDef = caffe2.NetDef.decode(buffer);
-            var model = new Caffe2Model(netDef);
-            Caffe2OperatorMetadata.open(host, (err, metadata) => {
-                callback(null, model);
-            });
-        }
-        catch (err) {
-            callback(err, null);
-        }
-    }
+}
+
+class Caffe2Model {
 
     constructor(netDef) {
         var graph = new Caffe2Graph(netDef);
@@ -471,3 +473,10 @@ class Caffe2OperatorMetadata
         return true;
     }
 }
+
+class Caffe2Error extends Error {
+    constructor(message) {
+        super(message);
+        this.name = 'Caffe2 Error';
+    }
+}

+ 26 - 20
src/coreml-model.js

@@ -1,37 +1,37 @@
 /*jshint esversion: 6 */
 
-// Experimental
-
 var coreml = null;
 
-class CoreMLModel {
+class CoreMLModelFactory {
+
+    match(buffer, identifier) {
+        var extension = identifier.split('.').pop();
+        return extension == 'mlmodel';
+    }
 
-    static open(buffer, identifier, host, callback) { 
+    open(buffer, identifier, host, callback) { 
         host.import('/coreml.js', (err) => {
             if (err) {
                 callback(err, null);
             }
             else {
                 coreml = protobuf.roots.coreml.CoreML.Specification;
-                CoreMLModel.create(buffer, identifier, host, (err, model) => {
-                    callback(err, model);
-                });
+                try {
+                    var decodedBuffer = coreml.Model.decode(buffer);
+                    var model = new CoreMLModel(decodedBuffer);
+                    CoreMLOperatorMetadata.open(host, (err, metadata) => {
+                        callback(null, model);
+                    });
+                }
+                catch (error) {
+                    callback(new CoreMLError(error.message), null);
+                }
             }
         });
     }
+}
 
-    static create(buffer, identifier, host, callback) {
-        try {
-            var decodedBuffer = coreml.Model.decode(buffer);
-            var model = new CoreMLModel(decodedBuffer);
-            CoreMLOperatorMetadata.open(host, (err, metadata) => {
-                callback(null, model);
-            });
-        }
-        catch (err) {
-            callback(err, null);
-        }
-    }
+class CoreMLModel {
 
     constructor(model) {
         this._specificationVersion = model.specificationVersion;
@@ -849,5 +849,11 @@ class CoreMLOperatorMetadata
         }
         return '';
     }
-
 }
+
+class CoreMLError extends Error {
+    constructor(message) {
+        super(message);
+        this.name = 'CoreML Error';
+    }
+}

+ 51 - 51
src/keras-model.js

@@ -1,66 +1,66 @@
 /*jshint esversion: 6 */
 
-class KerasModel {
+class KerasModelFactory {
+
+    match(buffer, identifier) {
+        var extension = identifier.split('.').pop();
+        return (extension == 'keras' || extension == 'h5' || extension == 'json');
+    }
 
-    static open(buffer, identifier, host, callback) { 
+    open(buffer, identifier, host, callback) {
         host.import('/hdf5.js', (err) => {
             if (err) {
                 callback(err, null);
             }
             else {
-                KerasModel.create(buffer, identifier, host, (err, model) => {
-                    callback(err, model);
-                });
-            }
-        });
-    }
-
-    static create(buffer, identifier, host, callback) {
-        try {
-            var format = 'Keras';
-            var rootGroup = null;
-            var rootJson = null;
-            var model_config = null;
-
-            var extension = identifier.split('.').pop();
-            if (extension == 'keras' || extension == 'h5') {
-                var file = new hdf5.File(buffer);
-                rootGroup = file.rootGroup;
-                var modelConfigJson = rootGroup.attributes.model_config;
-                if (!modelConfigJson) {
-                    throw new KerasError('HDF5 file does not contain a \'model_config\' graph. Use \'save()\' instead of \'save_weights()\' to save both the graph and weights.');
+                try {
+                    var format = 'Keras';
+                    var rootGroup = null;
+                    var rootJson = null;
+                    var model_config = null;
+                    var extension = identifier.split('.').pop();
+                    if (extension == 'keras' || extension == 'h5') {
+                        var file = new hdf5.File(buffer);
+                        rootGroup = file.rootGroup;
+                        var modelConfigJson = rootGroup.attributes.model_config;
+                        if (!modelConfigJson) {
+                            callback(new KerasError('HDF5 file does not contain a \'model_config\' graph. Use \'save()\' instead of \'save_weights()\' to save both the graph and weights.'), null);
+                            return;
+                        }
+                        model_config = JSON.parse(modelConfigJson);
+                    }
+                    else if (extension == 'json') {
+                        var decoder = new window.TextDecoder('utf-8');
+                        var json = decoder.decode(buffer);
+                        model_config = JSON.parse(json);
+                        if (model_config && model_config.modelTopology && model_config.modelTopology.model_config) {
+                            format = 'TensorFlow.js ' + format;
+                            rootJson = model_config;
+                            model_config = model_config.modelTopology.model_config;
+                        }
+                    }
+                    if (!model_config) {
+                        callback(new KerasError('\'model_config\' is not present.'));
+                    }
+                    else if (!model_config.class_name) {
+                        callback(new KerasError('\'class_name\' is not present.'), null);
+                    }
+                    else {
+                        var model = new KerasModel(format, model_config, rootGroup, rootJson);
+                        KerasOperatorMetadata.open(host, (err, metadata) => {
+                            callback(null, model);
+                        });
+                    }
                 }
-                model_config = JSON.parse(modelConfigJson);
-            }
-            else if (extension == 'json') {
-                var decoder = new window.TextDecoder('utf-8');
-                var json = decoder.decode(buffer);
-                model_config = JSON.parse(json);
-                if (model_config && model_config.modelTopology && model_config.modelTopology.model_config) {
-                    format = 'TensorFlow.js ' + format;
-                    rootJson = model_config;
-                    model_config = model_config.modelTopology.model_config;
+                catch (error) {
+                    callback(new KerasError(error.message), null);
                 }
             }
-
-            if (!model_config) {
-                throw new KerasError('model_config is not present.');
-            }
-
-            if (!model_config.class_name) {
-                throw new KerasError('class_name is not present.');
-            }
-    
-            var model = new KerasModel(format, model_config, rootGroup, rootJson);
-
-            KerasOperatorMetadata.open(host, (err, metadata) => {
-                callback(null, model);
-            });
-        }
-        catch (err) {
-            callback(err, null);
-        }
+        });
     }
+}
+
+class KerasModel {
 
     constructor(format, model_config, rootGroup, rootJson) {
         this._format = format;

+ 9 - 8
src/mxnet-model.js

@@ -2,29 +2,30 @@
 
 // Experimental
 
-class MXNetModel {
+class MXNetModelFactory {
 
-    static open(buffer, identifier, host, callback) { 
-        MXNetModel.create(buffer, identifier, host, (err, model) => {
-            callback(err, model);
-        });
+    match(buffer, identifier) {
+        return identifier.endsWith('-symbol.json');
     }
 
-    static create(buffer, identifier, host, callback) {
+    open(buffer, identifier, host, callback) {
         try {
             var decoder = new TextDecoder('utf-8');
             var json = decoder.decode(buffer);
-
             var model = new MXNetModel(json);
             MXNetOperatorMetadata.open(host, (err, metadata) => {
                 callback(null, model);
             });
         }
         catch (err) {
-            callback(err, null);
+            callback(new MXNetError(err.message), null);
         }
     }
 
+}
+
+class MXNetModel {
+
     constructor(json) {
         var model = JSON.parse(json);
         if (!model) {

+ 19 - 17
src/onnx-model.js

@@ -2,34 +2,36 @@
 
 var onnx = null;
 
-class OnnxModel {
+class OnnxModelFactory {
+
+    match(buffer, identifier) {
+        var extension = identifier.split('.').pop();
+        return (identifier != 'saved_model.pb') && (identifier != 'predict_net.pb') && (extension == 'onnx' || extension == 'pb');
+    }
 
-    static open(buffer, identifier, host, callback) { 
+    open(buffer, identifier, host, callback) { 
         host.import('/onnx.js', (err) => {
             if (err) {
                 callback(err, null);
             }
             else {
                 onnx = protobuf.roots.onnx.onnx;
-                OnnxModel.create(buffer, host, (err, model) => {
-                    callback(err, model);
-                });
+                try {
+                    var model = onnx.ModelProto.decode(buffer);
+                    var result = new OnnxModel(model);
+                    OnnxOperatorMetadata.open(host, (err, metadata) => {
+                        callback(null, result);
+                    });
+                }
+                catch (error) {
+                    callback(new OnnxError(error.message), null);
+                }
             }
         });
     }
+}
 
-    static create(buffer, host, callback) {
-        try {
-            var model = onnx.ModelProto.decode(buffer);
-            var result = new OnnxModel(model);
-            OnnxOperatorMetadata.open(host, (err, metadata) => {
-                callback(null, result);
-            });
-        }
-        catch (err) {
-            callback(err, null);
-        }
-    }
+class OnnxModel {
 
     constructor(model) {
         this._model = model;

+ 52 - 49
src/tf-model.js

@@ -4,67 +4,69 @@
 
 var tensorflow = null;
 
-class TensorFlowModel {
+class TensorFlowModelFactory {
+
+    match(buffer, identifier) {
+        var extension = identifier.split('.').pop();
+        return (identifier != 'predict_net.pb') && (extension == 'pb' || extension == 'meta');
+    }
 
-    static open(buffer, identifier, host, callback) { 
+    open(buffer, identifier, host, callback) { 
         host.import('/tf.js', (err) => {
             if (err) {
                 callback(err, null);
             }
             else {
                 tensorflow = protobuf.roots.tf.tensorflow;
-                var model = TensorFlowModel.create(buffer, identifier, host, (err, model) => {
-                    callback(err, model);
-                });
-            }
-        });
-    }
-
-    static create(buffer, identifier, host, callback) {
-        try {
-            var model = null;
-            var format = null;
-            if (identifier == 'saved_model.pb') {
-                model = tensorflow.SavedModel.decode(buffer);
-                format = 'TensorFlow Saved Model';
-                if (model.savedModelSchemaVersion) {
-                    format = format + ' v' + model.savedModelSchemaVersion.toString();
-                }
-            }
-            else {
-                var metaGraphDef = null;
-                var extension = identifier.split('.').pop();
-                if (extension != 'meta') {
-                    try {
-                        var graphDef = tensorflow.GraphDef.decode(buffer);
-                        metaGraphDef = new tensorflow.MetaGraphDef();
-                        metaGraphDef.graphDef = graphDef;
-                        metaGraphDef.anyInfo = identifier;
-                        format = 'TensorFlow Graph';
+                try {
+                    var model = null;
+                    var format = null;
+                    if (identifier == 'saved_model.pb') {
+                        model = tensorflow.SavedModel.decode(buffer);
+                        format = 'TensorFlow Saved Model';
+                        if (model.savedModelSchemaVersion) {
+                            format = format + ' v' + model.savedModelSchemaVersion.toString();
+                        }
                     }
-                    catch (err) {
+                    else {
+                        var metaGraphDef = null;
+                        var extension = identifier.split('.').pop();
+                        if (extension != 'meta') {
+                            try {
+                                var graphDef = tensorflow.GraphDef.decode(buffer);
+                                metaGraphDef = new tensorflow.MetaGraphDef();
+                                metaGraphDef.graphDef = graphDef;
+                                metaGraphDef.anyInfo = identifier;
+                                format = 'TensorFlow Graph';
+                            }
+                            catch (metaError) {
+                            }
+                        }
+        
+                        if (!metaGraphDef) {
+                            metaGraphDef = tensorflow.MetaGraphDef.decode(buffer);
+                            format = 'TensorFlow MetaGraph';
+                        }
+        
+                        model = new tensorflow.SavedModel();
+                        model.metaGraphs.push(metaGraphDef);
                     }
+        
+                    var result = new TensorFlowModel(model, format);
+        
+                    TensorFlowOperatorMetadata.open(host, (err, metadata) => {
+                        callback(null, result);
+                    });
                 }
-
-                if (!metaGraphDef) {
-                    metaGraphDef = tensorflow.MetaGraphDef.decode(buffer);
-                    format = 'TensorFlow MetaGraph';
-                }
-
-                model = new tensorflow.SavedModel();
-                model.metaGraphs.push(metaGraphDef);
+                catch (error) {
+                    callback(new TensorFlowError(error.message), null);
+                }    
             }
-
-            var result = new TensorFlowModel(model, format);
-
-            TensorFlowOperatorMetadata.open(host, (err, metadata) => {
-                callback(null, result);
-            });
-        }
-        catch (err) {
-            callback(err, null);
-        }    
+        });
     }
+}
+
+class TensorFlowModel {
 
     constructor(model, format) {
         this._model = model;
@@ -821,6 +823,7 @@ class TensorFlowGraphOperatorMetadata {
             'Dequantize': 'Tensor',
             'Identity': 'Control',
             'BatchNormWithGlobalNormalization': 'Normalization',
+            'FusedBatchNorm': 'Normalization',
             // 'VariableV2':
             // 'Assign':
             // 'BiasAdd':

+ 28 - 23
src/tflite-model.js

@@ -1,37 +1,42 @@
 /*jshint esversion: 6 */
 
-class TensorFlowLiteModel {
-    
-    static open(buffer, identifier, host, callback) { 
+class TensorFlowLiteModelFactory {
+
+
+    match(buffer, identifier) {
+        var extension = identifier.split('.').pop();
+        return extension == 'tflite';
+    }
+
+    open(buffer, identifier, host, callback) {
         host.import('/tflite.js', (err) => {
             if (err) {
                 callback(err, null);
             }
             else {
-                TensorFlowLiteModel.create(buffer, identifier, host, (err, model) => {
-                    callback(err, model);
-                });
+                try {
+                    var byteBuffer = new flatbuffers.ByteBuffer(buffer);
+                    if (!tflite.Model.bufferHasIdentifier(byteBuffer))
+                    {
+                        callback(new TensorFlowLiteError('Invalid FlatBuffers identifier.'));
+                    }
+                    else {
+                        var model = tflite.Model.getRootAsModel(byteBuffer);
+                        model = new TensorFlowLiteModel(model);
+                        TensorFlowLiteOperatorMetadata.open(host, (err, metadata) => {
+                            callback(null, model);
+                        });
+                    }
+                }
+                catch (error) {
+                    callback(new TensorFlowLiteError(error.message), null);
+                }
             }
         });
     }
+}
 
-    static create(buffer, identifier, host, callback) { 
-        try {
-            var byteBuffer = new flatbuffers.ByteBuffer(buffer);
-            if (!tflite.Model.bufferHasIdentifier(byteBuffer))
-            {
-                throw new TensorFlowLiteError('Invalid FlatBuffers identifier.');
-            }
-            var model = tflite.Model.getRootAsModel(byteBuffer);
-            model = new TensorFlowLiteModel(model);
-            TensorFlowLiteOperatorMetadata.open(host, (err, metadata) => {
-                callback(null, model);
-            });
-        }
-        catch (err) {
-            callback(err, null);
-        }
-    }
+class TensorFlowLiteModel {
 
     constructor(model) {
         this._model = model;

+ 28 - 59
src/view.js

@@ -69,66 +69,35 @@ class View {
     }
 
     loadBuffer(buffer, identifier, callback) {
-        var model = null;
-        var err = null;
-    
-        var extension = identifier.split('.').pop();
-
-        if (extension == 'tflite') {
-            TensorFlowLiteModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-           });
-        }
-        else if (extension == 'onnx') {
-            OnnxModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
-        else if (extension == 'mlmodel') {
-            CoreMLModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
-        else if (extension == 'caffemodel') {
-            CaffeModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
-        else if (identifier.endsWith('predict_net.pb')) {
-            Caffe2Model.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
-        else if (identifier.endsWith('-symbol.json')) {
-            MXNetModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
-        else if (extension == 'keras' || extension == 'h5' || extension == 'json') {
-            KerasModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
-        else if (identifier == 'saved_model.pb' || extension == 'meta') {
-            TensorFlowModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
-        else if (extension == 'pb') {
-            OnnxModel.open(buffer, identifier, this._host, (err, model) => {
-                if (!err) {
-                    callback(err, model);
-                }
-                else {
-                    TensorFlowModel.open(buffer, identifier, this._host, (err, model) => {
+        var modelFactoryRegistry = [
+            new OnnxModelFactory(),
+            new MXNetModelFactory(),
+            new KerasModelFactory(),
+            new CoreMLModelFactory(),
+            new CaffeModelFactory(),
+            new Caffe2ModelFactory(), 
+            new TensorFlowLiteModelFactory(),
+            new TensorFlowModelFactory()
+        ];
+        var matches = modelFactoryRegistry.filter((factory) => factory.match(buffer, identifier));
+        var next = () => {
+            if (matches.length > 0) {
+                var modelFactory = matches.shift();
+                modelFactory.open(buffer, identifier, this._host, (err, model) => {
+                    if (model || matches.length == 0) {
                         callback(err, model);
-                    });
-                }
-            });
-        }
-        else {
-            callback(new Error('Unsupported file extension \'.' + extension + '\'.'), null);
-        }
+                    }
+                    else {
+                        next();
+                    }
+                });
+            }
+            else {
+                var extension = identifier.split('.').pop();
+                callback(new Error('Unsupported file extension \'.' + extension + '\'.'), null);
+            }
+        };
+        next();
     }
 
     openBuffer(buffer, identifier, callback) {