|
|
@@ -9,7 +9,7 @@ import sys
|
|
|
|
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
sys.path.append(root_dir)
|
|
|
-sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'test', 'backend')
|
|
|
+sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'pytorch_script')
|
|
|
|
|
|
source_dir = os.path.join(root_dir, 'source')
|
|
|
third_party_dir = os.path.join(root_dir, 'third_party')
|
|
|
@@ -34,8 +34,7 @@ def _read_metadata():
|
|
|
metadata[key] = value
|
|
|
return metadata
|
|
|
|
|
|
-def _write_metadata(value):
|
|
|
- metadata = list(collections.OrderedDict(sorted(value.items())).values())
|
|
|
+def _write_metadata(metadata):
|
|
|
content = json.dumps(metadata, indent=2, ensure_ascii=False)
|
|
|
content = re.sub(r'\s {8}', ' ', content)
|
|
|
content = re.sub(r',\s {8}', ', ', content)
|
|
|
@@ -88,13 +87,15 @@ known_legacy_schema_definitions = [
|
|
|
def _identifier(schema):
|
|
|
return schema.split('(', 1)[0].strip()
|
|
|
|
|
|
-def _parse_schemas():
|
|
|
- schemas = {}
|
|
|
+def _all_schemas():
|
|
|
torch = __import__('torch')
|
|
|
__import__('torchvision')
|
|
|
__import__('torchaudio')
|
|
|
- all_schemas = list(torch._C._jit_get_all_schemas()) # pylint: disable=protected-access
|
|
|
- for schema in all_schemas:
|
|
|
+ return list(torch._C._jit_get_all_schemas()) # pylint: disable=protected-access
|
|
|
+
|
|
|
+def _parse_schemas():
|
|
|
+ schemas = {}
|
|
|
+ for schema in _all_schemas():
|
|
|
definition = str(schema)
|
|
|
definition = definition.replace('(b|a)', '(a|b)')
|
|
|
key = _identifier(definition)
|
|
|
@@ -131,6 +132,34 @@ def _check_types(types, schemas):
|
|
|
if len(types) > 0:
|
|
|
raise Exception('\n'.join(list(types.keys()))) # pylint: disable=broad-exception-raised
|
|
|
|
|
|
+def _sort_types(types):
|
|
|
+ keys = {}
|
|
|
+ index = 0
|
|
|
+ for schema in _all_schemas():
|
|
|
+ definition = str(schema)
|
|
|
+ key = _identifier(definition)
|
|
|
+ keys[key] = index
|
|
|
+ index += 1
|
|
|
+ classes = collections.OrderedDict()
|
|
|
+ for item in types:
|
|
|
+ name = item['name']
|
|
|
+ if name.find('::') == -1:
|
|
|
+ classes[name] = item
|
|
|
+ else:
|
|
|
+ key = _identifier(name)
|
|
|
+ if not key in keys:
|
|
|
+ keys[key] = index
|
|
|
+ index += 1
|
|
|
+ for key, _ in classes.items():
|
|
|
+ keys[key] = index
|
|
|
+ index += 1
|
|
|
+ def custom_key(x):
|
|
|
+ key = _identifier(x['name'])
|
|
|
+ return keys[key]
|
|
|
+ types = sorted(types, key=custom_key)
|
|
|
+ return types
|
|
|
+
|
|
|
+
|
|
|
def _metadata():
|
|
|
types = _read_metadata()
|
|
|
schemas = _parse_schemas()
|
|
|
@@ -142,6 +171,7 @@ def _metadata():
|
|
|
types[key]['name'] = schema
|
|
|
else:
|
|
|
types[key] = { 'name': schema }
|
|
|
+ types = _sort_types(list(types.values()))
|
|
|
_write_metadata(types)
|
|
|
|
|
|
def main(): # pylint: disable=missing-function-docstring
|