preprocess.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import re
  2. import tqdm
  3. import jieba
  4. re_charset = re.compile("[^\u4e00-\u9fa5^.^a-z^A-Z^0-9]")
  5. re_message = re.compile(r'([0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}) (.*)(\(.+\)|<.+>)')
  6. def read_qq_history_file(filename):
  7. message = []
  8. # get line count
  9. count = -1
  10. for count, line in enumerate(open(filename, 'r', encoding='utf-8')):
  11. pass
  12. count += 1
  13. # read data
  14. with open(filename, 'r', encoding='utf-8') as f:
  15. for i in range(8):
  16. header = f.readline()
  17. print(header)
  18. cur_msg = None
  19. line = f.readline()
  20. with tqdm.tqdm(total=count, ascii=True) as pbar:
  21. while line:
  22. if line.strip() == '':
  23. line = f.readline()
  24. continue
  25. msg_header = re_message.match(line.strip())
  26. if msg_header:
  27. if cur_msg is not None:
  28. message.append(cur_msg)
  29. cur_msg = {
  30. 'time': msg_header.group(1),
  31. 'user': msg_header.group(2),
  32. 'id': msg_header.group(3),
  33. 'data': ''
  34. }
  35. else:
  36. cur_msg['data'] += line
  37. line = f.readline()
  38. pbar.update()
  39. return message
  40. def filter_msg(messages):
  41. with tqdm.tqdm(total=len(messages), ascii=True) as pbar:
  42. for each_msg in messages:
  43. each_msg['data'] = each_msg['data'].replace("\n", '')
  44. each_msg['data'] = each_msg['data'].replace(r'[图片]', '')
  45. each_msg['data'] = each_msg['data'].replace(r'[表情]', '')
  46. each_msg['data'] = re.sub(r'(http|https|ftp)://[0-9a-zA-Z~./_\-]+', '', each_msg['data'])
  47. each_msg['data'] = re.sub(r'@.+ ', '', each_msg['data'])
  48. each_msg['data'] = re.sub(r'.+加入本群', '', each_msg['data'])
  49. each_msg['data'] = re.sub(r'.+被管理员禁言[0-9]{1,2}(分钟|天)', '', each_msg['data'])
  50. each_msg['data'] = re.sub(r'.+被管理员解除禁言', '', each_msg['data'])
  51. each_msg['data'] = re.sub(r'.+撤回了一条消息', '', each_msg['data'])
  52. each_msg['data'] = re.sub(r'\[礼物\] .+成为.+的守护者', '', each_msg['data'])
  53. each_msg['data'] = re.sub(r'\[送礼物\] 为.+', '', each_msg['data'])
  54. each_msg['data'] = re.sub(r'\[QQ红包\]我发了一个.*', '', each_msg['data'])
  55. each_msg['data'] = re.sub(r'\[动作消息\].+', '', each_msg['data'])
  56. each_msg['data'] = re.sub(r'\[闪照\].+', '', each_msg['data'])
  57. # each_msg['data'] = re_charset.sub("", each_msg['data'])
  58. pbar.update()
  59. def generate_dataset(messages, output_path_source, output_path_target):
  60. prev_msg = None
  61. with open(output_path_source, 'w', encoding='utf-8') as fs:
  62. with open(output_path_target, 'w', encoding='utf-8') as ft:
  63. with tqdm.tqdm(total=len(messages), ascii=True) as pbar:
  64. for each_msg in messages:
  65. if each_msg['data'].strip() == '':
  66. continue
  67. if prev_msg is not None:
  68. # filter conditions
  69. if prev_msg['data'].strip() != each_msg['data'].strip() \
  70. and len(each_msg['data']) < 64 and len(prev_msg['data']) < 64 \
  71. and prev_msg['user'] != each_msg['user'] \
  72. and not each_msg['data'].startswith('安安子') and not prev_msg['data'].startswith('安安子') \
  73. and each_msg['id'] == '(794424922)':
  74. prev_seg_list = jieba.cut(prev_msg['data'], cut_all=False)
  75. cur_seg_list = jieba.cut(each_msg['data'], cut_all=False)
  76. prev_seg_list = " ".join(prev_seg_list)
  77. cur_seg_list = " ".join(cur_seg_list)
  78. fs.write(prev_msg['id'] + ' : ' + prev_seg_list + '\n')
  79. ft.write(cur_seg_list + '\n')
  80. prev_msg = each_msg
  81. pbar.update()
  82. if __name__ == "__main__":
  83. print('read message')
  84. msg = read_qq_history_file('data/Octoon 开发组.txt')
  85. print('filter message')
  86. filter_msg(msg)
  87. print('write to file')
  88. generate_dataset(msg, 'data/valid_source.txt', 'data/valid_target.txt')
  89. print('read message')
  90. msg = read_qq_history_file('data/ISOIEC C++ China Unofficial.txt')
  91. print('filter message')
  92. filter_msg(msg)
  93. print('write to file')
  94. generate_dataset(msg, 'data/train_source.txt', 'data/train_target.txt')