Explorar el Código

Update scikit-learn script

Lutz Roeder hace 3 años
padre
commit
41fb0fde8b
Se han modificado 2 ficheros con 38 adiciones y 65 borrados
  1. 1 1
      tools/sklearn
  2. 37 64
      tools/sklearn_metadata.py

+ 1 - 1
tools/sklearn

@@ -47,7 +47,7 @@ metadata() {
     [[ $(grep -U $'\x0D' ./source/sklearn-metadata.json) ]] && crlf=1
     venv
     export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
-    ${python} ./tools/sklearn_script.py metadata
+    ${python} ./tools/sklearn_metadata.py
     deactivate
     if [[ -n ${crlf} ]]; then
         unix2dos --quiet --newfile ./source/sklearn-metadata.json ./source/sklearn-metadata.json

+ 37 - 64
tools/sklearn_script.py → tools/sklearn_metadata.py

@@ -32,68 +32,43 @@ def _update_description(schema, lines):
         schema['description'] = '\n'.join(lines)
 
 def _attribute_value(attribute_type, attribute_value):
+    if attribute_type in ('float32', 'int32', 'boolean', 'string'):
+        if attribute_value in ("'auto'", '"auto"') or attribute_type == 'string':
+            return attribute_value.strip("'").strip('"')
+    if attribute_type in (None, 'float32', 'int32') and attribute_value == 'None':
+        return None
     if attribute_type == 'float32':
-        if attribute_value == 'None':
-            return None
-        if attribute_value != "'auto'":
-            return float(attribute_value)
-        return attribute_value.strip("'").strip('"')
+        return float(attribute_value)
     if attribute_type == 'int32':
-        if attribute_value == 'None':
-            return None
-        if attribute_value in ("'auto'", '"auto"'):
-            return attribute_value.strip("'").strip('"')
         return int(attribute_value)
-    if attribute_type == 'string':
-        return attribute_value.strip("'").strip('"')
     if attribute_type == 'boolean':
-        if attribute_value == 'True':
-            return True
-        if attribute_value == 'False':
-            return False
-        if attribute_value == "'auto'":
-            return attribute_value.strip("'").strip('"')
+        if attribute_value in ('True', 'False'):
+            return attribute_value == 'True'
         raise Exception("Unknown boolean default value '" + str(attribute_value) + "'.")
     if attribute_type:
         raise Exception("Unknown default type '" + attribute_type + "'.")
-    if attribute_value == 'None':
-        return None
     return attribute_value.strip("'")
 
-def _update_attribute(schema, name, description, attribute_type, optional, default_value):
-    attribute = None
-    if not 'attributes' in schema:
-        schema['attributes'] = []
-    for current_attribute in schema['attributes']:
-        if 'name' in current_attribute and current_attribute['name'] == name:
-            attribute = current_attribute
-            break
+def _find_attribute(schema, name):
+    schema.setdefault('attributes', [])
+    attribute = next((_ for _ in schema['attributes'] if _['name'] == name), None)
     if not attribute:
-        attribute = {}
-        attribute['name'] = name
+        attribute = { 'name': name }
         schema['attributes'].append(attribute)
-    attribute['description'] = description
-    if attribute_type:
-        attribute['type'] = attribute_type
-    if optional:
-        attribute['optional'] = True
-    if default_value:
-        attribute['default'] = _attribute_value(attribute_type, default_value)
+    return attribute
 
 def _update_attributes(schema, lines):
-    i = 0
-    while i < len(lines):
-        line = lines[i]
+    while len(lines) > 0:
+        line = lines.pop(0)
         line = re.sub(r',\s+', ', ', line)
-        if line.endswith('.'):
-            line = line[0:-1]
-        colon = line.find(':')
-        if colon == -1:
+        line = line.rstrip('.')
+        name, line = line.split(':', 1)
+        if not line:
             raise Exception("Expected ':' in parameter.")
-        name = line[0:colon].strip(' ')
-        line = line[colon + 1:].strip(' ')
+        name = name.strip(' ')
+        line = line.strip(' ')
         attribute_type = None
-        type_map = {
+        attribute_types = {
             'float': 'float32',
             'boolean': 'boolean',
             'bool': 'boolean',
@@ -165,7 +140,6 @@ def _update_attributes(schema, lines):
                 end = line.find('},')
                 if end == -1:
                     raise Exception("Expected '}' in parameter.")
-                # attribute_type = line[0:end + 1]
                 line = line[end + 2:].strip(' ')
         elif line.startswith("'"):
             while line.startswith("'"):
@@ -173,7 +147,7 @@ def _update_attributes(schema, lines):
                 if end == -1:
                     raise Exception("Expected \' in parameter.")
                 line = line[end + 2:].strip(' ')
-        elif line in type_map:
+        elif line in attribute_types:
             attribute_type = line
             line = ''
         elif line.startswith('int, RandomState instance or None,'):
@@ -184,7 +158,7 @@ def _update_attributes(schema, lines):
             line = ''
         else:
             space = line.find(' {')
-            if space != -1 and line[0:space] in type_map and line[space:].find('}') != -1:
+            if space != -1 and line[0:space] in attribute_types and line[space:].find('}') != -1:
                 attribute_type = line[0:space]
                 end = line[space:].find('}')
                 line = line[space+end+1:]
@@ -196,11 +170,7 @@ def _update_attributes(schema, lines):
                         raise Exception("Expected ',' in parameter.")
                 attribute_type = line[0:comma]
                 line = line[comma + 1:].strip(' ')
-        attribute_type = type_map.get(attribute_type, None)
-        # elif type == "{dict, 'balanced'}":
-        #    v = 'map'
-        # else:
-        #    raise Exception("Unknown attribute type '" + attribute_type + "'.")
+        attribute_type = attribute_types.get(attribute_type, None)
         optional = False
         default = None
         while len(line.strip(' ')) > 0:
@@ -231,13 +201,18 @@ def _update_attributes(schema, lines):
                 if comma == -1:
                     raise Exception("Expected ',' in parameter.")
                 line = line[comma+1:]
-        i = i + 1
-        attribute_lines = []
-        while i < len(lines) and (len(lines[i].strip(' ')) == 0 or lines[i].startswith('        ')):
-            attribute_lines.append(lines[i].lstrip(' '))
-            i = i + 1
-        description = '\n'.join(attribute_lines)
-        _update_attribute(schema, name, description, attribute_type, optional, default)
+        description = []
+        while len(lines) > 0 and (len(lines[0].strip(' ')) == 0 or lines[0].startswith('        ')):
+            line = lines.pop(0).lstrip(' ')
+            description.append(line)
+        attribute = _find_attribute(schema, name)
+        attribute['description'] = '\n'.join(description)
+        if attribute_type:
+            attribute['type'] = attribute_type
+        if optional:
+            attribute['optional'] = True
+        if default:
+            attribute['default'] = _attribute_value(attribute_type, default)
 
 def _metadata():
     json_file = os.path.join(os.path.dirname(__file__), '../source/sklearn-metadata.json')
@@ -271,9 +246,7 @@ def _metadata():
         file.write(json.dumps(json_root, sort_keys=False, indent=2))
 
 def main(): # pylint: disable=missing-function-docstring
-    command_table = { 'metadata': _metadata }
-    command = sys.argv[1]
-    command_table[command]()
+    _metadata()
 
 if __name__ == '__main__':
     main()