onnx-metadata.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #!/usr/bin/env python
  2. from __future__ import unicode_literals
  3. import json
  4. import io
  5. import sys
  6. from onnx import defs
  7. from onnx.defs import OpSchema
  8. from onnx.backend.test.case import collect_snippets
  9. snippets = collect_snippets()
  10. categories = {
  11. 'Constant': 'Constant',
  12. 'Conv': 'Layer',
  13. 'ConvTranspose': 'Layer',
  14. 'FC': 'Layer',
  15. 'RNN': 'Layer',
  16. 'LSTM': 'Layer',
  17. 'GRU': 'Layer',
  18. 'Gemm': 'Layer',
  19. 'Dropout': 'Dropout',
  20. 'Elu': 'Activation',
  21. 'HardSigmoid': 'Activation',
  22. 'LeakyRelu': 'Activation',
  23. 'PRelu': 'Activation',
  24. 'ThresholdedRelu': 'Activation',
  25. 'Relu': 'Activation',
  26. 'Selu': 'Activation',
  27. 'Sigmoid': 'Activation',
  28. 'Tanh': 'Activation',
  29. 'LogSoftmax': 'Activation',
  30. 'Softmax': 'Activation',
  31. 'Softplus': 'Activation',
  32. 'Softsign': 'Activation',
  33. 'BatchNormalization': 'Normalization',
  34. 'InstanceNormalization': 'Normalization',
  35. 'LpNormalization': 'Normalization',
  36. 'LRN': 'Normalization',
  37. 'Flatten': 'Shape',
  38. 'Reshape': 'Shape',
  39. 'Transpose': 'Shape',
  40. 'Tile': 'Shape',
  41. 'Xor': 'Logic',
  42. 'Not': 'Logic',
  43. 'Or': 'Logic',
  44. 'Less': 'Logic',
  45. 'And': 'Logic',
  46. 'Greater': 'Logic',
  47. 'Equal': 'Logic',
  48. 'AveragePool': 'Pool',
  49. 'GlobalAveragePool': 'Pool',
  50. 'GlobalLpPool': 'Pool',
  51. 'GlobalMaxPool': 'Pool',
  52. 'LpPool': 'Pool',
  53. 'MaxPool': 'Pool',
  54. 'MaxRoiPool': 'Pool',
  55. 'Concat': 'Tensor',
  56. 'Slice': 'Tensor',
  57. 'Split': 'Tensor',
  58. 'Pad': 'Tensor',
  59. 'ImageScaler': 'Data',
  60. 'Crop': 'Data',
  61. 'Gather': 'Transform',
  62. 'Unsqueeze': 'Transform',
  63. 'Squeeze': 'Transform',
  64. }
  65. def generate_json_attr_type(type):
  66. assert isinstance(type, OpSchema.AttrType)
  67. s = str(type)
  68. s = s[s.rfind('.')+1:].lower()
  69. if s[-1] == 's':
  70. s = s[0:-1] + '[]'
  71. return s
  72. def generate_json_attr_default_value(attr_value):
  73. if not str(attr_value):
  74. return None
  75. if attr_value.HasField('i'):
  76. return attr_value.i
  77. if attr_value.HasField('s'):
  78. return attr_value.s.decode('utf8')
  79. if attr_value.HasField('f'):
  80. return attr_value.f
  81. return None
  82. def generate_json_support_level_name(support_level):
  83. assert isinstance(support_level, OpSchema.SupportType)
  84. s = str(support_level)
  85. return s[s.rfind('.')+1:].lower()
  86. def generate_json_types(types):
  87. r = []
  88. for type in types:
  89. r.append(type)
  90. r = sorted(r)
  91. return r
  92. def generate_json(schemas, json_file):
  93. json_root = []
  94. for schema in schemas:
  95. json_schema = {}
  96. if schema.domain:
  97. json_schema['domain'] = schema.domain
  98. else:
  99. json_schema['domain'] = 'ai.onnx'
  100. json_schema['since_version'] = schema.since_version
  101. json_schema['support_level'] = generate_json_support_level_name(schema.support_level)
  102. if schema.doc:
  103. json_schema['description'] = schema.doc.lstrip()
  104. if schema.inputs:
  105. json_schema['inputs'] = []
  106. for input in schema.inputs:
  107. json_input = {}
  108. json_input['name'] = input.name
  109. json_input['description'] = input.description
  110. json_input['type'] = input.typeStr
  111. if input.option == OpSchema.FormalParameterOption.Optional:
  112. json_input['option'] = 'optional'
  113. elif input.option == OpSchema.FormalParameterOption.Variadic:
  114. json_input['option'] = 'variadic'
  115. json_schema['inputs'].append(json_input)
  116. json_schema['min_input'] = schema.min_input
  117. json_schema['max_input'] = schema.max_input
  118. if schema.outputs:
  119. json_schema['outputs'] = []
  120. for output in schema.outputs:
  121. json_output = {}
  122. json_output['name'] = output.name
  123. json_output['description'] = output.description
  124. json_output['type'] = output.typeStr
  125. if output.option == OpSchema.FormalParameterOption.Optional:
  126. json_output['option'] = 'optional'
  127. elif output.option == OpSchema.FormalParameterOption.Variadic:
  128. json_output['option'] = 'variadic'
  129. json_schema['outputs'].append(json_output)
  130. json_schema['min_output'] = schema.min_output
  131. json_schema['max_output'] = schema.max_output
  132. if schema.attributes:
  133. json_schema['attributes'] = []
  134. for _, attribute in sorted(schema.attributes.items()):
  135. json_attribute = {}
  136. json_attribute['name'] = attribute.name
  137. json_attribute['description'] = attribute.description
  138. json_attribute['type'] = generate_json_attr_type(attribute.type)
  139. json_attribute['required'] = attribute.required
  140. default_value = generate_json_attr_default_value(attribute.default_value)
  141. if default_value:
  142. json_attribute['default'] = default_value
  143. json_schema['attributes'].append(json_attribute)
  144. if schema.type_constraints:
  145. json_schema['type_constraints'] = []
  146. for type_constraint in schema.type_constraints:
  147. json_schema['type_constraints'].append({
  148. 'description': type_constraint.description,
  149. 'type_param_str': type_constraint.type_param_str,
  150. 'allowed_type_strs': type_constraint.allowed_type_strs
  151. })
  152. if schema.name in snippets:
  153. json_schema['examples'] = []
  154. for summary, code in sorted(snippets[schema.name]):
  155. json_schema['examples'].append({
  156. 'summary': summary,
  157. 'code': code
  158. })
  159. if schema.name in categories:
  160. json_schema['category'] = categories[schema.name]
  161. json_root.append({
  162. 'name': schema.name,
  163. 'schema': json_schema
  164. })
  165. with io.open(json_file, 'w', newline='') as fout:
  166. json_root = json.dumps(json_root, sort_keys=True, indent=2)
  167. for line in json_root.splitlines():
  168. line = line.rstrip()
  169. if sys.version_info[0] < 3:
  170. line = unicode(line)
  171. fout.write(line)
  172. fout.write('\n')
  173. if __name__ == '__main__':
  174. schemas = defs.get_all_schemas_with_history()
  175. schemas = sorted(schemas, key=lambda schema: schema.name)
  176. generate_json(schemas, '../../src/onnx-metadata.json')