convert_test.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) 2018, deepakn94, codyaustun, robieta. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # -----------------------------------------------------------------------
  16. #
  17. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. from argparse import ArgumentParser
  31. import pandas as pd
  32. import numpy as np
  33. from load import implicit_load
  34. from convert import save_feature_spec, _TestNegSampler, TEST_0, TEST_1, TRAIN_0, TRAIN_1
  35. import torch
  36. import os
  37. USER_COLUMN = 'user_id'
  38. ITEM_COLUMN = 'item_id'
  39. def parse_args():
  40. parser = ArgumentParser()
  41. parser.add_argument('--path', type=str, default='/data/ml-20m/ratings.csv',
  42. help='Path to reviews CSV file from MovieLens')
  43. parser.add_argument('--output', type=str, default='/data',
  44. help='Output directory for train and test files')
  45. parser.add_argument('--valid_negative', type=int, default=100,
  46. help='Number of negative samples for each positive test example')
  47. parser.add_argument('--seed', '-s', type=int, default=1,
  48. help='Manually set random seed for torch')
  49. parser.add_argument('--test', type=str, help='select modification to be applied to the set')
  50. return parser.parse_args()
  51. def main():
  52. args = parse_args()
  53. if args.seed is not None:
  54. torch.manual_seed(args.seed)
  55. print("Loading raw data from {}".format(args.path))
  56. df = implicit_load(args.path, sort=False)
  57. if args.test == 'less_user':
  58. to_drop = set(list(df[USER_COLUMN].unique())[-100:])
  59. df = df[~df[USER_COLUMN].isin(to_drop)]
  60. if args.test == 'less_item':
  61. to_drop = set(list(df[ITEM_COLUMN].unique())[-100:])
  62. df = df[~df[ITEM_COLUMN].isin(to_drop)]
  63. if args.test == 'more_user':
  64. sample = df.sample(frac=0.2).copy()
  65. sample[USER_COLUMN] = sample[USER_COLUMN] + 10000000
  66. df = df.append(sample)
  67. users = df[USER_COLUMN]
  68. df = df[users.isin(users[users.duplicated(keep=False)])] # make sure something remains in the train set
  69. if args.test == 'more_item':
  70. sample = df.sample(frac=0.2).copy()
  71. sample[ITEM_COLUMN] = sample[ITEM_COLUMN] + 10000000
  72. df = df.append(sample)
  73. print("Mapping original user and item IDs to new sequential IDs")
  74. df[USER_COLUMN] = pd.factorize(df[USER_COLUMN])[0]
  75. df[ITEM_COLUMN] = pd.factorize(df[ITEM_COLUMN])[0]
  76. user_cardinality = df[USER_COLUMN].max() + 1
  77. item_cardinality = df[ITEM_COLUMN].max() + 1
  78. # Need to sort before popping to get last item
  79. df.sort_values(by='timestamp', inplace=True)
  80. # clean up data
  81. del df['rating'], df['timestamp']
  82. df = df.drop_duplicates() # assuming it keeps order
  83. # Test set is the last interaction for a given user
  84. grouped_sorted = df.groupby(USER_COLUMN, group_keys=False)
  85. test_data = grouped_sorted.tail(1).sort_values(by=USER_COLUMN)
  86. # Train set is all interactions but the last one
  87. train_data = grouped_sorted.apply(lambda x: x.iloc[:-1])
  88. sampler = _TestNegSampler(train_data.values, args.valid_negative)
  89. test_negs = sampler.generate().cuda()
  90. if args.valid_negative > 0:
  91. test_negs = test_negs.reshape(-1, args.valid_negative)
  92. else:
  93. test_negs = test_negs.reshape(test_data.shape[0], 0)
  94. if args.test == 'more_pos':
  95. mask = np.random.rand(len(test_data)) < 0.5
  96. sample = test_data[mask].copy()
  97. sample[ITEM_COLUMN] = sample[ITEM_COLUMN] + 5
  98. test_data = test_data.append(sample)
  99. test_negs_copy = test_negs[mask]
  100. test_negs = torch.cat((test_negs, test_negs_copy), dim=0)
  101. if args.test == 'less_pos':
  102. mask = np.random.rand(len(test_data)) < 0.5
  103. test_data = test_data[mask]
  104. test_negs = test_negs[mask]
  105. # Reshape train set into user,item,label tabular and save
  106. train_ratings = torch.from_numpy(train_data.values).cuda()
  107. train_labels = torch.ones_like(train_ratings[:, 0:1], dtype=torch.float32)
  108. torch.save(train_ratings, os.path.join(args.output, TRAIN_0))
  109. torch.save(train_labels, os.path.join(args.output, TRAIN_1))
  110. # Reshape test set into user,item,label tabular and save
  111. # All users have the same number of items, items for a given user appear consecutively
  112. test_ratings = torch.from_numpy(test_data.values).cuda()
  113. test_users_pos = test_ratings[:, 0:1] # slicing instead of indexing to keep dimensions
  114. test_items_pos = test_ratings[:, 1:2]
  115. test_users = test_users_pos.repeat_interleave(args.valid_negative + 1, dim=0)
  116. test_items = torch.cat((test_items_pos.reshape(-1, 1), test_negs), dim=1).reshape(-1, 1)
  117. positive_labels = torch.ones_like(test_users_pos, dtype=torch.float32)
  118. negative_labels = torch.zeros_like(test_users_pos, dtype=torch.float32).repeat(1, args.valid_negative)
  119. test_labels = torch.cat((positive_labels, negative_labels), dim=1).reshape(-1, 1)
  120. dtypes = {'user': str(test_users.dtype), 'item': str(test_items.dtype), 'label': str(test_labels.dtype)}
  121. test_tensor = torch.cat((test_users, test_items), dim=1)
  122. torch.save(test_tensor, os.path.join(args.output, TEST_0))
  123. torch.save(test_labels, os.path.join(args.output, TEST_1))
  124. if args.test == 'other_names':
  125. dtypes = {'user_2': str(test_users.dtype),
  126. 'item_2': str(test_items.dtype),
  127. 'label_2': str(test_labels.dtype)}
  128. save_feature_spec(user_cardinality=user_cardinality, item_cardinality=item_cardinality, dtypes=dtypes,
  129. test_negative_samples=args.valid_negative, output_path=args.output + '/feature_spec.yaml',
  130. user_feature_name='user_2',
  131. item_feature_name='item_2',
  132. label_feature_name='label_2')
  133. else:
  134. save_feature_spec(user_cardinality=user_cardinality, item_cardinality=item_cardinality, dtypes=dtypes,
  135. test_negative_samples=args.valid_negative, output_path=args.output + '/feature_spec.yaml')
  136. if __name__ == '__main__':
  137. main()