|
|
@@ -229,7 +229,28 @@ def convert():
|
|
|
with open(base + '.pbtxt', 'w') as text_file:
|
|
|
text_file.write(text)
|
|
|
|
|
|
+def optimize():
|
|
|
+ import onnx
|
|
|
+ from onnx import optimizer
|
|
|
+ file = sys.argv[2];
|
|
|
+ base, extension = os.path.splitext(file)
|
|
|
+ onnx_model = onnx.load(file)
|
|
|
+ passes = optimizer.get_available_passes()
|
|
|
+ optimized_model = optimizer.optimize(onnx_model, passes)
|
|
|
+ onnx.save(optimized_model, base + '.optimized.onnx')
|
|
|
+
|
|
|
+
|
|
|
+def infer():
|
|
|
+ import onnx
|
|
|
+ import onnx.shape_inference
|
|
|
+ from onnx import shape_inference
|
|
|
+ file = sys.argv[2];
|
|
|
+ base, extension = os.path.splitext(file)
|
|
|
+ onnx_model = onnx.load(base + '.onnx')
|
|
|
+ onnx_model = onnx.shape_inference.infer_shapes(onnx_model);
|
|
|
+ onnx.save(onnx_model, base + '.shape.onnx')
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
- command_table = { 'metadata': metadata, 'convert': convert }
|
|
|
+ command_table = { 'metadata': metadata, 'convert': convert, 'optimize': optimize, 'infer': infer }
|
|
|
command = sys.argv[1];
|
|
|
command_table[command]()
|