util.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright (c) 2017-present, Facebook, Inc.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the MIT license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. # NOTE: The purpose of this file is not to accumulate all useful utility
  9. # functions. This file should contain very commonly used and requested functions
  10. # (such as test). If you think you have a function at that level, please create
  11. # an issue and we will happily review your suggestion. This file is also not supposed
  12. # to pull in dependencies outside of numpy/scipy without very good reasons. For
  13. # example, this file should not use sklearn and matplotlib to produce a t-sne
  14. # plot of word embeddings or such.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from __future__ import unicode_literals
  19. import numpy as np
  20. import sys
  21. import shutil
  22. import os
  23. import gzip
  24. try:
  25. from urllib.request import urlopen
  26. except ImportError:
  27. from urllib2 import urlopen
  28. valid_lang_ids = {"af", "sq", "als", "am", "ar", "an", "hy", "as", "ast",
  29. "az", "ba", "eu", "bar", "be", "bn", "bh", "bpy", "bs",
  30. "br", "bg", "my", "ca", "ceb", "bcl", "ce", "zh", "cv",
  31. "co", "hr", "cs", "da", "dv", "nl", "pa", "arz", "eml",
  32. "en", "myv", "eo", "et", "hif", "fi", "fr", "gl", "ka",
  33. "de", "gom", "el", "gu", "ht", "he", "mrj", "hi", "hu",
  34. "is", "io", "ilo", "id", "ia", "ga", "it", "ja", "jv",
  35. "kn", "pam", "kk", "km", "ky", "ko", "ku", "ckb", "la",
  36. "lv", "li", "lt", "lmo", "nds", "lb", "mk", "mai", "mg",
  37. "ms", "ml", "mt", "gv", "mr", "mzn", "mhr", "min", "xmf",
  38. "mwl", "mn", "nah", "nap", "ne", "new", "frr", "nso",
  39. "no", "nn", "oc", "or", "os", "pfl", "ps", "fa", "pms",
  40. "pl", "pt", "qu", "ro", "rm", "ru", "sah", "sa", "sc",
  41. "sco", "gd", "sr", "sh", "scn", "sd", "si", "sk", "sl",
  42. "so", "azb", "es", "su", "sw", "sv", "tl", "tg", "ta",
  43. "tt", "te", "th", "bo", "tr", "tk", "uk", "hsb", "ur",
  44. "ug", "uz", "vec", "vi", "vo", "wa", "war", "cy", "vls",
  45. "fy", "pnb", "yi", "yo", "diq", "zea"}
  46. # TODO: Add example on reproducing model.test with util.test and model.get_line
  47. def test(predictions, labels, k=1):
  48. """
  49. Return precision and recall modeled after fasttext's test
  50. """
  51. precision = 0.0
  52. nexamples = 0
  53. nlabels = 0
  54. for prediction, labels in zip(predictions, labels):
  55. for p in prediction:
  56. if p in labels:
  57. precision += 1
  58. nexamples += 1
  59. nlabels += len(labels)
  60. return (precision / (k * nexamples), precision / nlabels)
  61. def find_nearest_neighbor(query, vectors, ban_set, cossims=None):
  62. """
  63. query is a 1d numpy array corresponding to the vector to which you want to
  64. find the closest vector
  65. vectors is a 2d numpy array corresponding to the vectors you want to consider
  66. ban_set is a set of indicies within vectors you want to ignore for nearest match
  67. cossims is a 1d numpy array of size len(vectors), which can be passed for efficiency
  68. returns the index of the closest match to query within vectors
  69. """
  70. if cossims is None:
  71. cossims = np.matmul(vectors, query, out=cossims)
  72. else:
  73. np.matmul(vectors, query, out=cossims)
  74. rank = len(cossims) - 1
  75. result_i = np.argpartition(cossims, rank)[rank]
  76. while result_i in ban_set:
  77. rank -= 1
  78. result_i = np.argpartition(cossims, rank)[rank]
  79. return result_i
  80. def _reduce_matrix(X_orig, dim, eigv):
  81. """
  82. Reduces the dimension of a (m × n) matrix `X_orig` to
  83. to a (m × dim) matrix `X_reduced`
  84. It uses only the first 100000 rows of `X_orig` to do the mapping.
  85. Matrix types are all `np.float32` in order to avoid unncessary copies.
  86. """
  87. if eigv is None:
  88. mapping_size = 100000
  89. X = X_orig[:mapping_size]
  90. X = X - X.mean(axis=0, dtype=np.float32)
  91. C = np.divide(np.matmul(X.T, X), X.shape[0] - 1, dtype=np.float32)
  92. _, U = np.linalg.eig(C)
  93. eigv = U[:, :dim]
  94. X_reduced = np.matmul(X_orig, eigv)
  95. return (X_reduced, eigv)
  96. def reduce_model(ft_model, target_dim):
  97. """
  98. ft_model is an instance of `_FastText` class
  99. This function computes the PCA of the input and the output matrices
  100. and sets the reduced ones.
  101. """
  102. inp_reduced, proj = _reduce_matrix(
  103. ft_model.get_input_matrix(), target_dim, None)
  104. out_reduced, _ = _reduce_matrix(
  105. ft_model.get_output_matrix(), target_dim, proj)
  106. ft_model.set_matrices(inp_reduced, out_reduced)
  107. return ft_model
  108. def _print_progress(downloaded_bytes, total_size):
  109. percent = float(downloaded_bytes) / total_size
  110. bar_size = 50
  111. bar = int(percent * bar_size)
  112. percent = round(percent * 100, 2)
  113. sys.stdout.write(" (%0.2f%%) [" % percent)
  114. sys.stdout.write("=" * bar)
  115. sys.stdout.write(">")
  116. sys.stdout.write(" " * (bar_size - bar))
  117. sys.stdout.write("]\r")
  118. sys.stdout.flush()
  119. if downloaded_bytes >= total_size:
  120. sys.stdout.write('\n')
  121. def _download_file(url, write_file_name, chunk_size=2**13):
  122. print("Downloading %s" % url)
  123. response = urlopen(url)
  124. if hasattr(response, 'getheader'):
  125. file_size = int(response.getheader('Content-Length').strip())
  126. else:
  127. file_size = int(response.info().getheader('Content-Length').strip())
  128. downloaded = 0
  129. download_file_name = write_file_name + ".part"
  130. with open(download_file_name, 'wb') as f:
  131. while True:
  132. chunk = response.read(chunk_size)
  133. downloaded += len(chunk)
  134. if not chunk:
  135. break
  136. f.write(chunk)
  137. _print_progress(downloaded, file_size)
  138. os.rename(download_file_name, write_file_name)
  139. def _download_gz_model(gz_file_name, if_exists):
  140. if os.path.isfile(gz_file_name):
  141. if if_exists == 'ignore':
  142. return True
  143. elif if_exists == 'strict':
  144. print("gzip File exists. Use --overwrite to download anyway.")
  145. return False
  146. elif if_exists == 'overwrite':
  147. pass
  148. url = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/%s" % gz_file_name
  149. _download_file(url, gz_file_name)
  150. return True
  151. def download_model(lang_id, if_exists='strict', dimension=None):
  152. """
  153. Download pre-trained common-crawl vectors from fastText's website
  154. https://fasttext.cc/docs/en/crawl-vectors.html
  155. """
  156. if lang_id not in valid_lang_ids:
  157. raise Exception("Invalid lang id. Please select among %s" %
  158. repr(valid_lang_ids))
  159. file_name = "cc.%s.300.bin" % lang_id
  160. gz_file_name = "%s.gz" % file_name
  161. if os.path.isfile(file_name):
  162. if if_exists == 'ignore':
  163. return file_name
  164. elif if_exists == 'strict':
  165. print("File exists. Use --overwrite to download anyway.")
  166. return
  167. elif if_exists == 'overwrite':
  168. pass
  169. if _download_gz_model(gz_file_name, if_exists):
  170. with gzip.open(gz_file_name, 'rb') as f:
  171. with open(file_name, 'wb') as f_out:
  172. shutil.copyfileobj(f, f_out)
  173. return file_name