onnx-script.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from __future__ import unicode_literals
  2. import onnx
  3. import json
  4. import io
  5. import os
  6. import re
  7. import sys
  8. from onnx import defs
  9. from onnx.defs import OpSchema
  10. from onnx.backend.test.case import collect_snippets
  11. snippets = collect_snippets()
  12. categories = {
  13. 'Constant': 'Constant',
  14. 'Conv': 'Layer',
  15. 'ConvTranspose': 'Layer',
  16. 'FC': 'Layer',
  17. 'RNN': 'Layer',
  18. 'LSTM': 'Layer',
  19. 'GRU': 'Layer',
  20. 'Gemm': 'Layer',
  21. 'Dropout': 'Dropout',
  22. 'Elu': 'Activation',
  23. 'HardSigmoid': 'Activation',
  24. 'LeakyRelu': 'Activation',
  25. 'PRelu': 'Activation',
  26. 'ThresholdedRelu': 'Activation',
  27. 'Relu': 'Activation',
  28. 'Selu': 'Activation',
  29. 'Sigmoid': 'Activation',
  30. 'Tanh': 'Activation',
  31. 'LogSoftmax': 'Activation',
  32. 'Softmax': 'Activation',
  33. 'Softplus': 'Activation',
  34. 'Softsign': 'Activation',
  35. 'BatchNormalization': 'Normalization',
  36. 'InstanceNormalization': 'Normalization',
  37. 'LpNormalization': 'Normalization',
  38. 'LRN': 'Normalization',
  39. 'Flatten': 'Shape',
  40. 'Reshape': 'Shape',
  41. 'Tile': 'Shape',
  42. 'Xor': 'Logic',
  43. 'Not': 'Logic',
  44. 'Or': 'Logic',
  45. 'Less': 'Logic',
  46. 'And': 'Logic',
  47. 'Greater': 'Logic',
  48. 'Equal': 'Logic',
  49. 'AveragePool': 'Pool',
  50. 'GlobalAveragePool': 'Pool',
  51. 'GlobalLpPool': 'Pool',
  52. 'GlobalMaxPool': 'Pool',
  53. 'LpPool': 'Pool',
  54. 'MaxPool': 'Pool',
  55. 'MaxRoiPool': 'Pool',
  56. 'Concat': 'Tensor',
  57. 'Slice': 'Tensor',
  58. 'Split': 'Tensor',
  59. 'Pad': 'Tensor',
  60. 'ImageScaler': 'Data',
  61. 'Crop': 'Data',
  62. 'Upsample': 'Data',
  63. 'Transpose': 'Transform',
  64. 'Gather': 'Transform',
  65. 'Unsqueeze': 'Transform',
  66. 'Squeeze': 'Transform',
  67. }
  68. attribute_type_table = {
  69. 'undefined': None,
  70. 'float': 'float32', 'int': 'int64', 'string': 'string', 'tensor': 'tensor', 'graph': 'graph',
  71. 'floats': 'float32[]', 'ints': 'int64[]', 'strings': 'string[]', 'tensors': 'tensor[]', 'graphs': 'graph[]',
  72. }
  73. def generate_json_attr_type(type):
  74. assert isinstance(type, OpSchema.AttrType)
  75. s = str(type)
  76. s = s[s.rfind('.')+1:].lower()
  77. if s in attribute_type_table:
  78. return attribute_type_table[s]
  79. return None
  80. def generate_json_attr_default_value(attr_value):
  81. if not str(attr_value):
  82. return None
  83. if attr_value.HasField('i'):
  84. return attr_value.i
  85. if attr_value.HasField('s'):
  86. return attr_value.s.decode('utf8')
  87. if attr_value.HasField('f'):
  88. return attr_value.f
  89. return None
  90. def generate_json_support_level_name(support_level):
  91. assert isinstance(support_level, OpSchema.SupportType)
  92. s = str(support_level)
  93. return s[s.rfind('.')+1:].lower()
  94. def generate_json_types(types):
  95. r = []
  96. for type in types:
  97. r.append(type)
  98. r = sorted(r)
  99. return r
  100. def format_range(value):
  101. if value == 2147483647:
  102. return '∞'
  103. return str(value)
  104. def format_description(description):
  105. def replace_line(match):
  106. link = match.group(1)
  107. url = match.group(2)
  108. if not url.startswith("http://") and not url.startswith("https://"):
  109. url = "https://github.com/onnx/onnx/blob/master/docs/" + url
  110. return "[" + link + "](" + url + ")";
  111. description = re.sub("\\[(.+)\\]\\(([^ ]+?)( \"(.+)\")?\\)", replace_line, description)
  112. return description
  113. def generate_json(schemas, json_file):
  114. json_root = []
  115. for schema in schemas:
  116. json_schema = {}
  117. if schema.domain:
  118. json_schema['domain'] = schema.domain
  119. else:
  120. json_schema['domain'] = 'ai.onnx'
  121. json_schema['since_version'] = schema.since_version
  122. json_schema['support_level'] = generate_json_support_level_name(schema.support_level)
  123. if schema.doc:
  124. json_schema['description'] = format_description(schema.doc.lstrip())
  125. if schema.inputs:
  126. json_schema['inputs'] = []
  127. for input in schema.inputs:
  128. json_input = {}
  129. json_input['name'] = input.name
  130. json_input['description'] = format_description(input.description)
  131. json_input['type'] = input.typeStr
  132. if input.option == OpSchema.FormalParameterOption.Optional:
  133. json_input['option'] = 'optional'
  134. elif input.option == OpSchema.FormalParameterOption.Variadic:
  135. json_input['option'] = 'variadic'
  136. json_schema['inputs'].append(json_input)
  137. json_schema['min_input'] = schema.min_input
  138. json_schema['max_input'] = schema.max_input
  139. if schema.outputs:
  140. json_schema['outputs'] = []
  141. for output in schema.outputs:
  142. json_output = {}
  143. json_output['name'] = output.name
  144. json_output['description'] = format_description(output.description)
  145. json_output['type'] = output.typeStr
  146. if output.option == OpSchema.FormalParameterOption.Optional:
  147. json_output['option'] = 'optional'
  148. elif output.option == OpSchema.FormalParameterOption.Variadic:
  149. json_output['option'] = 'variadic'
  150. json_schema['outputs'].append(json_output)
  151. json_schema['min_output'] = schema.min_output
  152. json_schema['max_output'] = schema.max_output
  153. if schema.min_input != schema.max_input:
  154. json_schema['inputs_range'] = format_range(schema.min_input) + ' - ' + format_range(schema.max_input);
  155. if schema.min_output != schema.max_output:
  156. json_schema['outputs_range'] = format_range(schema.min_output) + ' - ' + format_range(schema.max_output);
  157. if schema.attributes:
  158. json_schema['attributes'] = []
  159. for _, attribute in sorted(schema.attributes.items()):
  160. json_attribute = {}
  161. json_attribute['name'] = attribute.name
  162. json_attribute['description'] = format_description(attribute.description)
  163. attribute_type = generate_json_attr_type(attribute.type)
  164. if attribute_type:
  165. json_attribute['type'] = attribute_type
  166. elif 'type' in json_attribute:
  167. del json_attribute['type']
  168. json_attribute['required'] = attribute.required
  169. default_value = generate_json_attr_default_value(attribute.default_value)
  170. if default_value:
  171. json_attribute['default'] = default_value
  172. json_schema['attributes'].append(json_attribute)
  173. if schema.type_constraints:
  174. json_schema['type_constraints'] = []
  175. for type_constraint in schema.type_constraints:
  176. json_schema['type_constraints'].append({
  177. 'description': type_constraint.description,
  178. 'type_param_str': type_constraint.type_param_str,
  179. 'allowed_type_strs': type_constraint.allowed_type_strs
  180. })
  181. if schema.name in snippets:
  182. json_schema['examples'] = []
  183. for summary, code in sorted(snippets[schema.name]):
  184. json_schema['examples'].append({
  185. 'summary': summary,
  186. 'code': code
  187. })
  188. if schema.name in categories:
  189. json_schema['category'] = categories[schema.name]
  190. json_root.append({
  191. 'name': schema.name,
  192. 'schema': json_schema
  193. })
  194. with io.open(json_file, 'w', newline='') as fout:
  195. json_root = json.dumps(json_root, sort_keys=True, indent=2)
  196. for line in json_root.splitlines():
  197. line = line.rstrip()
  198. if sys.version_info[0] < 3:
  199. line = unicode(line)
  200. fout.write(line)
  201. fout.write('\n')
  202. def metadata():
  203. schemas = defs.get_all_schemas_with_history()
  204. schemas = sorted(schemas, key=lambda schema: schema.name)
  205. json_file = os.path.join(os.path.dirname(__file__), '../src/onnx-metadata.json')
  206. generate_json(schemas, json_file)
  207. def convert():
  208. def pip_import(package):
  209. import importlib
  210. try:
  211. importlib.import_module(package)
  212. except:
  213. import subprocess
  214. subprocess.call([ 'pip', 'install', '--quiet', package ])
  215. finally:
  216. globals()[package] = importlib.import_module(package)
  217. file = sys.argv[2]
  218. base, extension = os.path.splitext(file)
  219. if extension == '.mlmodel':
  220. pip_import('coremltools')
  221. import onnxmltools
  222. coreml_model = coremltools.utils.load_spec(file)
  223. onnx_model = onnxmltools.convert.convert_coreml(coreml_model)
  224. onnxmltools.utils.save_model(onnx_model, base + '.onnx')
  225. elif extension == '.h5':
  226. pip_import('tensorflow')
  227. pip_import('keras')
  228. import onnxmltools
  229. keras_model = keras.models.load_model(file)
  230. onnx_model = onnxmltools.convert.convert_keras(keras_model)
  231. onnxmltools.utils.save_model(onnx_model, base + '.onnx')
  232. elif extension == '.pkl':
  233. pip_import('sklearn')
  234. import onnxmltools
  235. sklearn_model = sklearn.externals.joblib.load(file)
  236. onnx_model = onnxmltools.convert.convert_sklearn(sklearn_model)
  237. onnxmltools.utils.save_model(onnx_model, base + '.onnx')
  238. base, extension = os.path.splitext(file)
  239. if extension == '.onnx':
  240. import onnx
  241. from google.protobuf import text_format
  242. onnx_model = onnx.load(file)
  243. text = text_format.MessageToString(onnx_model)
  244. with open(base + '.pbtxt', 'w') as text_file:
  245. text_file.write(text)
  246. def optimize():
  247. import onnx
  248. from onnx import optimizer
  249. file = sys.argv[2]
  250. base = os.path.splitext(file)
  251. onnx_model = onnx.load(file)
  252. passes = optimizer.get_available_passes()
  253. optimized_model = optimizer.optimize(onnx_model, passes)
  254. onnx.save(optimized_model, base + '.optimized.onnx')
  255. def infer():
  256. import onnx
  257. import onnx.shape_inference
  258. from onnx import shape_inference
  259. file = sys.argv[2]
  260. base = os.path.splitext(file)[0]
  261. onnx_model = onnx.load(base + '.onnx')
  262. onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
  263. onnx.save(onnx_model, base + '.shape.onnx')
  264. if __name__ == '__main__':
  265. command_table = { 'metadata': metadata, 'convert': convert, 'optimize': optimize, 'infer': infer }
  266. command = sys.argv[1]
  267. command_table[command]()