Преглед изворни кода

Add PyTorch model zoo script

Lutz Roeder пре 7 година
родитељ
комит
4e451b83a1
29 измењених фајлова са 349 додато и 125 уклоњено
  1. 2 2
      package.json
  2. 1 1
      src/caffe.js
  3. 1 1
      src/caffe2.js
  4. 1 1
      src/cntk.js
  5. 1 1
      src/coreml.js
  6. 1 1
      src/onnx.js
  7. 6 4
      src/pytorch.js
  8. 1 1
      src/sklearn.js
  9. 1 1
      src/tf.js
  10. 24 19
      src/tflite.js
  11. 0 1
      src/view-browser.html
  12. 6 2
      src/view-browser.js
  13. 5 15
      src/view-electron.js
  14. 1 1
      src/view.js
  15. 60 0
      test/models.json
  16. 8 4
      test/test.js
  17. 12 1
      tools/caffe
  18. 12 1
      tools/cntk
  19. 21 4
      tools/coreml
  20. 15 1
      tools/keras
  21. 1 2
      tools/keras-script.py
  22. 20 3
      tools/mxnet
  23. 26 5
      tools/onnx
  24. 48 9
      tools/pytorch
  25. 12 1
      tools/pytorch-script.py
  26. 14 2
      tools/sklearn
  27. 24 7
      tools/tf
  28. 20 6
      tools/tflite
  29. 5 28
      tools/update

+ 2 - 2
package.json

