onnx_metadata.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. ''' ONNX metadata script '''
  2. import collections
  3. import json
  4. import os
  5. import re
  6. import onnx.backend.test.case # pylint: disable=import-error
  7. import onnx.defs # pylint: disable=import-error
  8. def _read(path):
  9. with open(path, 'r', encoding='utf-8') as file:
  10. return file.read()
  11. def _write(path, content):
  12. with open(path, 'w', encoding='utf-8') as file:
  13. file.write(content)
  14. categories = {
  15. 'Constant': 'Constant',
  16. 'Conv': 'Layer',
  17. 'ConvInteger': 'Layer',
  18. 'ConvTranspose': 'Layer',
  19. 'FC': 'Layer',
  20. 'RNN': 'Layer',
  21. 'LSTM': 'Layer',
  22. 'GRU': 'Layer',
  23. 'Gemm': 'Layer',
  24. 'FusedConv': 'Layer',
  25. 'Dropout': 'Dropout',
  26. 'Elu': 'Activation',
  27. 'HardSigmoid': 'Activation',
  28. 'LeakyRelu': 'Activation',
  29. 'PRelu': 'Activation',
  30. 'ThresholdedRelu': 'Activation',
  31. 'Relu': 'Activation',
  32. 'Selu': 'Activation',
  33. 'Sigmoid': 'Activation',
  34. 'Tanh': 'Activation',
  35. 'LogSoftmax': 'Activation',
  36. 'Softmax': 'Activation',
  37. 'Softplus': 'Activation',
  38. 'Softsign': 'Activation',
  39. 'Clip': 'Activation',
  40. 'BatchNormalization': 'Normalization',
  41. 'InstanceNormalization': 'Normalization',
  42. 'LpNormalization': 'Normalization',
  43. 'LRN': 'Normalization',
  44. 'Flatten': 'Shape',
  45. 'Reshape': 'Shape',
  46. 'Tile': 'Shape',
  47. 'AveragePool': 'Pool',
  48. 'GlobalAveragePool': 'Pool',
  49. 'GlobalLpPool': 'Pool',
  50. 'GlobalMaxPool': 'Pool',
  51. 'LpPool': 'Pool',
  52. 'MaxPool': 'Pool',
  53. 'MaxRoiPool': 'Pool',
  54. 'Concat': 'Tensor',
  55. 'Slice': 'Tensor',
  56. 'Split': 'Tensor',
  57. 'Pad': 'Tensor',
  58. 'ImageScaler': 'Data',
  59. 'Crop': 'Data',
  60. 'Upsample': 'Data',
  61. 'Transpose': 'Transform',
  62. 'Gather': 'Transform',
  63. 'Unsqueeze': 'Transform',
  64. 'Squeeze': 'Transform',
  65. }
  66. attribute_type_table = {
  67. 'undefined': None,
  68. 'float': 'float32', 'int': 'int64', 'string': 'string',
  69. 'tensor': 'tensor', 'graph': 'graph',
  70. 'floats': 'float32[]', 'ints': 'int64[]', 'strings': 'string[]',
  71. 'tensors': 'tensor[]', 'graphs': 'graph[]',
  72. }
  73. def _get_attr_type(attribute_type, attribute_name, op_type, op_domain):
  74. key = op_domain + ':' + op_type + ':' + attribute_name
  75. if key in (':Cast:to', ':EyeLike:dtype', ':RandomNormal:dtype'):
  76. return 'DataType'
  77. value = str(attribute_type)
  78. value = value[value.rfind('.')+1:].lower()
  79. if value in attribute_type_table:
  80. return attribute_type_table[value]
  81. return None
  82. def _get_attr_default_value(attr_value):
  83. if not str(attr_value):
  84. return None
  85. if attr_value.HasField('i'):
  86. return attr_value.i
  87. if attr_value.HasField('s'):
  88. return attr_value.s.decode('utf8')
  89. if attr_value.HasField('f'):
  90. return attr_value.f
  91. return None
  92. def _generate_json_support_level_name(support_level):
  93. value = str(support_level)
  94. return value[value.rfind('.')+1:].lower()
  95. def _format_description(description):
  96. def replace_line(match):
  97. link = match.group(1)
  98. url = match.group(2)
  99. if not url.startswith("http://") and not url.startswith("https://"):
  100. url = "https://github.com/onnx/onnx/blob/master/docs/" + url
  101. return "[" + link + "](" + url + ")"
  102. description = re.sub("\\[(.+)\\]\\(([^ ]+?)( \"(.+)\")?\\)", replace_line, description)
  103. return description
  104. def _update_attributes(json_schema, schema):
  105. json_schema['attributes'] = []
  106. for _ in collections.OrderedDict(schema.attributes.items()).values():
  107. json_attribute = {}
  108. json_attribute['name'] = _.name
  109. attribute_type = _get_attr_type(_.type, _.name, schema.name, schema.domain)
  110. if attribute_type:
  111. json_attribute['type'] = attribute_type
  112. elif 'type' in json_attribute:
  113. del json_attribute['type']
  114. json_attribute['required'] = _.required
  115. default_value = _get_attr_default_value(_.default_value)
  116. if default_value:
  117. json_attribute['default'] = default_value
  118. json_attribute['description'] = _format_description(_.description)
  119. json_schema['attributes'].append(json_attribute)
  120. def _update_inputs(json_schema, inputs):
  121. json_schema['inputs'] = []
  122. for _ in inputs:
  123. json_input = {}
  124. json_input['name'] = _.name
  125. json_input['type'] = _.type_str
  126. if _.option == onnx.defs.OpSchema.FormalParameterOption.Optional:
  127. json_input['option'] = 'optional'
  128. elif _.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
  129. json_input['list'] = True
  130. json_input['description'] = _format_description(_.description)
  131. json_schema['inputs'].append(json_input)
  132. def _update_outputs(json_schema, outputs):
  133. json_schema['outputs'] = []
  134. for _ in outputs:
  135. json_output = {}
  136. json_output['name'] = _.name
  137. json_output['type'] = _.type_str
  138. if _.option == onnx.defs.OpSchema.FormalParameterOption.Optional:
  139. json_output['option'] = 'optional'
  140. elif _.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
  141. json_output['list'] = True
  142. json_output['description'] = _format_description(_.description)
  143. json_schema['outputs'].append(json_output)
  144. def _update_type_constraints(json_schema, type_constraints):
  145. json_schema['type_constraints'] = []
  146. for _ in type_constraints:
  147. json_schema['type_constraints'].append({
  148. 'description': _.description,
  149. 'type_param_str': _.type_param_str,
  150. 'allowed_type_strs': _.allowed_type_strs
  151. })
  152. def _update_snippets(json_schema, snippets):
  153. json_schema['examples'] = []
  154. for summary, code in sorted(snippets):
  155. lines = code.splitlines()
  156. while len(lines) > 0 and re.search("\\s*#", lines[-1]):
  157. lines.pop()
  158. if len(lines) > 0 and len(lines[-1]) == 0:
  159. lines.pop()
  160. json_schema['examples'].append({
  161. 'summary': summary,
  162. 'code': '\n'.join(lines)
  163. })
  164. def _format_range(value):
  165. return '∞' if value == 2147483647 else str(value)
  166. def _metadata():
  167. json_root = []
  168. numpy = __import__('numpy')
  169. with numpy.errstate(all='ignore'):
  170. snippets = onnx.backend.test.case.collect_snippets()
  171. all_schemas_with_history = onnx.defs.get_all_schemas_with_history()
  172. for schema in all_schemas_with_history:
  173. json_schema = {}
  174. json_schema['name'] = schema.name
  175. json_schema['module'] = schema.domain if schema.domain else 'ai.onnx'
  176. json_schema['version'] = schema.since_version
  177. if schema.support_level != onnx.defs.OpSchema.SupportType.COMMON:
  178. json_schema['status'] = schema.support_level.name.lower()
  179. json_schema['description'] = _format_description(schema.doc.lstrip())
  180. if schema.attributes:
  181. _update_attributes(json_schema, schema)
  182. if schema.inputs:
  183. _update_inputs(json_schema, schema.inputs)
  184. json_schema['min_input'] = schema.min_input
  185. json_schema['max_input'] = schema.max_input
  186. if schema.outputs:
  187. _update_outputs(json_schema, schema.outputs)
  188. json_schema['min_output'] = schema.min_output
  189. json_schema['max_output'] = schema.max_output
  190. if schema.min_input != schema.max_input:
  191. json_schema['inputs_range'] = _format_range(schema.min_input) + ' - ' \
  192. + _format_range(schema.max_input)
  193. if schema.min_output != schema.max_output:
  194. json_schema['outputs_range'] = _format_range(schema.min_output) + ' - ' \
  195. + _format_range(schema.max_output)
  196. if schema.type_constraints:
  197. _update_type_constraints(json_schema, schema.type_constraints)
  198. if schema.name in snippets:
  199. _update_snippets(json_schema, snippets[schema.name])
  200. if schema.name in categories:
  201. json_schema['category'] = categories[schema.name]
  202. json_root.append(json_schema)
  203. json_root = sorted(json_root, key=lambda item: item['name'] + ':' + \
  204. str(item['version'] if 'version' in item else 0).zfill(4))
  205. root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
  206. json_file = os.path.join(root_dir, 'source', 'onnx-metadata.json')
  207. content = _read(json_file)
  208. items = json.loads(content)
  209. items = list(filter(lambda item: item['module'] == "com.microsoft", items))
  210. json_root = json_root + items
  211. _write(json_file, json.dumps(json_root, indent=2))
  212. def main(): # pylint: disable=missing-function-docstring
  213. _metadata()
  214. if __name__ == '__main__':
  215. main()