util.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # NOTE: The purpose of this file is not to accumulate all useful utility
  7. # functions. This file should contain very commonly used and requested functions
  8. # (such as test). If you think you have a function at that level, please create
  9. # an issue and we will happily review your suggestion. This file is also not supposed
  10. # to pull in dependencies outside of numpy/scipy without very good reasons. For
  11. # example, this file should not use sklearn and matplotlib to produce a t-sne
  12. # plot of word embeddings or such.
  13. from __future__ import absolute_import
  14. from __future__ import division
  15. from __future__ import print_function
  16. from __future__ import unicode_literals
  17. import numpy as np
  18. # TODO: Add example on reproducing model.test with util.test and model.get_line
  19. def test(predictions, labels, k=1):
  20. """
  21. Return precision and recall modeled after fasttext's test
  22. """
  23. precision = 0.0
  24. nexamples = 0
  25. nlabels = 0
  26. for prediction, labels in zip(predictions, labels):
  27. for p in prediction:
  28. if p in labels:
  29. precision += 1
  30. nexamples += 1
  31. nlabels += len(labels)
  32. return (precision / (k * nexamples), precision / nlabels)
  33. def find_nearest_neighbor(query, vectors, ban_set, cossims=None):
  34. """
  35. query is a 1d numpy array corresponding to the vector to which you want to
  36. find the closest vector
  37. vectors is a 2d numpy array corresponding to the vectors you want to consider
  38. ban_set is a set of indicies within vectors you want to ignore for nearest match
  39. cossims is a 1d numpy array of size len(vectors), which can be passed for efficiency
  40. returns the index of the closest match to query within vectors
  41. """
  42. if cossims is None:
  43. cossims = np.matmul(vectors, query, out=cossims)
  44. else:
  45. np.matmul(vectors, query, out=cossims)
  46. rank = len(cossims) - 1
  47. result_i = np.argpartition(cossims, rank)[rank]
  48. while result_i in ban_set:
  49. rank -= 1
  50. result_i = np.argpartition(cossims, rank)[rank]
  51. return result_i