|
|
@@ -209,7 +209,7 @@ def metadata():
|
|
|
generate_json(schemas, '../src/onnx-metadata.json')
|
|
|
|
|
|
def convert():
|
|
|
- file = sys.argv[2];
|
|
|
+ file = sys.argv[2]
|
|
|
base, extension = os.path.splitext(file)
|
|
|
if extension == '.mlmodel':
|
|
|
pip_import('coremltools')
|
|
|
@@ -242,8 +242,8 @@ def convert():
|
|
|
def optimize():
|
|
|
import onnx
|
|
|
from onnx import optimizer
|
|
|
- file = sys.argv[2];
|
|
|
- base, extension = os.path.splitext(file)
|
|
|
+ file = sys.argv[2]
|
|
|
+ base = os.path.splitext(file)
|
|
|
onnx_model = onnx.load(file)
|
|
|
passes = optimizer.get_available_passes()
|
|
|
optimized_model = optimizer.optimize(onnx_model, passes)
|
|
|
@@ -253,13 +253,13 @@ def infer():
|
|
|
import onnx
|
|
|
import onnx.shape_inference
|
|
|
from onnx import shape_inference
|
|
|
- file = sys.argv[2];
|
|
|
- base, extension = os.path.splitext(file)
|
|
|
+ file = sys.argv[2]
|
|
|
+ base = os.path.splitext(file)
|
|
|
onnx_model = onnx.load(base + '.onnx')
|
|
|
- onnx_model = onnx.shape_inference.infer_shapes(onnx_model);
|
|
|
+ onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
|
|
onnx.save(onnx_model, base + '.shape.onnx')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
command_table = { 'metadata': metadata, 'convert': convert, 'optimize': optimize, 'infer': infer }
|
|
|
- command = sys.argv[1];
|
|
|
+ command = sys.argv[1]
|
|
|
command_table[command]()
|