|
|
@@ -32,68 +32,43 @@ def _update_description(schema, lines):
|
|
|
schema['description'] = '\n'.join(lines)
|
|
|
|
|
|
def _attribute_value(attribute_type, attribute_value):
|
|
|
+ if attribute_type in ('float32', 'int32', 'boolean', 'string'):
|
|
|
+ if attribute_value in ("'auto'", '"auto"') or attribute_type == 'string':
|
|
|
+ return attribute_value.strip("'").strip('"')
|
|
|
+ if attribute_type in (None, 'float32', 'int32') and attribute_value == 'None':
|
|
|
+ return None
|
|
|
if attribute_type == 'float32':
|
|
|
- if attribute_value == 'None':
|
|
|
- return None
|
|
|
- if attribute_value != "'auto'":
|
|
|
- return float(attribute_value)
|
|
|
- return attribute_value.strip("'").strip('"')
|
|
|
+ return float(attribute_value)
|
|
|
if attribute_type == 'int32':
|
|
|
- if attribute_value == 'None':
|
|
|
- return None
|
|
|
- if attribute_value in ("'auto'", '"auto"'):
|
|
|
- return attribute_value.strip("'").strip('"')
|
|
|
return int(attribute_value)
|
|
|
- if attribute_type == 'string':
|
|
|
- return attribute_value.strip("'").strip('"')
|
|
|
if attribute_type == 'boolean':
|
|
|
- if attribute_value == 'True':
|
|
|
- return True
|
|
|
- if attribute_value == 'False':
|
|
|
- return False
|
|
|
- if attribute_value == "'auto'":
|
|
|
- return attribute_value.strip("'").strip('"')
|
|
|
+ if attribute_value in ('True', 'False'):
|
|
|
+ return attribute_value == 'True'
|
|
|
raise Exception("Unknown boolean default value '" + str(attribute_value) + "'.")
|
|
|
if attribute_type:
|
|
|
raise Exception("Unknown default type '" + attribute_type + "'.")
|
|
|
- if attribute_value == 'None':
|
|
|
- return None
|
|
|
return attribute_value.strip("'")
|
|
|
|
|
|
-def _update_attribute(schema, name, description, attribute_type, optional, default_value):
|
|
|
- attribute = None
|
|
|
- if not 'attributes' in schema:
|
|
|
- schema['attributes'] = []
|
|
|
- for current_attribute in schema['attributes']:
|
|
|
- if 'name' in current_attribute and current_attribute['name'] == name:
|
|
|
- attribute = current_attribute
|
|
|
- break
|
|
|
+def _find_attribute(schema, name):
|
|
|
+ schema.setdefault('attributes', [])
|
|
|
+ attribute = next((_ for _ in schema['attributes'] if _['name'] == name), None)
|
|
|
if not attribute:
|
|
|
- attribute = {}
|
|
|
- attribute['name'] = name
|
|
|
+ attribute = { 'name': name }
|
|
|
schema['attributes'].append(attribute)
|
|
|
- attribute['description'] = description
|
|
|
- if attribute_type:
|
|
|
- attribute['type'] = attribute_type
|
|
|
- if optional:
|
|
|
- attribute['optional'] = True
|
|
|
- if default_value:
|
|
|
- attribute['default'] = _attribute_value(attribute_type, default_value)
|
|
|
+ return attribute
|
|
|
|
|
|
def _update_attributes(schema, lines):
|
|
|
- i = 0
|
|
|
- while i < len(lines):
|
|
|
- line = lines[i]
|
|
|
+ while len(lines) > 0:
|
|
|
+ line = lines.pop(0)
|
|
|
line = re.sub(r',\s+', ', ', line)
|
|
|
- if line.endswith('.'):
|
|
|
- line = line[0:-1]
|
|
|
- colon = line.find(':')
|
|
|
- if colon == -1:
|
|
|
+ line = line.rstrip('.')
|
|
|
+ name, line = line.split(':', 1)
|
|
|
+ if not line:
|
|
|
raise Exception("Expected ':' in parameter.")
|
|
|
- name = line[0:colon].strip(' ')
|
|
|
- line = line[colon + 1:].strip(' ')
|
|
|
+ name = name.strip(' ')
|
|
|
+ line = line.strip(' ')
|
|
|
attribute_type = None
|
|
|
- type_map = {
|
|
|
+ attribute_types = {
|
|
|
'float': 'float32',
|
|
|
'boolean': 'boolean',
|
|
|
'bool': 'boolean',
|
|
|
@@ -165,7 +140,6 @@ def _update_attributes(schema, lines):
|
|
|
end = line.find('},')
|
|
|
if end == -1:
|
|
|
raise Exception("Expected '}' in parameter.")
|
|
|
- # attribute_type = line[0:end + 1]
|
|
|
line = line[end + 2:].strip(' ')
|
|
|
elif line.startswith("'"):
|
|
|
while line.startswith("'"):
|
|
|
@@ -173,7 +147,7 @@ def _update_attributes(schema, lines):
|
|
|
if end == -1:
|
|
|
raise Exception("Expected \' in parameter.")
|
|
|
line = line[end + 2:].strip(' ')
|
|
|
- elif line in type_map:
|
|
|
+ elif line in attribute_types:
|
|
|
attribute_type = line
|
|
|
line = ''
|
|
|
elif line.startswith('int, RandomState instance or None,'):
|
|
|
@@ -184,7 +158,7 @@ def _update_attributes(schema, lines):
|
|
|
line = ''
|
|
|
else:
|
|
|
space = line.find(' {')
|
|
|
- if space != -1 and line[0:space] in type_map and line[space:].find('}') != -1:
|
|
|
+ if space != -1 and line[0:space] in attribute_types and line[space:].find('}') != -1:
|
|
|
attribute_type = line[0:space]
|
|
|
end = line[space:].find('}')
|
|
|
line = line[space+end+1:]
|
|
|
@@ -196,11 +170,7 @@ def _update_attributes(schema, lines):
|
|
|
raise Exception("Expected ',' in parameter.")
|
|
|
attribute_type = line[0:comma]
|
|
|
line = line[comma + 1:].strip(' ')
|
|
|
- attribute_type = type_map.get(attribute_type, None)
|
|
|
- # elif type == "{dict, 'balanced'}":
|
|
|
- # v = 'map'
|
|
|
- # else:
|
|
|
- # raise Exception("Unknown attribute type '" + attribute_type + "'.")
|
|
|
+ attribute_type = attribute_types.get(attribute_type, None)
|
|
|
optional = False
|
|
|
default = None
|
|
|
while len(line.strip(' ')) > 0:
|
|
|
@@ -231,13 +201,18 @@ def _update_attributes(schema, lines):
|
|
|
if comma == -1:
|
|
|
raise Exception("Expected ',' in parameter.")
|
|
|
line = line[comma+1:]
|
|
|
- i = i + 1
|
|
|
- attribute_lines = []
|
|
|
- while i < len(lines) and (len(lines[i].strip(' ')) == 0 or lines[i].startswith(' ')):
|
|
|
- attribute_lines.append(lines[i].lstrip(' '))
|
|
|
- i = i + 1
|
|
|
- description = '\n'.join(attribute_lines)
|
|
|
- _update_attribute(schema, name, description, attribute_type, optional, default)
|
|
|
+ description = []
|
|
|
+ while len(lines) > 0 and (len(lines[0].strip(' ')) == 0 or lines[0].startswith(' ')):
|
|
|
+ line = lines.pop(0).lstrip(' ')
|
|
|
+ description.append(line)
|
|
|
+ attribute = _find_attribute(schema, name)
|
|
|
+ attribute['description'] = '\n'.join(description)
|
|
|
+ if attribute_type:
|
|
|
+ attribute['type'] = attribute_type
|
|
|
+ if optional:
|
|
|
+ attribute['optional'] = True
|
|
|
+ if default:
|
|
|
+ attribute['default'] = _attribute_value(attribute_type, default)
|
|
|
|
|
|
def _metadata():
|
|
|
json_file = os.path.join(os.path.dirname(__file__), '../source/sklearn-metadata.json')
|
|
|
@@ -271,9 +246,7 @@ def _metadata():
|
|
|
file.write(json.dumps(json_root, sort_keys=False, indent=2))
|
|
|
|
|
|
def main(): # pylint: disable=missing-function-docstring
|
|
|
- command_table = { 'metadata': _metadata }
|
|
|
- command = sys.argv[1]
|
|
|
- command_table[command]()
|
|
|
+ _metadata()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|