@@ -11,9 +11,9 @@
     "repository": "lutzroeder/netron",
     "main": "src/app.js",
     "scripts": {
-        "install": "",
         "start": "[ -d node_modules ] || npm install && npx electron .",
-        "server": "[ -d node_modules ] || npm install && rm -rf build/python && python setup.py --quiet build && PYTHONPATH=build/python/lib python -c 'import netron; netron.main()' $@"
+        "server": "[ -d node_modules ] || npm install && rm -rf build/python && python setup.py --quiet build && PYTHONPATH=build/python/lib python -c 'import netron; netron.main()' $@",
+        "test": "[ -d node_modules ] || npm install && node ./test/test.js"
     },
     "dependencies": {
         "d3": "latest",

+ 1 - 1
src/caffe.js

@@ -26,7 +26,7 @@ caffe.ModelFactory = class {
     }
 
     open(context, host, callback) { 
-        host.require('caffe-proto', (err) => {
+        host.require('./caffe-proto', (err, module) => {
             if (err) {
                 callback(err, null);
                 return;

+ 1 - 1
src/caffe2.js

@@ -21,7 +21,7 @@ caffe2.ModelFactory = class {
     }    
 
     open(context, host, callback) {
-        host.require('caffe2-proto', (err) => {
+        host.require('./caffe2-proto', (err, module) => {
             if (err) {
                 callback(err, null);
                 return;

+ 1 - 1
src/cntk.js

@@ -21,7 +21,7 @@ cntk.ModelFactory = class {
     }
 
     open(context, host, callback) { 
-        host.require('cntk-proto', (err) => {
+        host.require('./cntk-proto', (err, module) => {
             if (err) {
                 callback(err, null);
                 return;

+ 1 - 1
src/coreml.js

@@ -12,7 +12,7 @@ coreml.ModelFactory = class {
     }
 
     open(context, host, callback) { 
-        host.require('coreml-proto', (err) => {
+        host.require('./coreml-proto', (err, module) => {
             if (err) {
                 callback(err, null);
                 return;

+ 1 - 1
src/onnx.js

@@ -35,7 +35,7 @@ onnx.ModelFactory = class {
     }
 
     open(context, host, callback) { 
-        host.require('onnx-proto', (err) => {
+        host.require('./onnx-proto', (err, module) => {
             if (err) {
                 callback(err, null);
                 return;

+ 6 - 4
src/pytorch.js

@@ -22,18 +22,18 @@ pytorch.ModelFactory = class {
     }
 
     open(context, host, callback) { 
-        host.require('pickle', (err) => {
+        host.require('./pickle', (err, pickle) => {
             if (err) {
                 callback(err, null);
                 return;
             }
             pytorch.OperatorMetadata.open(host, (err, metadata) => {
-                this._openModel(context, host, callback);
-            });
+                this._openModel(context, host, pickle, callback);
+            });        
         });
     }
 
-    _openModel(context, host, callback) {
+    _openModel(context, host, pickle, callback) {
         try {
             var identifier = context.identifier;
             var unpickler = new pickle.Unpickler(context.buffer);
@@ -103,6 +103,8 @@ pytorch.ModelFactory = class {
             constructorTable['torch.nn.modules.pooling.AdaptiveAvgPool3d'] = function() {};
             constructorTable['torch.nn.modules.rnn.LSTM'] = function () {};
             constructorTable['torch.nn.modules.sparse.Embedding'] = function () {};
+            constructorTable['torchvision.models.squeezenet.Fire'] = function () {};
+            constructorTable['torchvision.models.squeezenet.SqueezeNet'] = function () {};
             constructorTable['torch.nn.modules.upsampling.Upsample'] = function() {};
             constructorTable['torchvision.models.alexnet.AlexNet'] = function () {};
             constructorTable['torchvision.models.densenet.DenseNet'] = function () {};

+ 1 - 1
src/sklearn.js

@@ -24,7 +24,7 @@ sklearn.ModelFactory = class {
     }
 
     open(context, host, callback) { 
-        host.require('pickle', (err) => {
+        host.require('./pickle', (err, pickle) => {
             if (err) {
                 callback(err, null);
                 return;

+ 1 - 1
src/tf.js

@@ -32,7 +32,7 @@ tf.ModelFactory = class {
     }
 
     open(context, host, callback) { 
-        host.require('tf-proto', (err) => {
+        host.require('./tf-proto', (err, module) => {
             if (err) {
                 callback(err, null);
                 return;

+ 24 - 19
src/tflite.js

@@ -3,7 +3,6 @@
 var tflite = tflite || {};
 var flatbuffers = flatbuffers || require('flatbuffers').flatbuffers;
 var base = base || require('./base');
-var tflite_schema = tflite_schema || require('./tflite-schema');
 
 tflite.ModelFactory = class {
 
@@ -13,32 +12,38 @@ tflite.ModelFactory = class {
     }
 
     open(context, host, callback) {
-        var model = null;
-        try {
-            var buffer = context.buffer;
-            var byteBuffer = new flatbuffers.ByteBuffer(buffer);
-            tflite.schema = tflite_schema;
-            if (!tflite.schema.Model.bufferHasIdentifier(byteBuffer))
-            {
-                var identifier = (buffer && buffer.length >= 8 && buffer.slice(4, 8).every((c) => c >= 32 && c <= 127)) ? String.fromCharCode.apply(null, buffer.slice(4, 8)) : '';
-                callback(new tflite.Error("Invalid FlatBuffers identifier '" + identifier + "' in '" + context.identifier + "'."));
+        host.require('./tflite-schema', (err, tflite_schema) => {
+            if (err) {
+                callback(err, null);
                 return;
             }
-            model = tflite.schema.Model.getRootAsModel(byteBuffer);
-        }
-        catch (error) {
-            host.exception(error, false);
-            callback(new tflite.Error(error.message), null);
-        }
-
-        tflite.OperatorMetadata.open(host, (err, metadata) => {
+            var model = null;
             try {
-                callback(null, new tflite.Model(model));
+                var buffer = context.buffer;
+                var byteBuffer = new flatbuffers.ByteBuffer(buffer);
+                tflite.schema = tflite_schema;
+                if (!tflite.schema.Model.bufferHasIdentifier(byteBuffer))
+                {
+                    var identifier = (buffer && buffer.length >= 8 && buffer.slice(4, 8).every((c) => c >= 32 && c <= 127)) ? String.fromCharCode.apply(null, buffer.slice(4, 8)) : '';
+                    callback(new tflite.Error("Invalid FlatBuffers identifier '" + identifier + "' in '" + context.identifier + "'."));
+                    return;
+                }
+                model = tflite.schema.Model.getRootAsModel(byteBuffer);
             }
             catch (error) {
                 host.exception(error, false);
                 callback(new tflite.Error(error.message), null);
             }
+    
+            tflite.OperatorMetadata.open(host, (err, metadata) => {
+                try {
+                    callback(null, new tflite.Model(model));
+                }
+                catch (error) {
+                    host.exception(error, false);
+                    callback(new tflite.Error(error.message), null);
+                }
+            });
         });
     }
 };

+ 0 - 1
src/view-browser.html

@@ -133,7 +133,6 @@
 <script type='text/javascript' src='hdf5.js'></script>
 <script type='text/javascript' src='onnx.js'></script>
 <script type='text/javascript' src='tf.js'></script>
-<script type='text/javascript' src='tflite-schema.js'></script>
 <script type='text/javascript' src='tflite.js'></script>
 <script type='text/javascript' src='keras.js'></script>
 <script type='text/javascript' src='coreml.js'></script>

+ 6 - 2
src/view-browser.js

@@ -117,6 +117,7 @@ host.BrowserHost = class {
     }
 
     require(id, callback) {
+        window.module = { exports: {} };
         var script = document.scripts.namedItem(id);
         if (script) {
             callback(null);
@@ -127,10 +128,13 @@ host.BrowserHost = class {
         script.setAttribute('type', 'text/javascript');
         script.setAttribute('src', this._url(id + '.js'));
         script.onload = () => {
-            callback(null);
+            var exports = window.module.exports;
+            delete window.module;
+            callback(null, exports);
         };
         script.onerror = (e) => {
-            callback(new Error('The script \'' + e.target.src + '\' failed to load.'));
+            delete window.module;
+            callback(new Error('The script \'' + e.target.src + '\' failed to load.'), null);
         };
         document.head.appendChild(script);
     }

+ 5 - 15
src/view-electron.js

@@ -162,22 +162,12 @@ host.ElectronHost = class {
     }
 
     require(id, callback) {
-        var script = document.scripts.namedItem(id);
-        if (script) {
-            callback(null);
-            return;
+        try {
+            callback(null, require(id));
+        }
+        catch (err) {
+            callback(err, null);
         }
-        script = document.createElement('script');
-        script.setAttribute('id', id);
-        script.setAttribute('type', 'text/javascript');
-        script.setAttribute('src', path.join(__dirname, id + '.js'));
-        script.onload = () => {
-            callback(null);
-        };
-        script.onerror = (e) => {
-            callback(new Error('The script \'' + e.target.src + '\' failed to load.'));
-        };
-        document.head.appendChild(script);
     }
 
     save(name, extension, defaultPath, callback) {

+ 1 - 1
src/view.js

@@ -802,7 +802,7 @@ view.View = class {
                 this.showOperatorDocumentation(node);
             });
             view.on('export-tensor', (sender, tensor) => {
-                this._host.require('numpy', (err) => {
+                this._host.require('./numpy', (err, numpy) => {
                     if (!err) {
                         var defaultPath = tensor.name ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_') : 'tensor';
                         this._host.save('NumPy Array', 'npy', defaultPath, (file) => {

+ 60 - 0
test/models.json

@@ -1227,6 +1227,66 @@
     "target": "mxnet/vgg19_bn.model",
     "source": "https://s3.amazonaws.com/mxnet-model-server/onnx-vgg19_bn/vgg19_bn.model"
   },
+  {
+    "target": "pytorch/alexnet.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/densenet161.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/densenet121.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/inception_v3.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/resnet18.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/resnet50.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/resnet50.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/squeezenet1_0.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/vgg11_bn.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
+  {
+    "target": "pytorch/vgg16.pth",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html",
+    "script": [ "../tools/pytorch", "sync install zoo" ],
+    "status": "script"
+  },
   {
     "target": "tf/densenet.pb",
     "source": "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz[densenet/densenet.pb]"

+ 8 - 4
test/app.js → test/test.js

@@ -15,6 +15,7 @@ const gzip = require('../src/gzip');
 const tar = require('../src/tar');
 
 global.TextDecoder = require('util').TextDecoder;
+global.protobuf = protobuf;
 
 var models = JSON.parse(fs.readFileSync(__dirname + '/models.json', 'utf-8'));
 var folder = __dirname + '/data';
@@ -22,10 +23,13 @@ var folder = __dirname + '/data';
 class TestHost {
 
     require(id, callback) {
-        var filename = path.join(path.join(__dirname, '../src'), id + '.js');
-        var data = fs.readFileSync(filename, 'utf-8');
-        eval(data);
-        callback(null);
+        try {
+            var file = path.join(path.join(__dirname, '../src'), id + '.js');
+            callback(null, require(file));
+        }
+        catch (err) {
+            callback(err, null);
+        }
     }
 
     request(base, file, encoding, callback) {

+ 12 - 1
tools/caffe

@@ -10,6 +10,10 @@ third_party=${root}/third_party
 
 identifier=caffe
 
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
+
 git_sync() {
     mkdir -p "${third_party}"
     if [ -d "${third_party}/${1}" ]; then
@@ -22,12 +26,18 @@ git_sync() {
     git submodule update --init
 }
 
+clean() {
+    bold "caffe clean"
+    rm -tf ${third_party}/caffe
+}
+
 sync() {
+    bold "caffe sync"
     git_sync caffe https://github.com/BVLC/caffe.git
 }
 
 schema() {
-    echo "Generate 'caffe.js'"
+    bold "caffe schema"
     cp ${third_party}/${identifier}/src/caffe/proto/caffe.proto ${tools}/caffe.proto
     node ${tools}/caffe-schema.js ${tools}/caffe.proto
     ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case --decode-text -r caffe -o ${src}/caffe-proto.js ${tools}/caffe.proto
@@ -38,6 +48,7 @@ schema() {
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "schema") schema;;
     esac

+ 12 - 1
tools/cntk

@@ -10,6 +10,10 @@ third_party=${root}/third_party
 
 identifier=cntk
 
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
+
 git_sync() {
     mkdir -p "${third_party}"
     if [ -d "${third_party}/${1}" ]; then
@@ -22,12 +26,18 @@ git_sync() {
     git submodule update --init
 }
 
+clean() {
+    bold "cntk clean"
+    rm -tf ${third_party}/cntk
+}
+
 sync() {
+    bold "cntk sync"
     git_sync cntk https://github.com/Microsoft/CNTK.git
 }
 
 schema() {
-    echo "Generate 'cntk.js'"
+    bold "cntk schema"
     ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case -r cntk -o ${src}/cntk-proto.js ${third_party}/${identifier}/Source/CNTKv2LibraryDll/proto/CNTK.proto
     node ${tools}/update_pbjs.js array ${src}/cntk-proto.js value float 1
 }
@@ -35,6 +45,7 @@ schema() {
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "schema") schema;;
     esac

+ 21 - 4
tools/coreml

@@ -11,8 +11,17 @@ third_party=${root}/third_party
 identifier=coremltools
 virtualenv=${root}/build/virtualenv/${identifier}
 
-python=${python:-python}
-pip=${pip:-pip}
+if [ $(which python3) ] && [ $(which pip3) ]; then
+    python="python3"
+    pip="pip3"
+else
+    python="python"
+    pip="pip"
+fi
+
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
 
 git_sync () {
     mkdir -p "${third_party}"
@@ -25,12 +34,18 @@ git_sync () {
     fi
 }
 
+clean() {
+    bold "coreml clean"
+    rm -tf ${third_party}/coremltools
+}
+
 sync() {
+    bold "coreml sync"
     git_sync coremltools https://github.com/apple/coremltools.git
 }
 
 install() {
-    echo "Install coremltools"
+    bold "coreml install"
     virtualenv --quiet -p ${python} ${virtualenv}
     source ${virtualenv}/bin/activate
     ${pip} install --quiet ${third_party}/${identifier}
@@ -38,12 +53,13 @@ install() {
 }
 
 schema() {
-    echo "Generate 'coreml.js'"
+    bold "coreml schema"
     ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case -r coreml -o ${src}/coreml-proto.js ${third_party}/${identifier}/mlmodel/format/Model.proto
     node ${tools}/update_pbjs.js array ${src}/coreml-proto.js floatValue float 2
 }
 
 convert() {
+    bold "coreml convert"
     source ${virtualenv}/bin/activate
     ${pip} install --quiet onnx
     ${pip} install --quiet sklearn
@@ -54,6 +70,7 @@ convert() {
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "install") install;;
         "schema") schema;;

+ 15 - 1
tools/keras

@@ -14,6 +14,10 @@ pip="pip"
 identifier=keras
 virtualenv=${build}/virtualenv/${identifier}
 
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
+
 git_sync() {
     mkdir -p "${third_party}"
     if [ -d "${third_party}/${1}" ]; then
@@ -26,11 +30,19 @@ git_sync() {
     git submodule update --init
 }
 
+clean() {
+    bold "keras clean"
+    rm -rf ${virtualenv}
+    rm -rf ${third_party}/${identifier}
+}
+
 sync() {
+    bold "keras sync"
     git_sync keras https://github.com/keras-team/keras.git
 }
 
 install() {
+    bold "keras install"
     virtualenv --quiet -p ${python} ${virtualenv}
     source ${virtualenv}/bin/activate
     ${pip} install --quiet tensorflow
@@ -39,8 +51,8 @@ install() {
 }
 
 metadata() {
+    bold "keras metadata"
     source ${virtualenv}/bin/activate
-    echo "Update 'keras-metadata.json'"
     pushd ${tools} > /dev/null
     ${python} keras-script.py metadata
     popd > /dev/null
@@ -48,6 +60,7 @@ metadata() {
 }
 
 zoo() {
+    bold "keras zoo"
     source ${virtualenv}/bin/activate
     pushd ${tools} > /dev/null
     ${python} keras-script.py zoo keras.applications.densenet.DenseNet121 ${test}/data/keras/DenseNet121.h5
@@ -68,6 +81,7 @@ zoo() {
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "install") install;;
         "metadata") metadata;;

+ 1 - 2
tools/keras-script.py

@@ -338,13 +338,12 @@ def metadata():
             fout.write('\n')
 
 def zoo():
-    from pydoc import locate
     type = sys.argv[2];
     file = sys.argv[3];
     directory = os.path.dirname(file);
     if not os.path.exists(directory):
         os.makedirs(directory)
-    model = locate(type)()
+    model = pydoc.locate(type)()
     model.save(file);
 
 if __name__ == '__main__':

+ 20 - 3
tools/mxnet

@@ -7,11 +7,20 @@ src=${root}/src
 tools=${root}/tools
 third_party=${root}/third_party
 
-python=${python:-python}
-pip=${pip:-pip}
+if [ $(which python3) ] && [ $(which pip3) ]; then
+    python="python3"
+    pip="pip3"
+else
+    python="python"
+    pip="pip"
+fi
 
 identifier=mxnet
 
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
+
 git_sync() {
     mkdir -p "${third_party}"
     if [ -d "${third_party}/${1}" ]; then
@@ -24,18 +33,26 @@ git_sync() {
     git submodule update --init
 }
 
+clean() {
+    bold "mxnet clean"
+    rm -rf ${virtualenv}
+    rm -rf ${third_party}/${identifier}
+}
+
 sync() {
+    bold "mxnet sync"
     git_sync mxnet https://github.com/apache/incubator-mxnet.git
 }
 
 metadata() {
-    echo "Update 'mxnet-script.json'"
+    bold "mxnet metadata"
     # ${python} mxnet-script.py
 }
 
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "metadata") metadata;;
     esac

+ 26 - 5
tools/onnx

@@ -12,8 +12,17 @@ tools=${root}/tools
 identifier=onnx
 virtualenv=${build}/virtualenv/${identifier}
 
-python=${python:-python}
-pip=${pip:-pip}
+if [ $(which python3) ] && [ $(which pip3) ]; then
+    python="python3"
+    pip="pip3"
+else
+    python="python"
+    pip="pip"
+fi
+
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
 
 git_sync() {
     mkdir -p "${third_party}"
@@ -27,13 +36,21 @@ git_sync() {
     git submodule update --init
 }
 
+clean() {
+    bold "onnx clean"
+    rm -rf ${virtualenv}
+    rm -rf ${third_party}/${identifier}
+    rm -rf ${third_party}/onnxmltools
+}
+
 sync() {
+    bold "onnx sync"
     git_sync onnx https://github.com/onnx/onnx.git
     git_sync onnxmltools https://github.com/onnx/onnxmltools.git
 }
 
 install() {
-    echo "Install ONNX"
+    bold "onnx install"
     virtualenv --quiet -p ${python} ${virtualenv}
     source ${virtualenv}/bin/activate
     export ONNX_ML=1
@@ -43,7 +60,7 @@ install() {
 }
 
 schema() {
-    echo "Generate 'onnx-proto.js'"
+    bold "onnx schema"
     source ${virtualenv}/bin/activate
     ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case --decode-text -r onnx -o ${src}/onnx-proto.js ${third_party}/${identifier}/onnx/onnx-ml.proto ${third_party}/${identifier}/onnx/onnx-operators-ml.proto
     node ${tools}/update_pbjs.js array ${src}/onnx-proto.js float_data float 1
@@ -52,7 +69,7 @@ schema() {
 }
 
 metadata() {
-    echo "Generate 'onnx-metadata.json'"
+    bold "onnx metadata"
     source ${virtualenv}/bin/activate
     pushd ${tools} > /dev/null
     ${python} onnx-script.py metadata
@@ -61,6 +78,7 @@ metadata() {
 }
 
 convert() {
+    bold "onnx convert"
     source ${virtualenv}/bin/activate
     ${pip} install --quiet coremltools
     ${pip} install --quiet tensorflow
@@ -73,12 +91,14 @@ convert() {
 }
 
 infer() {
+    bold "onnx infer"
     source ${virtualenv}/bin/activate
     ${python} ${tools}/onnx-script.py infer ${1}
     deactivate
 }
 
 optimize() {
+    bold "onnx optimize"
     source ${virtualenv}/bin/activate
     ${python} ${tools}/onnx-script.py optimize ${1}
     deactivate
@@ -87,6 +107,7 @@ optimize() {
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "install") install;;
         "schema") schema;;

+ 48 - 9
tools/pytorch

@@ -6,14 +6,24 @@ root=$(cd $(dirname ${0})/..; pwd)
 build=${root}/build
 node_modules=${root}/node_modules
 src=${root}/src
+test=${root}/test
 tools=${root}/tools
 third_party=${root}/third_party
 
 identifier=pytorch
 virtualenv=${build}/virtualenv/${identifier}
 
-python="python"
-pip="pip"
+if [ $(which python3) ] && [ $(which pip3) ]; then
+    python="python3"
+    pip="pip3"
+else
+    python="python"
+    pip="pip"
+fi
+
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
 
 git_sync() {
     mkdir -p "${third_party}"
@@ -27,12 +37,19 @@ git_sync() {
     git submodule update --init
 }
 
+clean() {
+    bold "pytorch clean"
+    rm -rf ${virtualenv}
+    rm -rf ${third_party}/${identifier}
+}
+
 sync() {
+    bold "pytorch sync"
     git_sync pytorch https://github.com/pytorch/pytorch.git
 }
 
 install() {
-echo "Install Caffe2"
+    bold "pytorch install"
 if [ "$(uname -s)" == "Darwin" ] && [ "$(which brew)" != "" ]; then
 brew bundle --file=- <<-EOS
 brew "automake"
@@ -43,26 +60,46 @@ EOS
 fi
     virtualenv --quiet -p ${python} ${virtualenv}
     source ${virtualenv}/bin/activate
-    ${pip} install --quiet future leveldb numpy protobuf pydot python-gflags pyyaml scikit-image setuptools six hypothesis typing
     pushd "${third_party}/pytorch" > /dev/null
-    ${python} setup.py install
+    ${pip} install --quiet future leveldb numpy protobuf pydot python-gflags pyyaml scikit-image setuptools six hypothesis typing tqdm
+    ${python} setup.py install --quiet
+    ${pip} install --quiet torchvision
     popd > /dev/null
     deactivate
 }
 
 schema() {
-    echo "Generate 'caffe2.js'"
+    bold "caffe2 schema"
     ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case --decode-text -r caffe2 -o ${src}/caffe2-proto.js ${third_party}/pytorch/caffe2/proto/caffe2.proto
     node ${tools}/update_pbjs.js enumeration ${src}/caffe2-proto.js floats float 1
 }
 
 metadata() {
+    bold "pytorch metadata"
     source ${virtualenv}/bin/activate
     pushd ${tools} > /dev/null
-    echo "Generate 'caffe2-metadata.json'"
-    ${python} caffe2-script.py metadata
-    echo "Generate 'pytorch-metadata.json'"
     ${python} pytorch-script.py metadata
+    bold "caffe2 metadata"
+    ${python} caffe2-script.py metadata
+    popd > /dev/null
+    deactivate
+}
+
+zoo() {
+    bold "pytorch zoo"
+    source ${virtualenv}/bin/activate
+    pushd ${tools} > /dev/null
+    ${python} pytorch-script.py zoo torchvision.models.alexnet ${test}/data/pytorch/alexnet.pth
+    ${python} pytorch-script.py zoo torchvision.models.densenet121 ${test}/data/pytorch/densenet121.pth
+    ${python} pytorch-script.py zoo torchvision.models.densenet161 ${test}/data/pytorch/densenet161.pth
+    ${python} pytorch-script.py zoo torchvision.models.inception_v3 ${test}/data/pytorch/inception_v3.pth
+    ${python} pytorch-script.py zoo torchvision.models.resnet101 ${test}/data/pytorch/resnet101.pth
+    ${python} pytorch-script.py zoo torchvision.models.resnet18 ${test}/data/pytorch/resnet18.pth
+    ${python} pytorch-script.py zoo torchvision.models.resnet50 ${test}/data/pytorch/resnet50.pth
+    ${python} pytorch-script.py zoo torchvision.models.squeezenet1_0 ${test}/data/pytorch/squeezenet1_0.pth
+    ${python} pytorch-script.py zoo torchvision.models.vgg11_bn ${test}/data/pytorch/vgg11_bn.pth
+    ${python} pytorch-script.py zoo torchvision.models.vgg16 ${test}/data/pytorch/vgg16.pth
+    rm -rf ~/.torch/models
     popd > /dev/null
     deactivate
 }
@@ -70,9 +107,11 @@ metadata() {
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "install") install;;
         "schema") schema;;
         "metadata") metadata;;
+        "zoo") zoo;;
     esac
 done

+ 12 - 1
tools/pytorch-script.py

@@ -44,7 +44,18 @@ def metadata():
             fout.write(line)
             fout.write('\n')
 
+def zoo():
+    import torch
+    type = sys.argv[2];
+    file = sys.argv[3];
+    directory = os.path.dirname(file);
+    if not os.path.exists(directory):
+        os.makedirs(directory)
+    print(type)
+    model = pydoc.locate(type)(pretrained=True)
+    torch.save(model, file);
+
 if __name__ == '__main__':
-    command_table = { 'metadata': metadata }
+    command_table = { 'metadata': metadata, 'zoo': zoo }
     command = sys.argv[1];
     command_table[command]()

+ 14 - 2
tools/sklearn

@@ -7,8 +7,17 @@ virtualenv=${root}/build/virtualenv/scikit-learn
 tools=${root}/tools
 third_party=${root}/third_party
 
-python=${python:-python}
-pip=${pip:-pip}
+if [ $(which python3) ] && [ $(which pip3) ]; then
+    python="python3"
+    pip="pip3"
+else
+    python="python"
+    pip="pip"
+fi
+
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
 
 git_sync() {
     mkdir -p "${third_party}"
@@ -23,12 +32,14 @@ git_sync() {
 }
 
 sync() {
+    bold "sklearn clean"
     git_sync scikit-learn https://github.com/scikit-learn/scikit-learn.git
     git_sync lightgbm https://github.com/Microsoft/LightGBM.git
     git_sync xgboost https://github.com/dmlc/xgboost.git
 }
 
 install() {
+    bold "sklearn install"
     virtualenv --quiet -p ${python} ${virtualenv}
     source ${virtualenv}/bin/activate
     echo "Install scikit-learn"
@@ -39,6 +50,7 @@ install() {
 }
 
 metadata() {
+    bold "sklearn metadata"
     source ${virtualenv}/bin/activate
     echo "Update 'sklearn-metadata.json'"
     pushd ${tools} > /dev/null

+ 24 - 7
tools/tf

@@ -12,8 +12,17 @@ third_party=${root}/third_party
 identifier=tensorflow
 virtualenv=${build}/virtualenv/${identifier}
 
-python=${python:-python}
-pip=${pip:-pip}
+if [ $(which python3) ] && [ $(which pip3) ]; then
+    python="python3"
+    pip="pip3"
+else
+    python="python"
+    pip="pip"
+fi
+
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
 
 git_sync () {
     mkdir -p "${third_party}"
@@ -26,20 +35,27 @@ git_sync () {
     fi
 }
 
+clean() {
+    bold "tf clean"
+    rm -rf ${virtualenv}
+    rm -rf ${third_party}/${identifier}
+}
+
 sync() {
+    bold "tf sync"
     git_sync tensorflow https://github.com/tensorflow/tensorflow.git
 }
 
 install() {
+    bold "tf install"
     virtualenv --quiet -p ${python} ${virtualenv}
     source ${virtualenv}/bin/activate
-    echo "Install protobuf"
     ${pip} install --quiet protobuf
     deactivate
 }
 
 schema() {
-    echo "Generate 'tf-proto.js'"
+    bold "tf schema"
     ${node_modules}/protobufjs/bin/pbjs -t static-module -w closure --no-encode --no-delimited --no-comments --keep-case --decode-text -r tf -o ${src}/tf-proto.js \
         ${third_party}/${identifier}/tensorflow/core/protobuf/saved_model.proto \
         ${third_party}/${identifier}/tensorflow/core/protobuf/meta_graph.proto \
@@ -57,8 +73,9 @@ schema() {
 }
 
 metadata() {
-    echo "Generate 'tf-metadata.json'"
+    bold "tf metadata"
     source ${virtualenv}/bin/activate
+    pushd ${tools} > /dev/null
     protoc --proto_path ${third_party}/${identifier} tensorflow/core/framework/attr_value.proto --python_out=${tools}
     protoc --proto_path ${third_party}/${identifier} tensorflow/core/framework/tensor.proto --python_out=${tools}
     protoc --proto_path ${third_party}/${identifier} tensorflow/core/framework/types.proto --python_out=${tools}
@@ -69,16 +86,16 @@ metadata() {
     touch ${tools}/tensorflow/__init__.py
     touch ${tools}/tensorflow/core/__init__.py
     touch ${tools}/tensorflow/core/framework/__init__.py
-    pushd ${tools} > /dev/null
     ${python} tf-script.py
-    popd > /dev/null
     rm -rf ${tools}/tensorflow
+    popd > /dev/null
     deactivate
 }
 
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
         "install") install;;
         "schema") schema;;

+ 20 - 6
tools/tflite

@@ -10,8 +10,12 @@ third_party=${root}/third_party
 identifier=tflite
 virtualenv=${build}/virtualenv/${identifier}
 
-python=${python:-python}
-pip=${pip:-pip}
+python="python"
+pip="pip"
+
+bold() {
+    echo "$(tty -s && tput bold)$1$(tty -s && tput sgr0)" 
+}
 
 git_sync () {
     mkdir -p "${third_party}"
@@ -24,12 +28,20 @@ git_sync () {
     fi
 }
 
+clean() {
+    bold "tflite clean"
+    rm -rf ${virtualenv}
+    rm -rf ${third_party}/tensorflow
+}
+
 sync() {
+    bold "tflite sync"
     git_sync flatbuffers https://github.com/google/flatbuffers.git
     git_sync tensorflow https://github.com/tensorflow/tensorflow.git
 }
 
-build() {
+install() {
+    bold "flatbuffers install"
     echo "Build flatbuffers..."
     pushd "${third_party}/flatbuffers" > /dev/null
     cmake -G "Unix Makefiles"
@@ -39,7 +51,7 @@ build() {
 }
 
 schema() {
-    echo "Generate '../src/tflite-schema.js'"
+    bold "tflite schema"
     cp ${third_party}/tensorflow/tensorflow/lite/schema/schema.fbs ${tools}/tflite.schema.fbs
     sed -i 's/namespace tflite\;/namespace tflite_schema\;/' ${tools}/tflite.schema.fbs
     ${third_party}/flatbuffers/flatc --no-js-exports --js ${tools}/tflite.schema.fbs
@@ -53,17 +65,19 @@ EOT
 }
 
 visualize() {
+    bold "tflite visualize"
     source ${virtualenv}/bin/activate
     ${pip} install --quiet tensorflow
-    python ${third_party}/tensorflow/tensorflow/lite/tools/visualize.py $@
+    ${python} ${third_party}/tensorflow/tensorflow/lite/tools/visualize.py $@
     deactivate
 }
 
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
+        "clean") clean;;
         "sync") sync;;
-        "build") build;;
+        "install") install;;
         "schema") schema;;
         "visualize") visualize ${1} ${2} && shift && shift;;
     esac

+ 5 - 28
tools/update

@@ -2,40 +2,17 @@
 
 set -e
 
-if [ $(which python3) ] && [ $(which pip3) ]; then
-    export python=python3
-    export pip=pip3
-fi
+tools=$(cd $(dirname ${0})/..; pwd)/tools
 
-root=$(cd $(dirname ${0})/..; pwd)
-tools=${root}/tools
-
-echo "Update TensorFlow"
 ${tools}/tf sync install schema metadata
-
-echo "Update Keras"
 ${tools}/keras sync install metadata
 
-echo "Update CoreML"
-${tools}/coreml sync install schema
-
-echo "Update Caffe"
 ${tools}/caffe sync schema
-
-echo "Update MXNet"
+${tools}/coreml sync install schema
+${tools}/cntk sync schema
 ${tools}/mxnet sync metadata
-
-echo "Update ONNX"
 ${tools}/onnx sync install schema metadata
-
-echo "Update TensorFlow Lite"
-${tools}/tflite sync install schema
-
-echo "Update scikit-learn"
+${tools}/pytorch sync install schema metadata
 ${tools}/sklearn sync install metadata
 
-echo "Update CNTK"
-${tools}/cntk sync schema
-
-echo "Update PyTorch"
-${tools}/pytorch sync install metadata schema
+${tools}/tflite sync install schema