Browse Source

Add Caffe2 module exports

Lutz Roeder 7 years ago
parent
commit
f99e46aa11
5 changed files with 64 additions and 61 deletions
  1. 1 1
      setup.py
  2. 60 57
      src/caffe2.js
  3. 1 1
      src/view-browser.html
  4. 1 1
      src/view-electron.html
  5. 1 1
      src/view.js

+ 1 - 1
setup.py

@@ -91,7 +91,7 @@ setuptools.setup(
             'keras.js', 'keras-metadata.json', 'hdf5.js',
             'coreml.js', 'coreml-metadata.json', 'coreml-proto.js',
             'caffe.js', 'caffe-metadata.json', 'caffe-proto.js',
-            'caffe2-model.js', 'caffe2-metadata.json', 'caffe2-proto.js',
+            'caffe2.js', 'caffe2-metadata.json', 'caffe2-proto.js',
             'mxnet.js', 'mxnet-metadata.json',
             'cntk-model.js', 'cntk-metadata.json', 'cntk-proto.js',
             'pytorch.js', 'pytorch-metadata.json', 'pickle.js',

+ 60 - 57
src/caffe2-model.js → src/caffe2.js

@@ -1,8 +1,8 @@
 /*jshint esversion: 6 */
 
-var caffe2 = null;
+var caffe2 = caffe2 || {};
 
-class Caffe2ModelFactory {
+caffe2.ModelFactory = class {
 
     match(context, host) {
         var identifier = context.identifier;
@@ -28,31 +28,31 @@ class Caffe2ModelFactory {
             var extension = context.identifier.split('.').pop();
             if (extension == 'pbtxt' || extension == 'prototxt') {
                 try {
-                    caffe2 = protobuf.roots.caffe2.caffe2;
-                    netDef = caffe2.NetDef.decodeText(context.text);
+                    caffe2.proto = protobuf.roots.caffe2.caffe2;
+                    netDef = caffe2.proto.NetDef.decodeText(context.text);
                 }
                 catch (error) {
                     host.exception(error, false);
-                    callback(new Caffe2Error("File text format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'."), null);
+                    callback(new caffe2.Error("File text format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'."), null);
                     return;
                 }    
             }
             else {
                 try {
-                    caffe2 = protobuf.roots.caffe2.caffe2;
-                    netDef = caffe2.NetDef.decode(context.buffer);
+                    caffe2.proto = protobuf.roots.caffe2.caffe2;
+                    netDef = caffe2.proto.NetDef.decode(context.buffer);
                 }
                 catch (error) {
-                    callback(new Caffe2Error("File format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'."), null);
+                    callback(new caffe2.Error("File format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'."), null);
                     return;
                 }    
             }
-            Caffe2OperatorMetadata.open(host, (err, metadata) => {
+            caffe2.OperatorMetadata.open(host, (err, metadata) => {
                 context.request('init_net.pb', null, (err, data) => {
                     var init = null;
                     if (!err && data) {
                         try {
-                            init = caffe2.NetDef.decode(data);
+                            init = caffe2.proto.NetDef.decode(data);
                         }
                         catch (error) {
                         }
@@ -60,11 +60,11 @@ class Caffe2ModelFactory {
 
                     var model = null;
                     try {
-                        model = new Caffe2Model(netDef, init);
+                        model = new caffe2.Model(netDef, init);
                     }
                     catch (error) {
                         host.exception(error, false);
-                        callback(new Caffe2Error(error.message), null);
+                        callback(new caffe2.Error(error.message), null);
                         return;
                     }
                     callback(null, model);
@@ -73,12 +73,12 @@ class Caffe2ModelFactory {
         });
     }
 
-}
+};
 
-class Caffe2Model {
+caffe2.Model = class {
 
     constructor(netDef, init) {
-        var graph = new Caffe2Graph(netDef, init);
+        var graph = new caffe2.Graph(netDef, init);
         this._graphs = [ graph ];
     }
 
@@ -89,9 +89,9 @@ class Caffe2Model {
     get graphs() {
         return this._graphs;
     }
-}
+};
 
-class Caffe2Graph {
+caffe2.Graph = class {
 
     constructor(netDef, init) {
         this._name = netDef.name ? netDef.name : '';
@@ -155,20 +155,20 @@ class Caffe2Graph {
 
         netDef.op.forEach((op) => {
             this._operators[op.type] = (this._operators[op.type] || 0) + 1;
-            this._nodes.push(new Caffe2Node(op, initializers));
+            this._nodes.push(new caffe2.Node(op, initializers));
         });
 
         this._inputs = [];
         var inputs = Object.keys(initializers);
         inputs.forEach((input) => {
             if (inputs.length == 1 || !input.startsWith('caffe.')) {
-                this._inputs.push(new Caffe2Argument(input, [ new Caffe2Connection(input, null, null) ]));
+                this._inputs.push(new caffe2.Argument(input, [ new caffe2.Connection(input, null, null) ]));
             }
         });
 
         this._outputs = [];
         netDef.external_output.forEach((output) => {
-            this._outputs.push(new Caffe2Argument(output, [ new Caffe2Connection(output, null, null) ]));
+            this._outputs.push(new caffe2.Argument(output, [ new caffe2.Connection(output, null, null) ]));
         });
     }
 
@@ -195,9 +195,9 @@ class Caffe2Graph {
     get operators() {
         return this._operators;
     }
-}
+};
 
-class Caffe2Argument {
+caffe2.Argument = class {
     constructor(name, connections) {
         this._name = name;
         this._connections = connections;
@@ -214,9 +214,9 @@ class Caffe2Argument {
     get connections() {
         return this._connections;
     }
-}
+};
 
-class Caffe2Connection {
+caffe2.Connection = class {
     constructor(id, type, initializer) {
         this._id = id;
         this._type = type || null;
@@ -237,9 +237,9 @@ class Caffe2Connection {
     get initializer() {
         return this._initializer;
     }
-}
+};
 
-class Caffe2Node {
+caffe2.Node = class {
 
     constructor(op, initializers) {
         if (op.name) {
@@ -254,7 +254,7 @@ class Caffe2Node {
 
         this._attributes = [];
         op.arg.forEach((arg) => {
-            this._attributes.push(new Caffe2Attribute(this, arg));
+            this._attributes.push(new caffe2.Attribute(this, arg));
         });
 
         this._initializers = {};
@@ -262,7 +262,7 @@ class Caffe2Node {
             if (index > 0) {
                 var tensor = initializers[input];
                 if (tensor) {
-                    this._initializers[input] = new Caffe2Tensor(input, tensor, 'Initializer');
+                    this._initializers[input] = new caffe2.Tensor(input, tensor, 'Initializer');
                     delete initializers[input];
                 }
             }
@@ -282,28 +282,28 @@ class Caffe2Node {
     }
 
     get category() {
-        var schema = Caffe2OperatorMetadata.operatorMetadata.getSchema(this._operator);
+        var schema = caffe2.OperatorMetadata.operatorMetadata.getSchema(this._operator);
         return (schema && schema.category) ? schema.category : null;
     }
 
     get documentation() {
-        return Caffe2OperatorMetadata.operatorMetadata.getOperatorDocumentation(this._operator);
+        return caffe2.OperatorMetadata.operatorMetadata.getOperatorDocumentation(this._operator);
     }
 
     get inputs() {
-        var inputs = Caffe2OperatorMetadata.operatorMetadata.getInputs(this._operator, this._inputs);
+        var inputs = caffe2.OperatorMetadata.operatorMetadata.getInputs(this._operator, this._inputs);
         return inputs.map((input) => {
-            return new Caffe2Argument(input.name, input.connections.map((connection) => {
-                return new Caffe2Connection(connection.id, null, this._initializers[connection.id]);
+            return new caffe2.Argument(input.name, input.connections.map((connection) => {
+                return new caffe2.Connection(connection.id, null, this._initializers[connection.id]);
             }));
         });
     }
 
     get outputs() {
-        var outputs = Caffe2OperatorMetadata.operatorMetadata.getOutputs(this._operator, this._outputs);
+        var outputs = caffe2.OperatorMetadata.operatorMetadata.getOutputs(this._operator, this._outputs);
         return outputs.map((output) => {
-            return new Caffe2Argument(output.name, output.connections.map((connection) => {
-                return new Caffe2Connection(connection.id, null, null);
+            return new caffe2.Argument(output.name, output.connections.map((connection) => {
+                return new caffe2.Connection(connection.id, null, null);
             }));
         });
     }
@@ -311,9 +311,9 @@ class Caffe2Node {
     get attributes() {
         return this._attributes;
     }
-}
+};
 
-class Caffe2Attribute {
+caffe2.Attribute = class {
 
     constructor(node, arg) {
         this._node = node;
@@ -337,7 +337,7 @@ class Caffe2Attribute {
             this._value = arg.i;
         }
 
-        var schema = Caffe2OperatorMetadata.operatorMetadata.getAttributeSchema(this._node.operator, this._name);
+        var schema = caffe2.OperatorMetadata.operatorMetadata.getAttributeSchema(this._node.operator, this._name);
         if (schema) {
             if (schema.hasOwnProperty('type')) {
                 this._type = schema.type;
@@ -368,9 +368,9 @@ class Caffe2Attribute {
     get visible() {
         return this._visible == false ? false : true;
     }
-}
+};
 
-class Caffe2Tensor {
+caffe2.Tensor = class {
 
     constructor(name, tensor, kind) {
         this._name = name;
@@ -389,7 +389,7 @@ class Caffe2Tensor {
         if (args.values) {
             this._values = args.values;
         }
-        this._type = new Caffe2TensorType(tensor.dataType, new Caffe2TensorShape(shape));
+        this._type = new caffe2.TensorType(tensor.dataType, new caffe2.TensorShape(shape));
     }
 
     get name() {
@@ -493,9 +493,9 @@ class Caffe2Tensor {
         }
         return results;
     }
-}
+};
 
-class Caffe2TensorType {
+caffe2.TensorType = class {
 
     constructor(dataType, shape) {
         this._dataType = dataType;
@@ -513,9 +513,9 @@ class Caffe2TensorType {
     toString() {
         return this.dataType + this._shape.toString();
     }
-}
+};
 
-class Caffe2TensorShape {
+caffe2.TensorShape = class {
 
     constructor(dimensions) {
         this._dimensions = dimensions;
@@ -528,20 +528,18 @@ class Caffe2TensorShape {
     toString() {
         return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
     }
+};
 
-}
-
-class Caffe2OperatorMetadata 
-{
+caffe2.OperatorMetadata = class {
 
     static open(host, callback) {
-        if (Caffe2OperatorMetadata.operatorMetadata) {
-            callback(null, Caffe2OperatorMetadata.operatorMetadata);
+        if (caffe2.OperatorMetadata.operatorMetadata) {
+            callback(null, caffe2.OperatorMetadata.operatorMetadata);
         }
         else {
             host.request(null, 'caffe2-metadata.json', 'utf-8', (err, data) => {
-                Caffe2OperatorMetadata.operatorMetadata = new Caffe2OperatorMetadata(data);
-                callback(null, Caffe2OperatorMetadata.operatorMetadata);
+                caffe2.OperatorMetadata.operatorMetadata = new caffe2.OperatorMetadata(data);
+                callback(null, caffe2.OperatorMetadata.operatorMetadata);
             });
         }    
     }
@@ -689,11 +687,16 @@ class Caffe2OperatorMetadata
         }
         return null;
     }
-}
+};
+
+caffe2.Error = class extends Error {
 
-class Caffe2Error extends Error {
     constructor(message) {
         super(message);
         this.name = 'Error loading Caffe2 model.';
     }
-}
+};
+
+if (module && module.exports) {
+    module.exports.ModelFactory = caffe2.ModelFactory;
+}

+ 1 - 1
src/view-browser.html

@@ -135,7 +135,7 @@
 <script type='text/javascript' src='keras.js'></script>
 <script type='text/javascript' src='coreml.js'></script>
 <script type='text/javascript' src='caffe.js'></script>
-<script type='text/javascript' src='caffe2-model.js'></script>
+<script type='text/javascript' src='caffe2.js'></script>
 <script type='text/javascript' src='pytorch.js'></script>
 <script type='text/javascript' src='mxnet.js'></script>
 <script type='text/javascript' src='cntk-model.js'></script>

+ 1 - 1
src/view-electron.html

@@ -124,7 +124,7 @@
 <script type='text/javascript' src='keras.js'></script>
 <script type='text/javascript' src='coreml.js'></script>
 <script type='text/javascript' src='caffe.js'></script>
-<script type='text/javascript' src='caffe2-model.js'></script>
+<script type='text/javascript' src='caffe2.js'></script>
 <script type='text/javascript' src='pytorch.js'></script>
 <script type='text/javascript' src='sklearn-model.js'></script>
 <script type='text/javascript' src='mxnet.js'></script>

+ 1 - 1
src/view.js

@@ -1191,7 +1191,7 @@ class ModelFactoryService {
             new keras.ModelFactory(),
             new coreml.ModelFactory(),
             new caffe.ModelFactory(),
-            new Caffe2ModelFactory(), 
+            new caffe2.ModelFactory(), 
             new pytorch.ModelFactory(),
             new TensorFlowLiteModelFactory(),
             new TensorFlowModelFactory(),