gen_csv.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from data.feature_spec import FeatureSpec
  2. from data.outbrain.defaults import ONEHOT_CHANNEL, MULTIHOT_CHANNEL, LABEL_CHANNEL, NUMERICAL_CHANNEL, \
  3. MAP_FEATURE_CHANNEL
  4. from argparse import ArgumentParser
  5. import pandas as pd
  6. import os
  7. import numpy as np
  8. def parse_args():
  9. parser = ArgumentParser()
  10. parser.add_argument('--feature_spec_in', type=str, default='feature_spec.yaml',
  11. help='Name of the input feature specification file')
  12. parser.add_argument('--output', type=str, default='/data')
  13. parser.add_argument('--size', type=int, default=1000,
  14. help='The desired number of rows in the output csv file')
  15. return parser.parse_args()
  16. def main():
  17. #this generator supports the following feature types:
  18. #onehot categorical
  19. #numerical
  20. #label
  21. #multihot categorical
  22. args = parse_args()
  23. dataset_size = args.size
  24. fspec_in = FeatureSpec.from_yaml(args.feature_spec_in)
  25. fspec_in.base_directory = args.output
  26. #prepare shapes for one-hot categorical features
  27. onehot_features = fspec_in.get_names_by_channel(ONEHOT_CHANNEL)
  28. onehot_cardinalities: dict = fspec_in.get_cardinalities(onehot_features)
  29. multihot_features = fspec_in.get_names_by_channel(MULTIHOT_CHANNEL)
  30. multihot_cardinalities: dict = fspec_in.get_cardinalities(multihot_features)
  31. multihot_hotnesses: dict = fspec_in.get_multihot_hotnesses(multihot_features)
  32. input_label_feature_name = fspec_in.get_names_by_channel(LABEL_CHANNEL)[0]
  33. numerical_names_set = set(fspec_in.get_names_by_channel(NUMERICAL_CHANNEL))
  34. map_channel_features = fspec_in.get_names_by_channel(MAP_FEATURE_CHANNEL)
  35. map_feature = None
  36. if len(map_channel_features)>0:
  37. map_feature=map_channel_features[0]
  38. for mapping_name, mapping in fspec_in.source_spec.items():
  39. for chunk in mapping:
  40. assert chunk['type'] == 'csv', "Only csv files supported in this generator"
  41. assert len(chunk['files']) == 1, "Only one file per chunk supported in this generator"
  42. path_to_save = os.path.join(fspec_in.base_directory, chunk['files'][0])
  43. data = {}
  44. for name in chunk['features']:
  45. if name == input_label_feature_name:
  46. data[name]=np.random.randint(0, 2, size=dataset_size)
  47. elif name in numerical_names_set:
  48. data[name]=np.random.rand(dataset_size)
  49. elif name in set(onehot_features):
  50. local_cardinality = onehot_cardinalities[name]
  51. data[name]=np.random.randint(0, local_cardinality, size=dataset_size)
  52. elif name in set(multihot_features):
  53. local_cardinality = multihot_cardinalities[name]
  54. local_hotness = multihot_hotnesses[name]
  55. data[name]=np.random.randint(0, local_cardinality, size=(dataset_size, local_hotness)).tolist()
  56. elif name == map_feature:
  57. raise NotImplementedError("Cannot generate datasets with map feature enabled")
  58. # TODO add a parameter that specifies max repeats and generate
  59. else:
  60. raise ValueError(f"Cannot generate for unused features. Unknown feature: {name}")
  61. # Columns in the csv appear in the order they are listed in the source spec for a given chunk
  62. column_order = chunk['files']
  63. df = pd.DataFrame(data)
  64. os.makedirs(os.path.dirname(path_to_save), exist_ok=True)
  65. df.to_csv(path_to_save, columns=column_order, index=False, header=False)
  66. if __name__ == "__main__":
  67. main()