pytorch_metadata.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. ''' TorchScript metadata script '''
  2. import collections
  3. import json
  4. import os
  5. import re
  6. import sys
  7. root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  8. sys.path.append(root_dir)
  9. sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'test', 'backend')
  10. pytorch = __import__('source.pytorch').pytorch
  11. source_dir = os.path.join(root_dir, 'source')
  12. third_party_dir = os.path.join(root_dir, 'third_party')
  13. metadata_file = os.path.join(source_dir, 'pytorch-metadata.json')
  14. pytorch_source_dir = os.path.join(third_party_dir, 'source', 'pytorch')
  15. def _read(path):
  16. with open(path, 'r', encoding='utf-8') as file:
  17. return file.read()
  18. def _write(path, content):
  19. with open(path, 'w', encoding='utf-8') as file:
  20. file.write(content)
  21. def _read_metadata():
  22. metadata: list[dict[str,object]] = json.loads(_read(metadata_file))
  23. return dict(map(lambda _: ( _['name'], _ ), metadata))
  24. def _write_metadata(value):
  25. metadata = list(collections.OrderedDict(sorted(value.items())).values())
  26. content = json.dumps(metadata, indent=2, ensure_ascii=False)
  27. content = re.sub(r'\s {8}', ' ', content)
  28. content = re.sub(r',\s {8}', ', ', content)
  29. content = re.sub(r'\s {6}}', ' }', content)
  30. _write(metadata_file, content)
  31. schema_source_files = [
  32. ('aten/src/ATen/native/native_functions.yaml',
  33. re.compile(r'-\s*func:\s*(.*)', re.MULTILINE), 'aten::'),
  34. ('aten/src/ATen/native/quantized/library.cpp',
  35. re.compile(r'TORCH_SELECTIVE_SCHEMA\("(.*)"\)', re.MULTILINE)),
  36. ('aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp',
  37. re.compile(r'TORCH_SELECTIVE_SCHEMA\("(.*)"', re.MULTILINE)),
  38. ('torch/csrc/jit/runtime/register_prim_ops.cpp',
  39. re.compile(r'(aten::.*->\s*.*)"', re.MULTILINE)),
  40. ('torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp',
  41. re.compile(r'(aten::.*->\s*.*)"', re.MULTILINE)),
  42. ('torch/csrc/jit/runtime/register_special_ops.cpp',
  43. re.compile(r'(aten::.*->\s*.*)"', re.MULTILINE)),
  44. ('caffe2/operators/copy_op.cc',
  45. re.compile(r'(_caffe2::.*->\s*Tensor)', re.MULTILINE)),
  46. ('caffe2/operators/batch_permutation_op.cc',
  47. re.compile(r'(_caffe2::.*->\s*Tensor)', re.MULTILINE)),
  48. ('caffe2/operators/collect_and_distribute_fpn_rpn_proposals_op.cc',
  49. re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
  50. ('caffe2/operators/box_with_nms_limit_op.cc',
  51. re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
  52. ('caffe2/operators/bbox_transform_op.cc',
  53. re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
  54. ('caffe2/operators/generate_proposals_op.cc',
  55. re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
  56. ('caffe2/operators/roi_align_op.cc',
  57. re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->.*)"', re.MULTILINE))
  58. ]
  59. known_schema_definitions = [
  60. 'aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(b|a)', # pylint: disable=line-too-long
  61. 'aten::as_tensor.bool(bool t, *, ScalarType? dtype=None, Device? device=None) -> Tensor',
  62. 'aten::as_tensor.complex(complex t, *, ScalarType? dtype=None, Device? device=None) -> Tensor',
  63. 'aten::as_tensor.float(float t, *, ScalarType? dtype=None, Device? device=None) -> Tensor',
  64. 'aten::as_tensor.int(int t, *, ScalarType? dtype=None, Device? device=None) -> Tensor',
  65. 'aten::as_tensor.list(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor'
  66. ]
  67. def _parse_schemas():
  68. schemas = {}
  69. for entry in schema_source_files:
  70. path = os.path.join(pytorch_source_dir, entry[0])
  71. content = _read(path)
  72. for value in entry[1].findall(content):
  73. value = re.sub(r'\n|\r|\s*"', '', value) if value.startswith('_caffe2::') else value
  74. definition = entry[2] + value if len(entry) > 2 else value
  75. schema = pytorch.Schema(definition)
  76. if schema.name in schemas:
  77. raise KeyError()
  78. schemas[schema.name] = schema
  79. for definition in known_schema_definitions:
  80. schema = pytorch.Schema(definition)
  81. schemas[schema.name] = schema
  82. return schemas
  83. def _filter_schemas(schemas, types):
  84. keys = set(map(lambda _: _.split('.')[0], types.keys()))
  85. filtered_schemas = set()
  86. for schema in schemas.values():
  87. for key in keys:
  88. if schema.name == key or schema.name.startswith(key + '.'):
  89. filtered_schemas.add(schema.name)
  90. # filtered_schemas = set(types.keys())
  91. # content = _read('list.csv')
  92. # regex = re.compile(r'Unsupported function \'(.*)\' in', re.MULTILINE)
  93. # matches = set()
  94. # for match in regex.findall(content):
  95. # if match.startswith('torch.'):
  96. # matches.add('aten::' + match[6:])
  97. # if match.startswith('ops.') and len(match.split('.')) > 2:
  98. # matches.add(match[4:].replace('.', '::'))
  99. # for schema in schemas.values():
  100. # for match in matches:
  101. # if schema.name.startswith(match):
  102. # filtered_schemas.add(schema.name)
  103. return dict(filter(lambda _: _[0] in filtered_schemas, schemas.items()))
  104. def _check_schemas(schemas): # pylint: disable=unused-argument
  105. # import torch
  106. # for name in dir(torch.ops.aten):
  107. # if name.startswith('__') or name == 'name':
  108. # continue
  109. # packet = getattr(torch.ops.aten, name)
  110. # for overload in packet.overloads():
  111. # key = 'aten::' + name + ('.' + overload if overload != 'default' else '')
  112. # overload_schema = str(getattr(packet, overload)._schema)
  113. # if key in schemas:
  114. # schema = schemas[key]
  115. # if overload_schema != str(schema):
  116. # print(overload_schema)
  117. # print(schema)
  118. pass
  119. def _check_types(types, schemas):
  120. types = dict(types.items())
  121. for schema in schemas.values():
  122. if schema.name in types:
  123. types.pop(schema.name)
  124. for key in list(types.keys()):
  125. if key.startswith('torch.nn'):
  126. types.pop(key)
  127. if key.startswith('torchvision::') or \
  128. key.startswith('torchaudio::') or \
  129. key.startswith('neuron::'):
  130. types.pop(key)
  131. types.pop('aten::fft')
  132. types.pop('aten::mul.ScalarT')
  133. types.pop('aten::classes._nnapi.Compilation')
  134. if len(types) > 0:
  135. raise Exception('\n'.join(list(types.keys()))) # pylint: disable=broad-exception-raised
  136. def _metadata():
  137. types = _read_metadata()
  138. schemas = _parse_schemas()
  139. _check_types(types, schemas)
  140. _check_schemas(schemas)
  141. filtered_schemas = _filter_schemas(schemas, types)
  142. metadata = pytorch.Metadata(types)
  143. for schema in filtered_schemas.values():
  144. metadata.type(schema)
  145. _write_metadata(types)
  146. def main(): # pylint: disable=missing-function-docstring
  147. _metadata()
  148. if __name__ == '__main__':
  149. main()