| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- from __future__ import unicode_literals
- from __future__ import print_function
- import io
- import json
- import logging
- import pydoc
- import os
- import re
- import sys
- def get_support_level(dir):
- dir = dir.replace('\\', '/')
- if 'caffe2/caffe2/operators' in dir:
- return 'core'
- if 'contrib' in dir.split('/'):
- return 'contribution'
- if 'experiments' in dir.split('/'):
- return 'experimental'
- return 'default'
- def update_argument_type(type):
- if type == 'int' or type == 'int64_t':
- return 'int64'
- if type == 'int32_t':
- return 'int32'
- elif type == '[int]' or type == 'int[]':
- return 'int64[]'
- elif type == 'float':
- return 'float32'
- elif type == 'string':
- return 'string'
- elif type == 'List(string)':
- return 'string[]'
- elif type == 'bool':
- return 'boolean'
- raise Exception('Unknown argument type ' + str(type))
- def update_argument_default(value, type):
- if type == 'int64':
- return int(value)
- elif type == 'float32':
- return float(value.rstrip('~'))
- elif type == 'boolean':
- if value == 'True':
- return True
- if value == 'False':
- return False
- elif type == 'string':
- return value.strip('\"')
- raise Exception('Unknown argument type ' + str(type))
- def update_argument(schema, arg):
- if not 'attributes' in schema:
- schema['attributes'] = []
- attribute = None
- for current_attribute in schema['attributes']:
- if 'name' in current_attribute and current_attribute['name'] == arg.name:
- attribute = current_attribute
- break
- if not attribute:
- attribute = {}
- attribute['name'] = arg.name
- schema['attributes'].append(attribute)
- description = arg.description.strip()
- if description.startswith('*('):
- index = description.find(')*')
- properties = []
- if index != -1:
- properties = description[2:index].split(';')
- description = description[index+2:].lstrip()
- else:
- index = description.index(')')
- properties = description[2:index].split(';')
- description = description[index+1:].lstrip()
- if len(properties) == 1 and properties[0].find(',') != -1:
- properties = properties[0].split(',')
- for property in properties:
- parts = property.split(':')
- name = parts[0].strip()
- if name == 'type':
- type = parts[1].strip()
- if type == 'primitive' or type == 'int | Tuple(int)' or type == '[]' or type == 'TensorProto_DataType' or type == 'Tuple(int)':
- continue
- attribute['type'] = update_argument_type(type)
- elif name == 'default':
- if 'type' in attribute:
- type = attribute['type']
- default = parts[1].strip()
- if default == '2, possible values':
- default = '2'
- if type == 'float32' and default == '\'NCHW\'':
- continue
- if type == 'int64[]':
- continue
- attribute['default'] = update_argument_default(default, type)
- elif name == 'optional':
- attribute['option'] = 'optional'
- elif name == 'must be > 1.0' or name == 'default=\'NCHW\'' or name == 'type depends on dtype' or name == 'Required=True':
- continue
- elif name == 'List(string)':
- attribute['type'] = 'string[]'
- else:
- raise Exception('Unknown property ' + str(parts[0].strip()))
- attribute['description'] = description
- if not arg.required:
- attribute['option'] = 'optional'
- return
- def update_input(schema, input_desc):
- input_name = input_desc[0]
- description = input_desc[1]
- if not 'inputs' in schema:
- schema['inputs'] = []
- input_arg = None
- for current_input in schema['inputs']:
- if 'name' in current_input and current_input['name'] == input_name:
- input_arg = current_input
- break
- if not input_arg:
- input_arg = {}
- input_arg['name'] = input_name
- schema['inputs'].append(input_arg)
- input_arg['description'] = description
- if len(input_desc) > 2:
- return
- def update_output(operator_name, schema, output_desc):
- output_name = output_desc[0]
- description = output_desc[1]
- if not 'outputs' in schema:
- schema['outputs'] = []
- output_arg = None
- for current_output in schema['outputs']:
- if 'name' in current_output and current_output['name'] == output_name:
- output_arg = current_output
- break
- if not output_arg:
- output_arg = {}
- output_arg['name'] = output_name
- schema['outputs'].append(output_arg)
- output_arg['description'] = description
- if len(output_desc) > 2:
- return
- class Caffe2Filter(logging.Filter):
- def filter(self, record):
- return record.getMessage().startswith('WARNING:root:This caffe2 python run does not have GPU support.')
- def metadata():
- logging.getLogger('').addFilter(Caffe2Filter())
- import caffe2.python.core
- json_file = os.path.join(os.path.dirname(__file__), '../src/caffe2-metadata.json')
- json_data = open(json_file).read()
- json_root = json.loads(json_data)
- schema_map = {}
- for entry in json_root:
- operator_name = entry['name']
- schema = entry['schema']
- schema_map[operator_name] = schema
- for operator_name in caffe2.python.core._GetRegisteredOperators():
- op_schema = caffe2.python.workspace.C.OpSchema.get(operator_name)
- if op_schema:
- if operator_name == 'Crash':
- continue
- if operator_name in schema_map:
- schema = schema_map[operator_name]
- else:
- schema = {}
- entry = { 'name': operator_name, 'schema': schema }
- schema_map[operator_name] = entry
- json_root.append(entry)
- schema['description'] = op_schema.doc
- for arg in op_schema.args:
- update_argument(schema, arg)
- for input_desc in op_schema.input_desc:
- update_input(schema, input_desc)
- for output_desc in op_schema.output_desc:
- update_output(operator_name, schema, output_desc)
- schema['support_level'] = get_support_level(os.path.dirname(op_schema.file))
- with io.open(json_file, 'w', newline='') as fout:
- json_data = json.dumps(json_root, sort_keys=True, indent=2)
- for line in json_data.splitlines():
- line = line.rstrip()
- if sys.version_info[0] < 3:
- line = unicode(line)
- fout.write(line)
- fout.write('\n')
- if __name__ == '__main__':
- command_table = { 'metadata': metadata }
- command = sys.argv[1];
- command_table[command]()
|