onnx-operator-json.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #!/usr/bin/env python
  2. from __future__ import unicode_literals
  3. import json
  4. import io
  5. import sys
  6. from onnx import defs
  7. from onnx.defs import OpSchema
  8. from onnx.backend.test.case.node import collect_snippets
  9. SNIPPETS = collect_snippets()
  10. def generate_json_attr_type(type):
  11. assert isinstance(type, OpSchema.AttrType)
  12. s = str(type)
  13. s = s[s.rfind('.')+1:].lower()
  14. if s[-1] == 's':
  15. s = 'list of ' + s
  16. return s
  17. def generate_json_support_level_name(support_level):
  18. assert isinstance(support_level, OpSchema.SupportType)
  19. s = str(support_level)
  20. return s[s.rfind('.')+1:].lower()
  21. def generate_json_types(types):
  22. r = []
  23. for type in types:
  24. r.append(type)
  25. r = sorted(r)
  26. return r
  27. def generate_json(schemas, json_file):
  28. json_root = []
  29. for schema in schemas:
  30. json_schema = {}
  31. if schema.domain:
  32. json_schema['domain'] = schema.domain
  33. else:
  34. json_schema['domain'] = 'ai.onnx'
  35. json_schema['since_version'] = schema.since_version
  36. json_schema['support_level'] = generate_json_support_level_name(schema.support_level)
  37. if schema.doc:
  38. json_schema['description'] = schema.doc.lstrip();
  39. if schema.inputs:
  40. json_schema['inputs'] = []
  41. for input in schema.inputs:
  42. option = ''
  43. if input.option == OpSchema.FormalParameterOption.Optional:
  44. option = 'optional'
  45. elif input.option == OpSchema.FormalParameterOption.Variadic:
  46. option = 'variadic'
  47. json_schema['inputs'].append({
  48. 'name': input.name,
  49. 'description': input.description,
  50. 'option': option,
  51. 'typeStr': input.typeStr,
  52. 'types': generate_json_types(input.types) })
  53. json_schema['min_input'] = schema.min_input;
  54. json_schema['max_input'] = schema.max_input;
  55. if schema.outputs:
  56. json_schema['outputs'] = []
  57. for output in schema.outputs:
  58. option = ''
  59. if output.option == OpSchema.FormalParameterOption.Optional:
  60. option = 'optional'
  61. elif output.option == OpSchema.FormalParameterOption.Variadic:
  62. option = 'variadic'
  63. json_schema['outputs'].append({
  64. 'name': output.name,
  65. 'description': output.description,
  66. 'option': option,
  67. 'typeStr': output.typeStr,
  68. 'types': generate_json_types(output.types) })
  69. json_schema['min_output'] = schema.min_output;
  70. json_schema['max_output'] = schema.max_output;
  71. if schema.attributes:
  72. json_schema['attributes'] = []
  73. for _, attribute in sorted(schema.attributes.items()):
  74. json_schema['attributes'].append({
  75. 'name' : attribute.name,
  76. 'description': attribute.description,
  77. 'type': generate_json_attr_type(attribute.type),
  78. 'required': attribute.required })
  79. if schema.type_constraints:
  80. json_schema["type_constraints"] = []
  81. for type_constraint in schema.type_constraints:
  82. json_schema['type_constraints'].append({
  83. 'description': type_constraint.description,
  84. 'type_param_str': type_constraint.type_param_str,
  85. 'allowed_type_strs': type_constraint.allowed_type_strs
  86. })
  87. if schema.name in SNIPPETS:
  88. json_schema['snippets'] = []
  89. for summary, code in sorted(SNIPPETS[schema.name]):
  90. json_schema['snippets'].append({
  91. 'summary': summary,
  92. 'code': code
  93. })
  94. json_root.append({
  95. "name": schema.name,
  96. "schema": json_schema
  97. })
  98. with io.open(json_file, 'w', newline='') as fout:
  99. json_root = json.dumps(json_root, sort_keys=True, indent=2)
  100. for line in json_root.splitlines():
  101. line = line.rstrip()
  102. if sys.version_info[0] < 3:
  103. line = unicode(line)
  104. fout.write(line)
  105. fout.write('\n')
  106. if __name__ == '__main__':
  107. schemas = sorted(defs.get_all_schemas_with_history(), key=lambda schema: schema.name)
  108. generate_json(schemas, '../src/onnx-operator.json')
  109. # print(schema.name + "|" + schema.domain + "|" + str(schema.since_version))
  110. # sorted_ops = sorted(
  111. # (int(schema.support_level), op_type, schema)
  112. # for (op_type, schema) in defs.get_all_schemas().items())