|
|
@@ -1,6 +1,9 @@
|
|
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
+import onnx
|
|
|
+print(onnx.__file__)
|
|
|
+
|
|
|
import json
|
|
|
import io
|
|
|
import sys
|
|
|
@@ -190,7 +193,42 @@ def generate_json(schemas, json_file):
|
|
|
fout.write(line)
|
|
|
fout.write('\n')
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
+def metadata():
|
|
|
schemas = defs.get_all_schemas_with_history()
|
|
|
schemas = sorted(schemas, key=lambda schema: schema.name)
|
|
|
- generate_json(schemas, '../../src/onnx-metadata.json')
|
|
|
+ generate_json(schemas, '../src/onnx-metadata.json')
|
|
|
+
|
|
|
+def convert():
|
|
|
+ file = sys.argv[2];
|
|
|
+ base, extension = os.path.splitext(file)
|
|
|
+ if extension == '.mlmodel':
|
|
|
+ import coremltools
|
|
|
+ import onnxmltools
|
|
|
+ coreml_model = coremltools.utils.load_spec(file)
|
|
|
+ onnx_model = onnxmltools.convert.convert_coreml(coreml_model)
|
|
|
+ onnxmltools.utils.save_model(onnx_model, base + '.onnx')
|
|
|
+ elif extension == '.h5':
|
|
|
+ import keras
|
|
|
+ import onnxmltools
|
|
|
+ keras_model = keras.models.load_model(file)
|
|
|
+ onnx_model = onnxmltools.convert.convert_keras(keras_model)
|
|
|
+ onnxmltools.utils.save_model(onnx_model, base + '.onnx')
|
|
|
+ elif extension == '.pkl':
|
|
|
+ from sklearn.externals import joblib
|
|
|
+ import onnxmltools
|
|
|
+ sklearn_model = joblib.load(file)
|
|
|
+ onnx_model = onnxmltools.convert.convert_sklearn(sklearn_model)
|
|
|
+ onnxmltools.utils.save_model(onnx_model, base + '.onnx')
|
|
|
+ base, extension = os.path.splitext(file)
|
|
|
+ if extension == '.onnx':
|
|
|
+ import onnx
|
|
|
+ from google.protobuf import text_format
|
|
|
+ onnx_model = onnx.load(file)
|
|
|
+ text = text_format.MessageToString(onnx_model)
|
|
|
+ with open(base + '.pbtxt', 'w') as text_file:
|
|
|
+ text_file.write(text)
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ command_table = { 'metadata': metadata }
|
|
|
+ command = sys.argv[1];
|
|
|
+ command_table[command]()
|