caffe2-script.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from __future__ import unicode_literals
  2. from __future__ import print_function
  3. import io
  4. import json
  5. import logging
  6. import pydoc
  7. import os
  8. import re
  9. import sys
  10. def get_support_level(dir):
  11. dir = dir.replace('\\', '/')
  12. if 'caffe2/caffe2/operators' in dir:
  13. return 'core'
  14. if 'contrib' in dir.split('/'):
  15. return 'contribution'
  16. if 'experiments' in dir.split('/'):
  17. return 'experimental'
  18. return 'default'
  19. def update_argument_type(type):
  20. if type == 'int' or type == 'int64_t':
  21. return 'int64'
  22. if type == 'int32_t':
  23. return 'int32'
  24. elif type == '[int]' or type == 'int[]':
  25. return 'int64[]'
  26. elif type == 'float':
  27. return 'float32'
  28. elif type == 'string':
  29. return 'string'
  30. elif type == 'List(string)':
  31. return 'string[]'
  32. elif type == 'bool':
  33. return 'boolean'
  34. raise Exception('Unknown argument type ' + str(type))
  35. def update_argument_default(value, type):
  36. if type == 'int64':
  37. return int(value)
  38. elif type == 'float32':
  39. return float(value.rstrip('~'))
  40. elif type == 'boolean':
  41. if value == 'True':
  42. return True
  43. if value == 'False':
  44. return False
  45. elif type == 'string':
  46. return value.strip('\"')
  47. raise Exception('Unknown argument type ' + str(type))
  48. def update_argument(schema, arg):
  49. if not 'attributes' in schema:
  50. schema['attributes'] = []
  51. attribute = None
  52. for current_attribute in schema['attributes']:
  53. if 'name' in current_attribute and current_attribute['name'] == arg.name:
  54. attribute = current_attribute
  55. break
  56. if not attribute:
  57. attribute = {}
  58. attribute['name'] = arg.name
  59. schema['attributes'].append(attribute)
  60. description = arg.description.strip()
  61. if description.startswith('*('):
  62. index = description.find(')*')
  63. properties = []
  64. if index != -1:
  65. properties = description[2:index].split(';')
  66. description = description[index+2:].lstrip()
  67. else:
  68. index = description.index(')')
  69. properties = description[2:index].split(';')
  70. description = description[index+1:].lstrip()
  71. if len(properties) == 1 and properties[0].find(',') != -1:
  72. properties = properties[0].split(',')
  73. for property in properties:
  74. parts = property.split(':')
  75. name = parts[0].strip()
  76. if name == 'type':
  77. type = parts[1].strip()
  78. if type == 'primitive' or type == 'int | Tuple(int)' or type == '[]' or type == 'TensorProto_DataType' or type == 'Tuple(int)':
  79. continue
  80. attribute['type'] = update_argument_type(type)
  81. elif name == 'default':
  82. if 'type' in attribute:
  83. type = attribute['type']
  84. default = parts[1].strip()
  85. if default == '2, possible values':
  86. default = '2'
  87. if type == 'float32' and default == '\'NCHW\'':
  88. continue
  89. if type == 'int64[]':
  90. continue
  91. attribute['default'] = update_argument_default(default, type)
  92. elif name == 'optional':
  93. attribute['option'] = 'optional'
  94. elif name == 'must be > 1.0' or name == 'default=\'NCHW\'' or name == 'type depends on dtype' or name == 'Required=True':
  95. continue
  96. elif name == 'List(string)':
  97. attribute['type'] = 'string[]'
  98. else:
  99. raise Exception('Unknown property ' + str(parts[0].strip()))
  100. attribute['description'] = description
  101. if not arg.required:
  102. attribute['option'] = 'optional'
  103. return
  104. def update_input(schema, input_desc):
  105. input_name = input_desc[0]
  106. description = input_desc[1]
  107. if not 'inputs' in schema:
  108. schema['inputs'] = []
  109. input_arg = None
  110. for current_input in schema['inputs']:
  111. if 'name' in current_input and current_input['name'] == input_name:
  112. input_arg = current_input
  113. break
  114. if not input_arg:
  115. input_arg = {}
  116. input_arg['name'] = input_name
  117. schema['inputs'].append(input_arg)
  118. input_arg['description'] = description
  119. if len(input_desc) > 2:
  120. return
  121. def update_output(operator_name, schema, output_desc):
  122. output_name = output_desc[0]
  123. description = output_desc[1]
  124. if not 'outputs' in schema:
  125. schema['outputs'] = []
  126. output_arg = None
  127. for current_output in schema['outputs']:
  128. if 'name' in current_output and current_output['name'] == output_name:
  129. output_arg = current_output
  130. break
  131. if not output_arg:
  132. output_arg = {}
  133. output_arg['name'] = output_name
  134. schema['outputs'].append(output_arg)
  135. output_arg['description'] = description
  136. if len(output_desc) > 2:
  137. return
  138. class Caffe2Filter(logging.Filter):
  139. def filter(self, record):
  140. return record.getMessage().startswith('WARNING:root:This caffe2 python run does not have GPU support.')
  141. def metadata():
  142. logging.getLogger('').addFilter(Caffe2Filter())
  143. import caffe2.python.core
  144. json_file = os.path.join(os.path.dirname(__file__), '../src/caffe2-metadata.json')
  145. json_data = open(json_file).read()
  146. json_root = json.loads(json_data)
  147. schema_map = {}
  148. for entry in json_root:
  149. operator_name = entry['name']
  150. schema = entry['schema']
  151. schema_map[operator_name] = schema
  152. for operator_name in caffe2.python.core._GetRegisteredOperators():
  153. op_schema = caffe2.python.workspace.C.OpSchema.get(operator_name)
  154. if op_schema:
  155. if operator_name == 'Crash':
  156. continue
  157. if operator_name in schema_map:
  158. schema = schema_map[operator_name]
  159. else:
  160. schema = {}
  161. entry = { 'name': operator_name, 'schema': schema }
  162. schema_map[operator_name] = entry
  163. json_root.append(entry)
  164. schema['description'] = op_schema.doc
  165. for arg in op_schema.args:
  166. update_argument(schema, arg)
  167. for input_desc in op_schema.input_desc:
  168. update_input(schema, input_desc)
  169. for output_desc in op_schema.output_desc:
  170. update_output(operator_name, schema, output_desc)
  171. schema['support_level'] = get_support_level(os.path.dirname(op_schema.file))
  172. with io.open(json_file, 'w', newline='') as fout:
  173. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  174. for line in json_data.splitlines():
  175. line = line.rstrip()
  176. if sys.version_info[0] < 3:
  177. line = unicode(line)
  178. fout.write(line)
  179. fout.write('\n')
  180. if __name__ == '__main__':
  181. command_table = { 'metadata': metadata }
  182. command = sys.argv[1];
  183. command_table[command]()