Browse Source

Add infer and optimize to onnx-script

Lutz Roeder 7 years ago
parent
commit
0d7724911c
2 changed files with 36 additions and 1 deletions
  1. 14 0
      tools/onnx
  2. 22 1
      tools/onnx-script.py

+ 14 - 0
tools/onnx

@@ -72,6 +72,18 @@ convert() {
     deactivate
 }
 
+infer() {
+    source ${virtualenv}/bin/activate
+    ${python} ${tools}/onnx-script.py infer ${1}
+    deactivate
+}
+
+optimize() {
+    source ${virtualenv}/bin/activate
+    ${python} ${tools}/onnx-script.py optimize ${1}
+    deactivate
+}
+
 while [ "$#" != 0 ]; do
     command="$1" && shift
     case "${command}" in
@@ -80,5 +92,7 @@ while [ "$#" != 0 ]; do
         "schema") schema;;
         "metadata") metadata;;
         "convert") convert ${1} && shift;;
+        "infer") infer ${1} && shift;;
+        "optimize") optimize ${1} && shift;;
     esac
 done

+ 22 - 1
tools/onnx-script.py

@@ -229,7 +229,28 @@ def convert():
         with open(base + '.pbtxt', 'w') as text_file:
             text_file.write(text)
 
+def optimize():
+    import onnx
+    from onnx import optimizer
+    file = sys.argv[2];
+    base, extension = os.path.splitext(file)
+    onnx_model = onnx.load(file)
+    passes = optimizer.get_available_passes()
+    optimized_model = optimizer.optimize(onnx_model, passes)
+    onnx.save(optimized_model, base + '.optimized.onnx')
+
+
+def infer():
+    import onnx
+    import onnx.shape_inference
+    from onnx import shape_inference
+    file = sys.argv[2];
+    base, extension = os.path.splitext(file)
+    onnx_model = onnx.load(base + '.onnx')
+    onnx_model = onnx.shape_inference.infer_shapes(onnx_model);
+    onnx.save(onnx_model, base + '.shape.onnx')
+
 if __name__ == '__main__':
-    command_table = { 'metadata': metadata, 'convert': convert }
+    command_table = { 'metadata': metadata, 'convert': convert, 'optimize': optimize, 'infer': infer }
     command = sys.argv[1];
     command_table[command]()