caffe2-metadata.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. #!/usr/bin/env python
  2. from __future__ import unicode_literals
  3. from __future__ import print_function
  4. import io
  5. import json
  6. import pydoc
  7. import os
  8. import re
  9. import sys
  10. import caffe2.python.core
  11. json_file = '../src/caffe2-metadata.json'
  12. json_data = open(json_file).read()
  13. json_root = json.loads(json_data)
  14. def get_support_level(dir):
  15. if 'caffe2/caffe2/operators' in dir:
  16. return 'core'
  17. if 'contrib' in dir.split('/'):
  18. return 'contribution'
  19. if 'experiments' in dir.split('/'):
  20. return 'experimental'
  21. return 'default'
  22. def update_argument(schema, arg):
  23. if not 'attributes' in schema:
  24. schema['attributes'] = []
  25. attribute = None
  26. for current_attribute in schema['attributes']:
  27. if 'name' in current_attribute and current_attribute['name'] == arg.name:
  28. attribute = current_attribute
  29. break
  30. if not attribute:
  31. attribute = {}
  32. attribute['name'] = arg.name
  33. schema['attributes'].append(attribute)
  34. attribute['description'] = arg.description
  35. if not arg.required:
  36. attribute['option'] = 'optional'
  37. return
  38. def update_input(schema, input_desc):
  39. name = input_desc[0]
  40. description = input_desc[1]
  41. if not 'inputs' in schema:
  42. schema['inputs'] = []
  43. input_arg = None
  44. for current_input in schema['inputs']:
  45. if 'name' in current_input and current_input['name'] == name:
  46. input_arg = current_input
  47. break
  48. if not input_arg:
  49. input_arg = {}
  50. input_arg['name'] = name
  51. schema['inputs'].append(input_arg)
  52. input_arg['description'] = description
  53. if len(input_desc) > 2:
  54. return
  55. def update_output(schema, output_desc):
  56. name = output_desc[0]
  57. description = output_desc[1]
  58. if not 'outputs' in schema:
  59. schema['outputs'] = []
  60. output_arg = None
  61. for current_output in schema['outputs']:
  62. if 'name' in current_output and current_output['name'] == name:
  63. output_arg = current_output
  64. break
  65. if not output_arg:
  66. output_arg = {}
  67. output_arg['name'] = name
  68. schema['outputs'].append(output_arg)
  69. output_arg['description'] = description
  70. if len(output_desc) > 2:
  71. return
  72. schema_map = {}
  73. for entry in json_root:
  74. name = entry['name']
  75. schema = entry['schema']
  76. schema_map[name] = schema
  77. for name in caffe2.python.core._GetRegisteredOperators():
  78. op_schema = caffe2.python.workspace.C.OpSchema.get(name)
  79. if op_schema:
  80. if name in schema_map:
  81. schema = schema_map[name]
  82. else:
  83. schema = {}
  84. schema_map[name] = { 'name': name, 'schema': schema }
  85. schema['description'] = op_schema.doc
  86. for arg in op_schema.args:
  87. update_argument(schema, arg)
  88. for input_desc in op_schema.input_desc:
  89. update_input(schema, input_desc)
  90. for output_desc in op_schema.output_desc:
  91. update_output(schema, output_desc)
  92. schema['support_level'] = get_support_level(os.path.dirname(op_schema.file))
  93. with io.open(json_file, 'w', newline='') as fout:
  94. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  95. for line in json_data.splitlines():
  96. line = line.rstrip()
  97. if sys.version_info[0] < 3:
  98. line = unicode(line)
  99. fout.write(line)
  100. fout.write('\n')