metrics.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. def __levenshtein(a, b):
  15. """Calculates the Levenshtein distance between two sequences."""
  16. n, m = len(a), len(b)
  17. if n > m:
  18. # Make sure n <= m, to use O(min(n,m)) space
  19. a, b = b, a
  20. n, m = m, n
  21. current = list(range(n + 1))
  22. for i in range(1, m + 1):
  23. previous, current = current, [i] + [0] * n
  24. for j in range(1, n + 1):
  25. add, delete = previous[j] + 1, current[j - 1] + 1
  26. change = previous[j - 1]
  27. if a[j - 1] != b[i - 1]:
  28. change = change + 1
  29. current[j] = min(add, delete, change)
  30. return current[n]
  31. def word_error_rate(hypotheses, references):
  32. """Computes average Word Error Rate (WER) between two text lists."""
  33. scores = 0
  34. words = 0
  35. len_diff = len(references) - len(hypotheses)
  36. if len_diff > 0:
  37. raise ValueError("Uneqal number of hypthoses and references: "
  38. "{0} and {1}".format(len(hypotheses), len(references)))
  39. elif len_diff < 0:
  40. hypotheses = hypotheses[:len_diff]
  41. for h, r in zip(hypotheses, references):
  42. h_list = h.split()
  43. r_list = r.split()
  44. words += len(r_list)
  45. scores += __levenshtein(h_list, r_list)
  46. if words!=0:
  47. wer = 1.0*scores/words
  48. else:
  49. wer = float('inf')
  50. return wer, scores, words