preprocess.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import re
  2. import tqdm
  3. re_message = re.compile(r'([0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}) (.*)')
  4. def read_qq_history_file(filename):
  5. message = []
  6. # get line count
  7. count = -1
  8. for count, line in enumerate(open(filename, 'r', encoding='utf-8')):
  9. pass
  10. count += 1
  11. # read data
  12. with open(filename, 'r', encoding='utf-8') as f:
  13. for i in range(8):
  14. header = f.readline()
  15. print(header)
  16. cur_msg = None
  17. line = f.readline()
  18. with tqdm.tqdm(total=count, ascii=True) as pbar:
  19. while line:
  20. if line.strip() == '':
  21. line = f.readline()
  22. continue
  23. msg_header = re_message.match(line.strip())
  24. if msg_header:
  25. if cur_msg is not None:
  26. message.append(cur_msg)
  27. cur_msg = {
  28. 'time': msg_header.group(1),
  29. 'user': msg_header.group(2),
  30. 'data': ''
  31. }
  32. else:
  33. cur_msg['data'] += line
  34. line = f.readline()
  35. pbar.update()
  36. return message
  37. def filter_msg(messages):
  38. with tqdm.tqdm(total=len(messages), ascii=True) as pbar:
  39. for each_msg in messages:
  40. each_msg['data'] = each_msg['data'].replace("\n", '')
  41. each_msg['data'] = each_msg['data'].replace(r'[图片]', '')
  42. each_msg['data'] = each_msg['data'].replace(r'[表情]', '')
  43. each_msg['data'] = re.sub(r'(http|https|ftp)://[0-9a-zA-Z~./_\-]+', '', each_msg['data'])
  44. each_msg['data'] = re.sub(r'@.+ ', '', each_msg['data'])
  45. pbar.update()
  46. def generate_dataset(messages, output_path_source, output_path_target):
  47. prev_msg = None
  48. with open(output_path_source, 'w', encoding='utf-8') as fs:
  49. with open(output_path_target, 'w', encoding='utf-8') as ft:
  50. with tqdm.tqdm(total=len(messages), ascii=True) as pbar:
  51. for each_msg in messages:
  52. if each_msg['data'].strip() == '':
  53. continue
  54. if prev_msg is not None:
  55. fs.write(prev_msg['data'] + '\n')
  56. ft.write(each_msg['data'] + '\n')
  57. prev_msg = each_msg
  58. pbar.update()
  59. if __name__ == "__main__":
  60. '''
  61. print('read message')
  62. msg = read_qq_history_file('data/Octoon 开发组.txt')
  63. print('filter message')
  64. filter_msg(msg)
  65. print('write to file')
  66. generate_dataset(msg, 'data/octoon_source.txt', 'data/octoon_target.txt')
  67. '''
  68. print('read message')
  69. msg = read_qq_history_file('data/ISOIEC C++ China Unofficial.txt')
  70. print('filter message')
  71. filter_msg(msg)
  72. print('write to file')
  73. generate_dataset(msg, 'data/train_source.txt', 'data/train_target.txt')