Explorar el Código

Update PyTorch script

Lutz Roeder hace 1 año
padre
commit
082add7125
Se han modificado 3 ficheros con 2474 adiciones y 2399 borrados
  1. 2437 2380
      source/pytorch-metadata.json
  2. 0 12
      source/pytorch.js
  3. 37 7
      tools/pytorch_script.py

La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 2437 - 2380
source/pytorch-metadata.json


+ 0 - 12
source/pytorch.js

@@ -3262,18 +3262,6 @@ pytorch.Execution = class extends python.Execution {
             }
             matches.push(schema);
         }
-        if (matches.length > 1) {
-            const keys = new Map([['IntType', 1], ['FloatType', 2], ['TensorType', 3], ['NumberType', 4]]);
-            matches.sort((a, b) => {
-                let keyA = keys.get(a.arguments[0].real_type.kind()) || 5;
-                let keyB = keys.get(b.arguments[0].real_type.kind()) || 5;
-                if (keyA === keyB && a.arguments.length > 1 && b.arguments.length > 1) {
-                    keyA = keys.get(a.arguments[1].real_type.kind()) || 5;
-                    keyB = keys.get(b.arguments[1].real_type.kind()) || 5;
-                }
-                return keyA - keyB;
-            });
-        }
         if (matches.length === 0) {
             throw new pytorch.Error(`Unknown function '${op_name}'.`);
         }

+ 37 - 7
tools/pytorch_script.py

@@ -9,7 +9,7 @@ import sys
 
 root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 sys.path.append(root_dir)
-sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'test', 'backend')
+sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'pytorch_script')
 
 source_dir = os.path.join(root_dir, 'source')
 third_party_dir = os.path.join(root_dir, 'third_party')
@@ -34,8 +34,7 @@ def _read_metadata():
         metadata[key] = value
     return metadata
 
-def _write_metadata(value):
-    metadata = list(collections.OrderedDict(sorted(value.items())).values())
+def _write_metadata(metadata):
     content = json.dumps(metadata, indent=2, ensure_ascii=False)
     content = re.sub(r'\s {8}', ' ', content)
     content = re.sub(r',\s {8}', ', ', content)
@@ -88,13 +87,15 @@ known_legacy_schema_definitions = [
 def _identifier(schema):
     return schema.split('(', 1)[0].strip()
 
-def _parse_schemas():
-    schemas = {}
+def _all_schemas():
     torch = __import__('torch')
     __import__('torchvision')
     __import__('torchaudio')
-    all_schemas = list(torch._C._jit_get_all_schemas()) # pylint: disable=protected-access
-    for schema in all_schemas:
+    return list(torch._C._jit_get_all_schemas()) # pylint: disable=protected-access
+
+def _parse_schemas():
+    schemas = {}
+    for schema in _all_schemas():
         definition = str(schema)
         definition = definition.replace('(b|a)', '(a|b)')
         key = _identifier(definition)
@@ -131,6 +132,34 @@ def _check_types(types, schemas):
     if len(types) > 0:
         raise Exception('\n'.join(list(types.keys()))) # pylint: disable=broad-exception-raised
 
+def _sort_types(types):
+    keys = {}
+    index = 0
+    for schema in _all_schemas():
+        definition = str(schema)
+        key = _identifier(definition)
+        keys[key] = index
+        index += 1
+    classes = collections.OrderedDict()
+    for item in types:
+        name = item['name']
+        if name.find('::') == -1:
+            classes[name] = item
+        else:
+            key = _identifier(name)
+            if not key in keys:
+                keys[key] = index
+                index += 1
+    for key, _ in classes.items():
+        keys[key] = index
+        index += 1
+    def custom_key(x):
+        key = _identifier(x['name'])
+        return keys[key]
+    types = sorted(types, key=custom_key)
+    return types
+
+
 def _metadata():
     types = _read_metadata()
     schemas = _parse_schemas()
@@ -142,6 +171,7 @@ def _metadata():
             types[key]['name'] = schema
         else:
             types[key] = { 'name': schema }
+    types = _sort_types(list(types.values()))
     _write_metadata(types)
 
 def main(): # pylint: disable=missing-function-docstring

Algunos archivos no se mostraron porque demasiados archivos cambiaron en este cambio