Lutz Roeder 7 лет назад
Родитель
Сommit
eaa56047ec
7 измененных файлов с 194 добавлено и 116 удалено
  1. 49 0
      src/pytorch-metadata.json
  2. 4 1
      src/pytorch.js
  3. 51 36
      tools/caffe2-script.py
  4. 1 1
      tools/keras
  5. 48 43
      tools/keras-script.py
  6. 2 2
      tools/pytorch
  7. 39 33
      tools/pytorch-script.py

+ 49 - 0
src/pytorch-metadata.json

@@ -213,6 +213,13 @@
       "package": "torch.nn.modules.activation"
     }
   },
+  {
+    "name": "MaxPool1d",
+    "schema": {
+      "category": "Pool",
+      "package": "torch.nn.modules.pooling"
+    }
+  },
   {
     "name": "MaxPool2d",
     "schema": {
@@ -242,6 +249,20 @@
       "package": "torch.nn.modules.pooling"
     }
   },
+  {
+    "name": "MaxPool3d",
+    "schema": {
+      "category": "Pool",
+      "package": "torch.nn.modules.pooling"
+    }
+  },
+  {
+    "name": "AvgPool1d",
+    "schema": {
+      "category": "Pool",
+      "package": "torch.nn.modules.pooling"
+    }
+  },
   {
     "name": "AvgPool2d",
     "schema": {
@@ -267,6 +288,34 @@
       "package": "torch.nn.modules.pooling"
     }
   },
