FastText.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import fasttext_pybind as fasttext
  12. import numpy as np
  13. loss_name = fasttext.loss_name
  14. model_name = fasttext.model_name
  15. EOS = "</s>"
  16. BOW = "<"
  17. EOW = ">"
  18. class _FastText():
  19. """
  20. This class defines the API to inspect models and should not be used to
  21. create objects. It will be returned by functions such as load_model or
  22. train.
  23. In general this API assumes to be given only unicode for Python2 and the
  24. Python3 equvalent called str for any string-like arguments. All unicode
  25. strings are then encoded as UTF-8 and fed to the fastText C++ API.
  26. """
  27. def __init__(self, model=None):
  28. self.f = fasttext.fasttext()
  29. if model is not None:
  30. self.f.loadModel(model)
  31. def is_quantized(self):
  32. return self.f.isQuant()
  33. def get_dimension(self):
  34. """Get the dimension (size) of a lookup vector (hidden layer)."""
  35. a = self.f.getArgs()
  36. return a.dim
  37. def get_word_vector(self, word):
  38. """Get the vector representation of word."""
  39. dim = self.get_dimension()
  40. b = fasttext.Vector(dim)
  41. self.f.getWordVector(b, word)
  42. return np.array(b)
  43. def get_sentence_vector(self, text):
  44. """
  45. Given a string, get a single vector represenation. This function
  46. assumes to be given a single line of text. We split words on
  47. whitespace (space, newline, tab, vertical tab) and the control
  48. characters carriage return, formfeed and the null character.
  49. """
  50. if text.find('\n') != -1:
  51. raise ValueError(
  52. "predict processes one line at a time (remove \'\\n\')"
  53. )
  54. text += "\n"
  55. dim = self.get_dimension()
  56. b = fasttext.Vector(dim)
  57. self.f.getSentenceVector(b, text)
  58. return np.array(b)
  59. def get_word_id(self, word):
  60. """
  61. Given a word, get the word id within the dictionary.
  62. Returns -1 if word is not in the dictionary.
  63. """
  64. return self.f.getWordId(word)
  65. def get_subword_id(self, subword):
  66. """
  67. Given a subword, return the index (within input matrix) it hashes to.
  68. """
  69. return self.f.getSubwordId(subword)
  70. def get_subwords(self, word):
  71. """
  72. Given a word, get the subwords and their indicies.
  73. """
  74. pair = self.f.getSubwords(word)
  75. return pair[0], np.array(pair[1])
  76. def get_input_vector(self, ind):
  77. """
  78. Given an index, get the corresponding vector of the Input Matrix.
  79. """
  80. dim = self.get_dimension()
  81. b = fasttext.Vector(dim)
  82. self.f.getInputVector(b, ind)
  83. return np.array(b)
  84. # Process one line only!
  85. def predict(self, text, k=1):
  86. """
  87. Given a string, get a list of labels and a list of
  88. corresponding probabilities. k controls the number
  89. of returned labels. A choice of 5, will return the 5
  90. most probable labels. By default this returns only
  91. the most likely label and probability.
  92. This function assumes to be given
  93. a single line of text. We split words on whitespace (space,
  94. newline, tab, vertical tab) and the control characters carriage
  95. return, formfeed and the null character.
  96. If the model is not supervised, this function will throw a ValueError.
  97. """
  98. if text.find('\n') != -1:
  99. raise ValueError(
  100. "predict processes one line at a time (remove \'\\n\')"
  101. )
  102. text += "\n"
  103. pairs = self.f.predict(text, k)
  104. probs, labels = zip(*pairs)
  105. probs = np.exp(np.array(probs))
  106. return labels, probs
  107. def get_input_matrix(self):
  108. """
  109. Get a copy of the full input matrix of a Model. This only
  110. works if the model is not quantized.
  111. """
  112. if self.f.isQuant():
  113. raise ValueError("Can't get quantized Matrix")
  114. return np.array(self.f.getInputMatrix())
  115. def get_output_matrix(self):
  116. """
  117. Get a copy of the full output matrix of a Model. This only
  118. works if the model is not quantized.
  119. """
  120. if self.f.isQuant():
  121. raise ValueError("Can't get quantized Matrix")
  122. return np.array(self.f.getOutputMatrix())
  123. def get_words(self, include_freq=False):
  124. """
  125. Get the entire list of words of the dictionary optionally
  126. including the frequency of the individual words. This
  127. does not include any subwords. For that please consult
  128. the function get_subwords.
  129. """
  130. pair = self.f.getVocab()
  131. if include_freq:
  132. return (pair[0], np.array(pair[1]))
  133. else:
  134. return pair[0]
  135. def get_labels(self, include_freq=False):
  136. """
  137. Get the entire list of labels of the dictionary optionally
  138. including the frequency of the individual labels. Unsupervised
  139. models use words as labels, which is why get_labels
  140. will call and return get_words for this type of
  141. model.
  142. """
  143. a = self.f.getArgs()
  144. if a.model == model_name.supervised:
  145. pair = self.f.getLabels()
  146. if include_freq:
  147. return (pair[0], np.array(pair[1]))
  148. else:
  149. return pair[0]
  150. else:
  151. return self.get_words(include_freq)
  152. def save_model(self, path):
  153. """Save the model to the given path"""
  154. self.f.saveModel(path)
  155. def quantize(
  156. self,
  157. input="",
  158. qout=False,
  159. cutoff=0,
  160. retrain=False,
  161. epoch=None,
  162. lr=None,
  163. thread=None,
  164. verbose=None,
  165. dsub=2,
  166. qnorm=False
  167. ):
  168. """
  169. Quantize the model reducing the size of the model and
  170. it's memory footprint.
  171. """
  172. a = self.f.getArgs()
  173. if not epoch:
  174. epoch = a.epoch
  175. if not lr:
  176. lr = a.lr
  177. if not thread:
  178. thread = a.thread
  179. if not verbose:
  180. verbose = a.verbose
  181. self.f.quantize(
  182. input, qout, cutoff, retrain, epoch, lr, thread, verbose, dsub,
  183. qnorm
  184. )
  185. # TODO:
  186. # Not supported:
  187. # - pretrained vectors
  188. def _parse_model_string(string):
  189. if string == "cbow":
  190. return model_name.cbow
  191. if string == "skipgram":
  192. return model_name.skipgram
  193. if string == "supervised":
  194. return model_name.supervised
  195. else:
  196. raise ValueError("Unrecognized model name")
  197. def _parse_loss_string(string):
  198. if string == "ns":
  199. return loss_name.ns
  200. if string == "hs":
  201. return loss_name.hs
  202. if string == "softmax":
  203. return loss_name.softmax
  204. else:
  205. raise ValueError("Unrecognized loss name")
  206. def _build_args(args):
  207. args["model"] = _parse_model_string(args["model"])
  208. args["loss"] = _parse_loss_string(args["loss"])
  209. a = fasttext.args()
  210. for (k, v) in args.items():
  211. setattr(a, k, v)
  212. a.output = "" # User should use save_model
  213. a.pretrainedVectors = "" # Unsupported
  214. a.saveOutput = 0 # Never use this
  215. if a.wordNgrams <= 1 and a.maxn == 0:
  216. a.bucket = 0
  217. return a
  218. def tokenize(text):
  219. """Given a string of text, tokenize it and return a list of tokens"""
  220. f = fasttext.fasttext()
  221. return f.tokenize(text)
  222. def load_model(path):
  223. """Load a model given a filepath and return a model object."""
  224. return _FastText(path)
  225. def train_supervised(
  226. input,
  227. lr=0.1,
  228. dim=100,
  229. ws=5,
  230. epoch=5,
  231. minCount=1,
  232. minCountLabel=0,
  233. minn=0,
  234. maxn=0,
  235. neg=5,
  236. wordNgrams=1,
  237. loss="softmax",
  238. bucket=2000000,
  239. thread=12,
  240. lrUpdateRate=100,
  241. t=1e-4,
  242. label="__label__",
  243. verbose=2,
  244. pretrainedVectors="",
  245. ):
  246. """
  247. Train a supervised model and return a model object.
  248. input must be a filepath. The input text does not need to be tokenized
  249. as per the tokenize function, but it must be preprocessed and encoded
  250. as UTF-8. You might want to consult standard preprocessing scripts such
  251. as tokenizer.perl mentioned here: http://www.statmt.org/wmt07/baseline.html
  252. The input file must must contain at least one label per line. For an
  253. example consult the example datasets which are part of the fastText
  254. repository such as the dataset pulled by classification-example.sh.
  255. """
  256. model = "supervised"
  257. a = _build_args(locals())
  258. ft = _FastText()
  259. fasttext.train(ft.f, a)
  260. return ft
  261. def train_unsupervised(
  262. input,
  263. model="skipgram",
  264. lr=0.05,
  265. dim=100,
  266. ws=5,
  267. epoch=5,
  268. minCount=5,
  269. minCountLabel=0,
  270. minn=3,
  271. maxn=6,
  272. neg=5,
  273. wordNgrams=1,
  274. loss="ns",
  275. bucket=2000000,
  276. thread=12,
  277. lrUpdateRate=100,
  278. t=1e-4,
  279. label="__label__",
  280. verbose=2,
  281. pretrainedVectors="",
  282. ):
  283. """
  284. Train an unsupervised model and return a model object.
  285. input must be a filepath. The input text does not need to be tokenized
  286. as per the tokenize function, but it must be preprocessed and encoded
  287. as UTF-8. You might want to consult standard preprocessing scripts such
  288. as tokenizer.perl mentioned here: http://www.statmt.org/wmt07/baseline.html
  289. The input fiel must not contain any labels or use the specified label prefix
  290. unless it is ok for those words to be ignored. For an example consult the
  291. dataset pulled by the example script word-vector-example.sh, which is
  292. part of the fastText repository.
  293. """
  294. a = _build_args(locals())
  295. ft = _FastText()
  296. fasttext.train(ft.f, a)
  297. return ft