onnx-metadata.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #!/usr/bin/env python
  2. from __future__ import unicode_literals
  3. import json
  4. import io
  5. import sys
  6. from onnx import defs
  7. from onnx.defs import OpSchema
  8. from onnx.backend.test.case import collect_snippets
  9. snippets = collect_snippets()
  10. categories = {
  11. 'Constant': 'Constant',
  12. 'Conv': 'Layer',
  13. 'ConvTranspose': 'Layer',
  14. 'FC': 'Layer',
  15. 'RNN': 'Layer',
  16. 'LSTM': 'Layer',
  17. 'GRU': 'Layer',
  18. 'Dropout': 'Dropout',
  19. 'Elu': 'Activation',
  20. 'HardSigmoid': 'Activation',
  21. 'LeakyRelu': 'Activation',
  22. 'PRelu': 'Activation',
  23. 'ThresholdedRelu': 'Activation',
  24. 'Relu': 'Activation',
  25. 'Selu': 'Activation',
  26. 'Sigmoid': 'Activation',
  27. 'Tanh': 'Activation',
  28. 'LogSoftmax': 'Activation',
  29. 'Softmax': 'Activation',
  30. 'Softplus': 'Activation',
  31. 'Softsign': 'Activation',
  32. 'BatchNormalization': 'Normalization',
  33. 'InstanceNormalization': 'Normalization',
  34. 'LpNormalization': 'Normalization',
  35. 'LRN': 'Normalization',
  36. 'Flatten': 'Shape',
  37. 'Reshape': 'Shape',
  38. 'Transpose': 'Shape',
  39. 'Tile': 'Shape',
  40. 'Xor': 'Logic',
  41. 'Not': 'Logic',
  42. 'Or': 'Logic',
  43. 'Less': 'Logic',
  44. 'And': 'Logic',
  45. 'Greater': 'Logic',
  46. 'Equal': 'Logic',
  47. 'AveragePool': 'Pool',
  48. 'GlobalAveragePool': 'Pool',
  49. 'GlobalLpPool': 'Pool',
  50. 'GlobalMaxPool': 'Pool',
  51. 'LpPool': 'Pool',
  52. 'MaxPool': 'Pool',
  53. 'MaxRoiPool': 'Pool',
  54. 'Concat': 'Tensor',
  55. 'Slice': 'Tensor',
  56. 'Split': 'Tensor',
  57. 'Pad': 'Tensor',
  58. 'ImageScaler': 'Data',
  59. 'Crop': 'Data',
  60. # 'Gemm': '',
  61. # 'MatMul': '',
  62. # 'Hardmax':
  63. # 'Log':
  64. # 'Max':
  65. # 'Div': 'Basic',
  66. # 'Ceil': 'Basic',
  67. # 'Exp': 'Basic',
  68. # 'Floor': 'Basic',
  69. # 'Sqrt': 'Basic',
  70. # 'Sub': 'Basic',
  71. # 'Sum': 'Basic',
  72. # 'Min': 'Basic',
  73. # 'Mul': 'Basic',
  74. # 'Neg': 'Basic',
  75. # 'Abs': 'Basic',
  76. # 'Add': 'Basic',
  77. # 'Pow': 'Basic',
  78. # 'ArgMax':
  79. # 'ArgMin':
  80. # 'Cast':
  81. # 'Clip':
  82. # 'DepthToSpace':
  83. # 'Mean':
  84. # 'Pad':
  85. # 'RandomNormal':
  86. # 'RandomNormalLike':
  87. # 'RandomUniform':
  88. # 'RandomUniformLike':
  89. # 'Reciprocal':
  90. # 'ReduceL1':
  91. # 'ReduceL2':
  92. # 'ReduceLogSum':
  93. # 'ReduceLogSumExp':
  94. # 'ReduceMax':
  95. # 'ReduceMean':
  96. # 'ReduceMin':
  97. # 'ReduceProd':
  98. # 'ReduceSum':
  99. # 'ReduceSumSquare':
  100. # 'SpaceToDepth':
  101. # 'Squeeze':
  102. # 'Tile':
  103. # 'Gather':
  104. }
  105. def generate_json_attr_type(type):
  106. assert isinstance(type, OpSchema.AttrType)
  107. s = str(type)
  108. s = s[s.rfind('.')+1:].lower()
  109. if s[-1] == 's':
  110. s = s[0:-1] + '[]'
  111. return s
  112. def generate_json_support_level_name(support_level):
  113. assert isinstance(support_level, OpSchema.SupportType)
  114. s = str(support_level)
  115. return s[s.rfind('.')+1:].lower()
  116. def generate_json_types(types):
  117. r = []
  118. for type in types:
  119. r.append(type)
  120. r = sorted(r)
  121. return r
  122. def generate_json(schemas, json_file):
  123. json_root = []
  124. for schema in schemas:
  125. json_schema = {}
  126. if schema.domain:
  127. json_schema['domain'] = schema.domain
  128. else:
  129. json_schema['domain'] = 'ai.onnx'
  130. json_schema['since_version'] = schema.since_version
  131. json_schema['support_level'] = generate_json_support_level_name(schema.support_level)
  132. if schema.doc:
  133. json_schema['description'] = schema.doc.lstrip();
  134. if schema.inputs:
  135. json_schema['inputs'] = []
  136. for input in schema.inputs:
  137. json_input = {}
  138. json_input['name'] = input.name
  139. json_input['description'] = input.description
  140. json_input['type'] = input.typeStr
  141. if input.option == OpSchema.FormalParameterOption.Optional:
  142. json_input['option'] = 'optional'
  143. elif input.option == OpSchema.FormalParameterOption.Variadic:
  144. json_input['option'] = 'variadic'
  145. json_schema['inputs'].append(json_input)
  146. json_schema['min_input'] = schema.min_input;
  147. json_schema['max_input'] = schema.max_input;
  148. if schema.outputs:
  149. json_schema['outputs'] = []
  150. for output in schema.outputs:
  151. json_output = {}
  152. json_output['name'] = output.name
  153. json_output['description'] = output.description
  154. json_output['type'] = output.typeStr
  155. if output.option == OpSchema.FormalParameterOption.Optional:
  156. json_output['option'] = 'optional'
  157. elif output.option == OpSchema.FormalParameterOption.Variadic:
  158. json_output['option'] = 'variadic'
  159. json_schema['outputs'].append(json_output)
  160. json_schema['min_output'] = schema.min_output;
  161. json_schema['max_output'] = schema.max_output;
  162. if schema.attributes:
  163. json_schema['attributes'] = []
  164. for _, attribute in sorted(schema.attributes.items()):
  165. json_schema['attributes'].append({
  166. 'name' : attribute.name,
  167. 'description': attribute.description,
  168. 'type': generate_json_attr_type(attribute.type),
  169. 'required': attribute.required })
  170. if schema.type_constraints:
  171. json_schema['type_constraints'] = []
  172. for type_constraint in schema.type_constraints:
  173. json_schema['type_constraints'].append({
  174. 'description': type_constraint.description,
  175. 'type_param_str': type_constraint.type_param_str,
  176. 'allowed_type_strs': type_constraint.allowed_type_strs
  177. })
  178. if schema.name in snippets:
  179. json_schema['examples'] = []
  180. for summary, code in sorted(snippets[schema.name]):
  181. json_schema['examples'].append({
  182. 'summary': summary,
  183. 'code': code
  184. })
  185. if schema.name in categories:
  186. json_schema['category'] = categories[schema.name]
  187. json_root.append({
  188. 'name': schema.name,
  189. 'schema': json_schema
  190. })
  191. with io.open(json_file, 'w', newline='') as fout:
  192. json_root = json.dumps(json_root, sort_keys=True, indent=2)
  193. for line in json_root.splitlines():
  194. line = line.rstrip()
  195. if sys.version_info[0] < 3:
  196. line = unicode(line)
  197. fout.write(line)
  198. fout.write('\n')
  199. if __name__ == '__main__':
  200. schemas = sorted(defs.get_all_schemas_with_history(), key=lambda schema: schema.name)
  201. generate_json(schemas, '../src/onnx-metadata.json')