conlleval.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Python version of the evaluation script from CoNLL'00-
  2. # Originates from: https://github.com/spyysalo/conlleval.py
  3. # Intentional differences:
  4. # - accept any space as delimiter by default
  5. # - optional file argument (default STDIN)
  6. # - option to set boundary (-b argument)
  7. # - LaTeX output (-l argument) not supported
  8. # - raw tags (-r argument) not supported
  9. # add function :evaluate(predicted_label, ori_label): which will not read from file
  10. import sys
  11. import re
  12. import codecs
  13. from collections import defaultdict, namedtuple
  14. ANY_SPACE = '<SPACE>'
  15. class FormatError(Exception):
  16. pass
  17. Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
  18. class EvalCounts(object):
  19. def __init__(self):
  20. self.correct_chunk = 0 # number of correctly identified chunks
  21. self.correct_tags = 0 # number of correct chunk tags
  22. self.found_correct = 0 # number of chunks in corpus
  23. self.found_guessed = 0 # number of identified chunks
  24. self.token_counter = 0 # token counter (ignores sentence breaks)
  25. # counts by type
  26. self.t_correct_chunk = defaultdict(int)
  27. self.t_found_correct = defaultdict(int)
  28. self.t_found_guessed = defaultdict(int)
  29. def parse_args(argv):
  30. import argparse
  31. parser = argparse.ArgumentParser(
  32. description='evaluate tagging results using CoNLL criteria',
  33. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  34. )
  35. arg = parser.add_argument
  36. arg('-b', '--boundary', metavar='STR', default='-X-',
  37. help='sentence boundary')
  38. arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
  39. help='character delimiting items in input')
  40. arg('-o', '--otag', metavar='CHAR', default='O',
  41. help='alternative outside tag')
  42. arg('file', nargs='?', default=None)
  43. return parser.parse_args(argv)
  44. def parse_tag(t):
  45. m = re.match(r'^([^-]*)-(.*)$', t)
  46. return m.groups() if m else (t, '')
  47. def evaluate(iterable, options=None):
  48. if options is None:
  49. options = parse_args([]) # use defaults
  50. counts = EvalCounts()
  51. num_features = None # number of features per line
  52. in_correct = False # currently processed chunks is correct until now
  53. last_correct = 'O' # previous chunk tag in corpus
  54. last_correct_type = '' # type of previously identified chunk tag
  55. last_guessed = 'O' # previously identified chunk tag
  56. last_guessed_type = '' # type of previous chunk tag in corpus
  57. for i, line in enumerate(iterable):
  58. line = line.rstrip('\r\n')
  59. # print(line)
  60. if options.delimiter == ANY_SPACE:
  61. features = line.split()
  62. else:
  63. features = line.split(options.delimiter)
  64. if num_features is None:
  65. num_features = len(features)
  66. elif num_features != len(features) and len(features) != 0:
  67. raise FormatError('unexpected number of features: %d (%d) at line %d\n%s' %
  68. (len(features), num_features, i, line))
  69. if len(features) == 0 or features[0] == options.boundary:
  70. features = [options.boundary, 'O', 'O']
  71. if len(features) < 3:
  72. raise FormatError('unexpected number of features in line %s' % line)
  73. guessed, guessed_type = parse_tag(features.pop())
  74. correct, correct_type = parse_tag(features.pop())
  75. first_item = features.pop(0)
  76. if first_item == options.boundary:
  77. guessed = 'O'
  78. end_correct = end_of_chunk(last_correct, correct,
  79. last_correct_type, correct_type)
  80. end_guessed = end_of_chunk(last_guessed, guessed,
  81. last_guessed_type, guessed_type)
  82. start_correct = start_of_chunk(last_correct, correct,
  83. last_correct_type, correct_type)
  84. start_guessed = start_of_chunk(last_guessed, guessed,
  85. last_guessed_type, guessed_type)
  86. if in_correct:
  87. if (end_correct and end_guessed and
  88. last_guessed_type == last_correct_type):
  89. in_correct = False
  90. counts.correct_chunk += 1
  91. counts.t_correct_chunk[last_correct_type] += 1
  92. elif (end_correct != end_guessed or guessed_type != correct_type):
  93. in_correct = False
  94. if start_correct and start_guessed and guessed_type == correct_type:
  95. in_correct = True
  96. if start_correct:
  97. counts.found_correct += 1
  98. counts.t_found_correct[correct_type] += 1
  99. if start_guessed:
  100. counts.found_guessed += 1
  101. counts.t_found_guessed[guessed_type] += 1
  102. if first_item != options.boundary:
  103. if correct == guessed and guessed_type == correct_type:
  104. counts.correct_tags += 1
  105. counts.token_counter += 1
  106. last_guessed = guessed
  107. last_correct = correct
  108. last_guessed_type = guessed_type
  109. last_correct_type = correct_type
  110. if in_correct:
  111. counts.correct_chunk += 1
  112. counts.t_correct_chunk[last_correct_type] += 1
  113. return counts
  114. def uniq(iterable):
  115. seen = set()
  116. return [i for i in iterable if not (i in seen or seen.add(i))]
  117. def calculate_metrics(correct, guessed, total):
  118. tp, fp, fn = correct, guessed-correct, total-correct
  119. p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
  120. r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
  121. f = 0 if p + r == 0 else 2 * p * r / (p + r)
  122. return Metrics(tp, fp, fn, p, r, f)
  123. def metrics(counts):
  124. c = counts
  125. overall = calculate_metrics(
  126. c.correct_chunk, c.found_guessed, c.found_correct
  127. )
  128. by_type = {}
  129. for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
  130. by_type[t] = calculate_metrics(
  131. c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
  132. )
  133. return overall, by_type
  134. def report(counts, out=None):
  135. if out is None:
  136. out = sys.stdout
  137. overall, by_type = metrics(counts)
  138. c = counts
  139. out.write('processed %d tokens with %d phrases; ' %
  140. (c.token_counter, c.found_correct))
  141. out.write('found: %d phrases; correct: %d.\n' %
  142. (c.found_guessed, c.correct_chunk))
  143. if c.token_counter > 0:
  144. out.write('accuracy: %6.2f%%; ' %
  145. (100.*c.correct_tags/c.token_counter))
  146. out.write('precision: %6.2f%%; ' % (100.*overall.prec))
  147. out.write('recall: %6.2f%%; ' % (100.*overall.rec))
  148. out.write('FB1: %6.2f\n' % (100.*overall.fscore))
  149. for i, m in sorted(by_type.items()):
  150. out.write('%17s: ' % i)
  151. out.write('precision: %6.2f%%; ' % (100.*m.prec))
  152. out.write('recall: %6.2f%%; ' % (100.*m.rec))
  153. out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
  154. def report_notprint(counts):
  155. overall, by_type = metrics(counts)
  156. c = counts
  157. final_report = []
  158. line = []
  159. line.append('processed %d tokens with %d phrases; ' %
  160. (c.token_counter, c.found_correct))
  161. line.append('found: %d phrases; correct: %d.\n' %
  162. (c.found_guessed, c.correct_chunk))
  163. final_report.append("".join(line))
  164. if c.token_counter > 0:
  165. line = []
  166. line.append('accuracy: %6.2f%%; ' %
  167. (100.*c.correct_tags/c.token_counter))
  168. line.append('precision: %6.2f%%; ' % (100.*overall.prec))
  169. line.append('recall: %6.2f%%; ' % (100.*overall.rec))
  170. line.append('FB1: %6.2f\n' % (100.*overall.fscore))
  171. final_report.append("".join(line))
  172. for i, m in sorted(by_type.items()):
  173. line = []
  174. line.append('%17s: ' % i)
  175. line.append('precision: %6.2f%%; ' % (100.*m.prec))
  176. line.append('recall: %6.2f%%; ' % (100.*m.rec))
  177. line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
  178. final_report.append("".join(line))
  179. return final_report
  180. def end_of_chunk(prev_tag, tag, prev_type, type_):
  181. # check if a chunk ended between the previous and current word
  182. # arguments: previous and current chunk tags, previous and current types
  183. chunk_end = False
  184. if prev_tag == 'E': chunk_end = True
  185. if prev_tag == 'S': chunk_end = True
  186. if prev_tag == 'B' and tag == 'B': chunk_end = True
  187. if prev_tag == 'B' and tag == 'S': chunk_end = True
  188. if prev_tag == 'B' and tag == 'O': chunk_end = True
  189. if prev_tag == 'I' and tag == 'B': chunk_end = True
  190. if prev_tag == 'I' and tag == 'S': chunk_end = True
  191. if prev_tag == 'I' and tag == 'O': chunk_end = True
  192. if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
  193. chunk_end = True
  194. # these chunks are assumed to have length 1
  195. if prev_tag == ']': chunk_end = True
  196. if prev_tag == '[': chunk_end = True
  197. return chunk_end
  198. def start_of_chunk(prev_tag, tag, prev_type, type_):
  199. # check if a chunk started between the previous and current word
  200. # arguments: previous and current chunk tags, previous and current types
  201. chunk_start = False
  202. if tag == 'B': chunk_start = True
  203. if tag == 'S': chunk_start = True
  204. if prev_tag == 'E' and tag == 'E': chunk_start = True
  205. if prev_tag == 'E' and tag == 'I': chunk_start = True
  206. if prev_tag == 'S' and tag == 'E': chunk_start = True
  207. if prev_tag == 'S' and tag == 'I': chunk_start = True
  208. if prev_tag == 'O' and tag == 'E': chunk_start = True
  209. if prev_tag == 'O' and tag == 'I': chunk_start = True
  210. if tag != 'O' and tag != '.' and prev_type != type_:
  211. chunk_start = True
  212. # these chunks are assumed to have length 1
  213. if tag == '[': chunk_start = True
  214. if tag == ']': chunk_start = True
  215. return chunk_start
  216. def main(argv):
  217. args = parse_args(argv[1:])
  218. if args.file is None:
  219. counts = evaluate(sys.stdin, args)
  220. else:
  221. with open(args.file) as f:
  222. counts = evaluate(f, args)
  223. report(counts)
  224. def return_report(input_file):
  225. with open(input_file, "r") as f:
  226. counts = evaluate(f)
  227. return report_notprint(counts)
  228. if __name__ == '__main__':
  229. # sys.exit(main(sys.argv))
  230. return_report('/home/pengy6/data/sentence_similarity/data/cdr/test1/wanli_result2/label_test.txt')