+  {
+    "name": "AvgPool3d",
+    "schema": {
+      "category": "Pool",
+      "package": "torch.nn.modules.pooling"
+    }
+  },
+  {
+    "name": "AdaptiveAvgPool1d",
+    "schema": {
+      "category": "Pool",
+      "package": "torch.nn.modules.pooling"
+    }
+  },
+  {
+    "name": "AdaptiveAvgPool2d",
+    "schema": {
+      "category": "Pool",
+      "package": "torch.nn.modules.pooling"
+    }
+  },
+  {
+    "name": "AdaptiveAvgPool3d",
+    "schema": {
+      "category": "Pool",
+      "package": "torch.nn.modules.pooling"
+    }
+  },
   {
     "name": "BatchNorm1d",
     "schema": {

+ 4 - 1
src/pytorch.js

@@ -97,7 +97,10 @@ pytorch.ModelFactory = class {
             constructorTable['torch.nn.modules.pooling.AvgPool3d'] = function () {};
             constructorTable['torch.nn.modules.pooling.MaxPool1d'] = function() {};
             constructorTable['torch.nn.modules.pooling.MaxPool2d'] = function () {};
-            constructorTable['torch.nn.modules.pooling.MaxPool2d'] = function() {};
+            constructorTable['torch.nn.modules.pooling.MaxPool3d'] = function() {};
+            constructorTable['torch.nn.modules.pooling.AdaptiveAvgPool1d'] = function() {};
+            constructorTable['torch.nn.modules.pooling.AdaptiveAvgPool2d'] = function() {};
+            constructorTable['torch.nn.modules.pooling.AdaptiveAvgPool3d'] = function() {};
             constructorTable['torch.nn.modules.rnn.LSTM'] = function () {};
             constructorTable['torch.nn.modules.sparse.Embedding'] = function () {};
             constructorTable['torch.nn.modules.upsampling.Upsample'] = function() {};

+ 51 - 36
tools/caffe2-script.py

@@ -4,15 +4,11 @@ from __future__ import print_function
 
 import io
 import json
+import logging
 import pydoc
 import os
 import re
 import sys
-import caffe2.python.core
-
-json_file = '../src/caffe2-metadata.json'
-json_data = open(json_file).read()
-json_root = json.loads(json_data)
 
 def get_support_level(dir):
     if 'caffe2/caffe2/operators' in dir:
@@ -145,36 +141,55 @@ def update_output(schema, output_desc):
     if len(output_desc) > 2:
         return
 
-schema_map = {}
+class Caffe2Filter(logging.Filter):
+    def filter(self, record):
+        return record.getMessage().startswith('WARNING:root:This caffe2 python run does not have GPU support.')
 
-for entry in json_root:
-    name = entry['name']
-    schema = entry['schema']
-    schema_map[name] = schema
+def metadata():
 
-for name in caffe2.python.core._GetRegisteredOperators():
-    op_schema = caffe2.python.workspace.C.OpSchema.get(name)
-    if op_schema:
-        if name in schema_map:
-            schema = schema_map[name]
-        else:
-            schema = {}
-            schema_map[name] = { 'name': name, 'schema': schema }
-        schema['description'] = op_schema.doc
-        for arg in op_schema.args:
-            update_argument(schema, arg)
-        for input_desc in op_schema.input_desc:
-            update_input(schema, input_desc)
-        if name != 'Int8ConvRelu' and name != 'Int8AveragePoolRelu':
-            for output_desc in op_schema.output_desc:
-                update_output(schema, output_desc)
-        schema['support_level'] = get_support_level(os.path.dirname(op_schema.file))
-
-with io.open(json_file, 'w', newline='') as fout:
-    json_data = json.dumps(json_root, sort_keys=True, indent=2)
-    for line in json_data.splitlines():
-        line = line.rstrip()
-        if sys.version_info[0] < 3:
-            line = unicode(line)
-        fout.write(line)
-        fout.write('\n')
+    logging.getLogger('').addFilter(Caffe2Filter())
+
+    import caffe2.python.core
+
+    json_file = '../src/caffe2-metadata.json'
+    json_data = open(json_file).read()
+    json_root = json.loads(json_data)
+
+    schema_map = {}
+
+    for entry in json_root:
+        name = entry['name']
+        schema = entry['schema']
+        schema_map[name] = schema
+
+    for name in caffe2.python.core._GetRegisteredOperators():
+        op_schema = caffe2.python.workspace.C.OpSchema.get(name)
+        if op_schema:
+            if name in schema_map:
+                schema = schema_map[name]
+            else:
+                schema = {}
+                schema_map[name] = { 'name': name, 'schema': schema }
+            schema['description'] = op_schema.doc
+            for arg in op_schema.args:
+                update_argument(schema, arg)
+            for input_desc in op_schema.input_desc:
+                update_input(schema, input_desc)
+            if name != 'Int8ConvRelu' and name != 'Int8AveragePoolRelu':
+                for output_desc in op_schema.output_desc:
+                    update_output(schema, output_desc)
+            schema['support_level'] = get_support_level(os.path.dirname(op_schema.file))
+
+    with io.open(json_file, 'w', newline='') as fout:
+        json_data = json.dumps(json_root, sort_keys=True, indent=2)
+        for line in json_data.splitlines():
+            line = line.rstrip()
+            if sys.version_info[0] < 3:
+                line = unicode(line)
+            fout.write(line)
+            fout.write('\n')
+
+if __name__ == '__main__':
+    command_table = { 'metadata': metadata }
+    command = sys.argv[1];
+    command_table[command]()

+ 1 - 1
tools/keras

@@ -42,7 +42,7 @@ metadata() {
     source ${virtualenv}/bin/activate
     echo "Update 'keras-metadata.json'"
     pushd ${tools} > /dev/null
-    ${python} keras-script.py
+    ${python} keras-script.py metadata
     popd > /dev/null
     deactivate
 }

+ 48 - 43
tools/keras-script.py

@@ -9,6 +9,11 @@ import pydoc
 import re
 import sys
 
+stderr = sys.stderr
+sys.stderr = open(os.devnull, 'w')
+import keras
+sys.stderr = stderr
+
 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
 def count_leading_spaces(s):
@@ -279,50 +284,50 @@ def update_output(schema, description):
     if entry:
         entry['description'] = description
 
-json_file = '../src/keras-metadata.json'
-json_data = open(json_file).read()
-json_root = json.loads(json_data)
-
-for entry in json_root:
-    name = entry['name']
-    schema = entry['schema']
-    if 'package' in schema:
-        class_name = schema['package'] + '.' + name
-        class_definition = pydoc.locate(class_name)
-        if not class_definition:
-            raise Exception('\'' + class_name + '\' not found.')
-        docstring = class_definition.__doc__
-        if not docstring:
-            raise Exception('\'' + class_name + '\' missing __doc__.')
-        docstring = process_docstring(docstring)
-        headers = split_docstring(docstring)
-        if '' in headers:
-            schema['description'] = '\n'.join(headers[''])
-            del headers['']
-        if 'Arguments' in headers:
-            update_arguments(schema, headers['Arguments'])
-            del headers['Arguments']
-        if 'Input shape' in headers:
-            update_input(schema, '\n'.join(headers['Input shape']))
-            del headers['Input shape']
-        if 'Output shape' in headers:
-            update_output(schema, '\n'.join(headers['Output shape']))
-            del headers['Output shape']
-        if 'Examples' in headers:
-            update_examples(schema, headers['Examples'])
-            del headers['Examples']
-        if 'Example' in headers:
-            update_examples(schema, headers['Example'])
-            del headers['Example']
-        if 'References' in headers:
-            update_references(schema, headers['References'])
-            del headers['References']
-        if 'Raises' in headers:
-            del headers['Raises']
-        if len(headers) > 0:
-            raise Exception('\'' + class_name + '.__doc__\' contains unprocessed headers.')
- 
 def metadata():
+    json_file = '../src/keras-metadata.json'
+    json_data = open(json_file).read()
+    json_root = json.loads(json_data)
+
+    for entry in json_root:
+        name = entry['name']
+        schema = entry['schema']
+        if 'package' in schema:
+            class_name = schema['package'] + '.' + name
+            class_definition = pydoc.locate(class_name)
+            if not class_definition:
+                raise Exception('\'' + class_name + '\' not found.')
+            docstring = class_definition.__doc__
+            if not docstring:
+                raise Exception('\'' + class_name + '\' missing __doc__.')
+            docstring = process_docstring(docstring)
+            headers = split_docstring(docstring)
+            if '' in headers:
+                schema['description'] = '\n'.join(headers[''])
+                del headers['']
+            if 'Arguments' in headers:
+                update_arguments(schema, headers['Arguments'])
+                del headers['Arguments']
+            if 'Input shape' in headers:
+                update_input(schema, '\n'.join(headers['Input shape']))
+                del headers['Input shape']
+            if 'Output shape' in headers:
+                update_output(schema, '\n'.join(headers['Output shape']))
+                del headers['Output shape']
+            if 'Examples' in headers:
+                update_examples(schema, headers['Examples'])
+                del headers['Examples']
+            if 'Example' in headers:
+                update_examples(schema, headers['Example'])
+                del headers['Example']
+            if 'References' in headers:
+                update_references(schema, headers['References'])
+                del headers['References']
+            if 'Raises' in headers:
+                del headers['Raises']
+            if len(headers) > 0:
+                raise Exception('\'' + class_name + '.__doc__\' contains unprocessed headers.')
+
     with io.open(json_file, 'w', newline='') as fout:
         json_data = json.dumps(json_root, sort_keys=True, indent=2)
         for line in json_data.splitlines():

+ 2 - 2
tools/pytorch

@@ -60,9 +60,9 @@ metadata() {
     source ${virtualenv}/bin/activate
     pushd ${tools} > /dev/null
     echo "Generate 'caffe2-metadata.json'"
-    ${python} caffe2-script.py
+    ${python} caffe2-script.py metadata
     echo "Generate 'pytorch-metadata.json'"
-    ${python} pytorch-script.py
+    ${python} pytorch-script.py metadata
     popd > /dev/null
     deactivate
 }

+ 39 - 33
tools/pytorch-script.py

@@ -9,36 +9,42 @@ import os
 import re
 import sys
 
-json_file = '../src/pytorch-metadata.json'
-json_data = open(json_file).read()
-json_root = json.loads(json_data)
-
-schema_map = {}
-
-for entry in json_root:
-    name = entry['name']
-    schema = entry['schema']
-    schema_map[name] = schema
-
-for entry in json_root:
-    name = entry['name']
-    schema = entry['schema']
-    if 'package' in schema:
-        class_name = schema['package'] + '.' + name
-        # print(class_name)
-        class_definition = pydoc.locate(class_name)
-        if not class_definition:
-            raise Exception('\'' + class_name + '\' not found.')
-        docstring = class_definition.__doc__
-        if not docstring:
-            raise Exception('\'' + class_name + '\' missing __doc__.')
-        # print(docstring)
-
-with io.open(json_file, 'w', newline='') as fout:
-    json_data = json.dumps(json_root, sort_keys=True, indent=2)
-    for line in json_data.splitlines():
-        line = line.rstrip()
-        if sys.version_info[0] < 3:
-            line = unicode(line)
-        fout.write(line)
-        fout.write('\n')
+def metadata():
+    json_file = '../src/pytorch-metadata.json'
+    json_data = open(json_file).read()
+    json_root = json.loads(json_data)
+
+    schema_map = {}
+
+    for entry in json_root:
+        name = entry['name']
+        schema = entry['schema']
+        schema_map[name] = schema
+
+    for entry in json_root:
+        name = entry['name']
+        schema = entry['schema']
+        if 'package' in schema:
+            class_name = schema['package'] + '.' + name
+            # print(class_name)
+            class_definition = pydoc.locate(class_name)
+            if not class_definition:
+                raise Exception('\'' + class_name + '\' not found.')
+            docstring = class_definition.__doc__
+            if not docstring:
+                raise Exception('\'' + class_name + '\' missing __doc__.')
+            # print(docstring)
+
+    with io.open(json_file, 'w', newline='') as fout:
+        json_data = json.dumps(json_root, sort_keys=True, indent=2)
+        for line in json_data.splitlines():
+            line = line.rstrip()
+            if sys.version_info[0] < 3:
+                line = unicode(line)
+            fout.write(line)
+            fout.write('\n')
+
+if __name__ == '__main__':
+    command_table = { 'metadata': metadata }
+    command = sys.argv[1];
+    command_table[command]()