Browse Source

Fix onnx convert script

Lutz Roeder 7 years ago
parent
commit
4603401d27
8 changed files with 12 additions and 14 deletions
  1. 1 1
      tools/caffe
  2. 1 1
      tools/cntk
  3. 1 1
      tools/keras
  4. 1 1
      tools/mxnet
  5. 1 3
      tools/onnx
  6. 5 5
      tools/onnx-script.py
  7. 1 1
      tools/pytorch
  8. 1 1
      tools/sklearn

+ 1 - 1
tools/caffe

@@ -23,7 +23,7 @@ git_sync() {
         echo "Clone ${2}..."
         git -C "${third_party}" clone --recursive ${2}
     fi
-    git submodule update --init
+    git -C "${third_party}" submodule update --init
 }
 
 clean() {

+ 1 - 1
tools/cntk

@@ -23,7 +23,7 @@ git_sync() {
         echo "Clone ${2}..."
         git -C "${third_party}" clone --recursive ${2} ${1}
     fi
-    git submodule update --init
+    git -C "${third_party}" submodule update --init
 }
 
 clean() {

+ 1 - 1
tools/keras

@@ -27,7 +27,7 @@ git_sync() {
         echo "Clone ${2}..."
         git -C "${third_party}" clone --recursive ${2}
     fi
-    git submodule update --init
+    git -C "${third_party}" submodule update --init
 }
 
 clean() {

+ 1 - 1
tools/mxnet

@@ -30,7 +30,7 @@ git_sync() {
         echo "Clone ${2}..."
         git -C "${third_party}" clone --recursive ${2} ${1}
     fi
-    git submodule update --init
+    git -C "${third_party}" submodule update --init
 }
 
 clean() {

+ 1 - 3
tools/onnx

@@ -33,7 +33,7 @@ git_sync() {
         echo "Clone ${2}..."
         git -C "${third_party}" clone --recursive ${2}
     fi
-    git submodule update --init
+    git -C "${third_party}" submodule update --init
 }
 
 clean() {
@@ -80,8 +80,6 @@ metadata() {
 convert() {
     bold "onnx convert"
     source ${virtualenv}/bin/activate
-    ${pip} install --quiet sklearn
-    ${pip} install --quiet lightgbm
     ${pip} install --quiet ${third_party}/onnxmltools
     ${python} ${tools}/onnx-script.py convert ${1}
     deactivate

+ 5 - 5
tools/onnx-script.py

@@ -197,9 +197,9 @@ def pip_import(package):
     import importlib
     try:
         importlib.import_module(package)
-    except ImportError:
-        import pip
-        pip.main([ 'install', package ])
+    except:
+        import subprocess
+        subprocess.call([ 'pip', 'install', '--quiet', package ])
     finally:
         globals()[package] = importlib.import_module(package)
 
@@ -225,9 +225,9 @@ def convert():
         onnx_model = onnxmltools.convert.convert_keras(keras_model)
         onnxmltools.utils.save_model(onnx_model, base + '.onnx')
     elif extension == '.pkl':
-        from sklearn.externals import joblib
+        pip_import('sklearn')
         import onnxmltools
-        sklearn_model = joblib.load(file)
+        sklearn_model = sklearn.externals.joblib.load(file)
         onnx_model = onnxmltools.convert.convert_sklearn(sklearn_model)
         onnxmltools.utils.save_model(onnx_model, base + '.onnx')
     base, extension = os.path.splitext(file)

+ 1 - 1
tools/pytorch

@@ -34,7 +34,7 @@ git_sync() {
         echo "Clone ${2}..."
         git -C "${third_party}" clone --recursive ${2}
     fi
-    git submodule update --init
+    git -C "${third_party}" submodule update --init
 }
 
 clean() {

+ 1 - 1
tools/sklearn

@@ -28,7 +28,7 @@ git_sync() {
         echo "Clone ${2}..."
         git -C "${third_party}" clone --recursive ${2} ${1}
     fi
-    git submodule update --init
+    git -C "${third_party}" submodule update --init
 }
 
 sync() {