test_script.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import absolute_import
  9. from __future__ import division
  10. from __future__ import print_function
  11. from __future__ import unicode_literals
  12. from fastText import train_supervised
  13. from fastText import train_unsupervised
  14. from fastText import load_model
  15. from fastText import tokenize
  16. import random
  17. import sys
  18. import os
  19. import subprocess
  20. import multiprocessing
  21. import numpy as np
  22. import unittest
  23. import tempfile
  24. import math
  25. from scipy import stats
  26. def compat_splitting(line):
  27. return line.decode('utf8').split()
  28. def similarity(v1, v2):
  29. n1 = np.linalg.norm(v1)
  30. n2 = np.linalg.norm(v2)
  31. return np.dot(v1, v2) / n1 / n2
  32. def read_vectors(model_path):
  33. vectors = {}
  34. with open(model_path, 'rb') as fin:
  35. for _, line in enumerate(fin):
  36. try:
  37. tab = compat_splitting(line)
  38. vec = np.array(tab[1:], dtype=float)
  39. word = tab[0]
  40. if np.linalg.norm(vec) == 0:
  41. continue
  42. if word not in vectors:
  43. vectors[word] = vec
  44. except ValueError:
  45. continue
  46. except UnicodeDecodeError:
  47. continue
  48. return vectors
  49. def compute_similarity(model_path, data_path, vectors=None):
  50. if not vectors:
  51. vectors = read_vectors(model_path)
  52. mysim = []
  53. gold = []
  54. drop = 0.0
  55. nwords = 0.0
  56. with open(data_path, 'rb') as fin:
  57. for line in fin:
  58. tline = compat_splitting(line)
  59. word1 = tline[0].lower()
  60. word2 = tline[1].lower()
  61. nwords = nwords + 1.0
  62. if (word1 in vectors) and (word2 in vectors):
  63. v1 = vectors[word1]
  64. v2 = vectors[word2]
  65. d = similarity(v1, v2)
  66. mysim.append(d)
  67. gold.append(float(tline[2]))
  68. else:
  69. drop = drop + 1.0
  70. corr = stats.spearmanr(mysim, gold)
  71. dataset = os.path.basename(data_path)
  72. correlation = corr[0] * 100
  73. oov = math.ceil(drop / nwords * 100.0)
  74. return dataset, correlation, oov
  75. def get_random_unicode(length):
  76. # See: https://stackoverflow.com/questions/1477294/generate-random-utf-8-string-in-python
  77. try:
  78. get_char = unichr
  79. except NameError:
  80. get_char = chr
  81. # Update this to include code point ranges to be sampled
  82. include_ranges = [
  83. (0x0021, 0x0021),
  84. (0x0023, 0x0026),
  85. (0x0028, 0x007E),
  86. (0x00A1, 0x00AC),
  87. (0x00AE, 0x00FF),
  88. (0x0100, 0x017F),
  89. (0x0180, 0x024F),
  90. (0x2C60, 0x2C7F),
  91. (0x16A0, 0x16F0),
  92. (0x0370, 0x0377),
  93. (0x037A, 0x037E),
  94. (0x0384, 0x038A),
  95. (0x038C, 0x038C),
  96. ]
  97. alphabet = [
  98. get_char(code_point)
  99. for current_range in include_ranges
  100. for code_point in range(current_range[0], current_range[1] + 1)
  101. ]
  102. return ''.join(random.choice(alphabet) for i in range(length))
  103. def get_random_words(N, a, b):
  104. words = []
  105. for _ in range(N):
  106. length = random.randint(a, b)
  107. words.append(get_random_unicode(length))
  108. return words
  109. class TestFastTextPy(unittest.TestCase):
  110. @classmethod
  111. def eprint(cls, *args, **kwargs):
  112. print(*args, file=sys.stderr, **kwargs)
  113. @classmethod
  114. def num_thread(cls):
  115. return multiprocessing.cpu_count() - 1
  116. @classmethod
  117. def build_paths(cls, train, test, output):
  118. train = os.path.join(cls.data_dir, train)
  119. test = os.path.join(cls.data_dir, test)
  120. output = os.path.join(cls.result_dir, output)
  121. return train, test, output
  122. @classmethod
  123. def build_train_args(cls, params, mode, train, output):
  124. args = [cls.bin, mode, "-input", train, "-output", output]
  125. return args + params.split(' ')
  126. @classmethod
  127. def get_train_output(cls, train_args):
  128. cls.eprint("Executing: " + ' '.join(train_args))
  129. return subprocess.check_output(train_args).decode('utf-8')
  130. @classmethod
  131. def get_path_size(cls, path):
  132. path_size = subprocess.check_output(["stat", "-c", "%s",
  133. path]).decode('utf-8')
  134. path_size = int(path_size)
  135. return path_size
  136. @classmethod
  137. def default_test_args(cls, model, test, quantize=False):
  138. return [cls.bin, "test", model, test]
  139. @classmethod
  140. def get_test_output(cls, test_args):
  141. cls.eprint("Executing: " + ' '.join(test_args))
  142. test_output = subprocess.check_output(test_args)
  143. test_output = test_output.decode('utf-8')
  144. cls.eprint("Test output:\n" + test_output)
  145. return list(
  146. map(lambda x: x.split('\t')[1], test_output.split('\n')[:-1])
  147. )
  148. @classmethod
  149. def train_generic_classifier(cls, train, output):
  150. thread = cls.num_thread()
  151. cls.eprint("Using {} threads".format(thread))
  152. sup_params = (
  153. "-dim 10 -lr 0.1 -wordNgrams 2 -minCount 1 -bucket 10000000 "
  154. "-epoch 5 -thread {}".format(thread)
  155. )
  156. mode = 'supervised'
  157. cls.get_train_output(
  158. cls.build_train_args(sup_params, mode, train, output)
  159. )
  160. @classmethod
  161. def train_generic_embeddings(cls, train, output):
  162. thread = cls.num_thread()
  163. cls.eprint("Using {} threads".format(thread))
  164. unsup_params = (
  165. "-thread {} -lr 0.025 -dim 100 -ws 5 -epoch 1 -minCount 5 "
  166. "-neg 5 -loss ns -bucket 2000000 -minn 3 -maxn 6 -t 1e-4 "
  167. "-lrUpdateRate 100".format(thread)
  168. )
  169. mode = 'cbow'
  170. cls.get_train_output(
  171. cls.build_train_args(unsup_params, mode, train, output)
  172. )
  173. def get_predictions_from_list(self, output, words, k):
  174. args = [self.bin, "predict-prob", output + '.bin', '-', str(k)]
  175. self.eprint("Executing: " + ' '.join(args))
  176. p = subprocess.Popen(
  177. args, stdin=subprocess.PIPE, stdout=subprocess.PIPE
  178. )
  179. test_text = ""
  180. if words:
  181. test_text = '\n'.join(words) + '\n'
  182. test_text = test_text.encode('utf-8')
  183. stdout, stderr = p.communicate(test_text)
  184. stdout = stdout.decode('utf-8')
  185. return stdout, stderr, p.returncode
  186. def get_word_vectors_from_list(self, output, words):
  187. args = [self.bin, "print-word-vectors", output + '.bin']
  188. self.eprint("Executing: " + ' '.join(args))
  189. p = subprocess.Popen(
  190. args, stdin=subprocess.PIPE, stdout=subprocess.PIPE
  191. )
  192. test_text = '\n'.join(words).encode('utf-8')
  193. stdout, stderr = p.communicate(test_text)
  194. return stdout
  195. class TestFastTextPyUnit(TestFastTextPy):
  196. @classmethod
  197. def setUpClass(cls):
  198. cls.bin = os.environ['FASTTEXT_BIN']
  199. cls.data_dir = os.environ['FASTTEXT_DATA']
  200. cls.result_dir = tempfile.mkdtemp()
  201. train, _, output = cls.build_paths("fil9", "rw/rw.txt", "fil9")
  202. cls.train_generic_embeddings(train, output)
  203. cls.output = output
  204. train, _, output_sup = cls.build_paths(
  205. "dbpedia.train", "dbpedia.test", "dbpedia"
  206. )
  207. cls.train_generic_classifier(train, output_sup)
  208. cls.output_sup = output_sup
  209. @classmethod
  210. def tearDownClass(cls):
  211. pass
  212. # shutil.rmtree(cls.result_dir)
  213. # Check if get_word_vector aligns with vectors from stdin
  214. def test_getvector(self):
  215. f = load_model(self.output + '.bin')
  216. words, _ = f.get_words(include_freq=True)
  217. words += get_random_words(100, 1, 100)
  218. ftbin_vectors = self.get_word_vectors_from_list(self.output, words)
  219. ftbin_vectors = ftbin_vectors.decode('utf-8').split('\n')[:-1]
  220. for v in ftbin_vectors:
  221. word = v.split(' ')[0]
  222. vector = v.split(' ')[1:-1]
  223. vector = np.array(list(map(float, vector)))
  224. pvec = f.get_word_vector(word)
  225. # The fasttext cli returns floats with 5 digits,
  226. # but we use the full 6 digits.
  227. self.assertTrue(np.allclose(vector, pvec, rtol=1e-04))
  228. def test_predict(self):
  229. # TODO: I went a little crazy here as an exercise for
  230. # a rigorous test case. This could be turned into
  231. # a few utility functions.
  232. f = load_model(self.output_sup + '.bin')
  233. def _test(N, min_length, max_length, k, add_vocab=0):
  234. words = get_random_words(N, min_length, max_length)
  235. if add_vocab > 0:
  236. vocab, _ = f.get_words(include_freq=True)
  237. for _ in range(add_vocab):
  238. ind = random.randint(0, len(vocab))
  239. words += [vocab[ind]]
  240. all_labels = []
  241. all_probs = []
  242. ii = 0
  243. gotError = False
  244. for w in words:
  245. try:
  246. labels, probs = f.predict(w, k)
  247. except ValueError:
  248. gotError = True
  249. continue
  250. all_labels.append(labels)
  251. all_probs.append(probs)
  252. ii += 1
  253. preds, _, retcode = self.get_predictions_from_list(
  254. self.output_sup, words, k
  255. )
  256. if gotError and retcode == 0:
  257. self.eprint(
  258. "Didn't get error. Make sure your compiled "
  259. "binary kept the assert statements"
  260. )
  261. self.assertTrue(False)
  262. else:
  263. return
  264. preds = preds.split('\n')[:-1]
  265. self.assertEqual(len(preds), len(all_labels))
  266. for i in range(len(preds)):
  267. labels = preds[i].split()
  268. probs = np.array(list(map(float, labels[1::2])))
  269. labels = np.array(labels[::2])
  270. self.assertTrue(np.allclose(probs, all_probs[i], rtol=1e-04))
  271. self.assertTrue(np.array_equal(labels, all_labels[i]))
  272. _test(0, 0, 0, 0)
  273. _test(1, 0, 0, 0)
  274. _test(10, 0, 0, 0)
  275. _test(1, 1, 1, 0)
  276. _test(1, 1, 1, 1)
  277. _test(1, 2, 3, 0)
  278. _test(1, 2, 3, 1)
  279. _test(10, 1, 1, 1)
  280. _test(1, 1, 1, 0, add_vocab=10)
  281. _test(1, 1, 1, 1, add_vocab=10)
  282. _test(1, 2, 3, 0, add_vocab=10)
  283. _test(1, 2, 3, 1, add_vocab=10)
  284. reach = 10
  285. for _ in range(10):
  286. N = random.randint(0, reach)
  287. init = random.randint(0, reach)
  288. offset = random.randint(0, reach)
  289. k = random.randint(0, reach)
  290. _test(N, init, init + offset, k)
  291. def test_vocab(self):
  292. f = load_model(self.output + '.bin')
  293. words, freq = f.get_words(include_freq=True)
  294. self.eprint(
  295. "There is no way to access words from the cli yet. "
  296. "Therefore there can be no rigorous test."
  297. )
  298. def test_subwords(self):
  299. f = load_model(self.output + '.bin')
  300. words, _ = f.get_words(include_freq=True)
  301. words += get_random_words(10, 1, 10)
  302. for w in words:
  303. f.get_subwords(w)
  304. self.eprint(
  305. "There is no way to access words from the cli yet. "
  306. "Therefore there can be no test."
  307. )
  308. def test_tokenize(self):
  309. train, _, _ = self.build_paths("fil9", "rw/rw.txt", "fil9")
  310. with open(train, 'r') as f:
  311. _ = tokenize(f.read())
  312. def test_dimension(self):
  313. f = load_model(self.output + '.bin')
  314. f.get_dimension()
  315. def test_subword_vector(self):
  316. f = load_model(self.output + '.bin')
  317. words, _ = f.get_words(include_freq=True)
  318. words += get_random_words(10000, 1, 200)
  319. input_matrix = f.get_input_matrix()
  320. for word in words:
  321. # Universal api to get word vector
  322. vec1 = f.get_word_vector(word)
  323. # Build word vector from subwords
  324. subwords, subinds = f.get_subwords(word)
  325. subvectors = list(map(lambda x: f.get_input_vector(x), subinds))
  326. subvectors = np.stack(subvectors)
  327. vec2 = np.sum((subvectors / len(subwords)), 0)
  328. # Build word vector from subinds
  329. vec3 = np.sum(input_matrix[subinds] / len(subinds), 0)
  330. # Build word vectors from word and subword ids
  331. wid = f.get_word_id(word)
  332. if wid >= 0:
  333. swids = list(map(lambda x: f.get_subword_id(x), subwords[1:]))
  334. swids.append(wid)
  335. else:
  336. swids = list(map(lambda x: f.get_subword_id(x), subwords))
  337. swids = np.array(swids)
  338. vec4 = np.sum(input_matrix[swids] / len(swids), 0)
  339. self.assertTrue(np.isclose(vec1, vec2, atol=1e-5, rtol=0).all())
  340. self.assertTrue(np.isclose(vec2, vec3, atol=1e-5, rtol=0).all())
  341. self.assertTrue(np.isclose(vec3, vec4, atol=1e-5, rtol=0).all())
  342. self.assertTrue(np.isclose(vec4, vec1, atol=1e-5, rtol=0).all())
  343. # TODO: Compare with .vec file
  344. def test_get_words(self):
  345. f = load_model(self.output + '.bin')
  346. words1, freq1 = f.get_words(include_freq=True)
  347. words2 = f.get_words(include_freq=False)
  348. self.assertEqual(len(words1), len(words2))
  349. self.assertEqual(len(words1), len(freq1))
  350. f = load_model(self.output_sup + '.bin')
  351. words1, freq1 = f.get_words(include_freq=True)
  352. words2 = f.get_words(include_freq=False)
  353. self.assertEqual(len(words1), len(words2))
  354. self.assertEqual(len(words1), len(freq1))
  355. # TODO: Compare with .vec file for unsup
  356. def test_get_labels(self):
  357. f = load_model(self.output + '.bin')
  358. labels1, freq1 = f.get_labels(include_freq=True)
  359. labels2 = f.get_labels(include_freq=False)
  360. words2 = f.get_words(include_freq=False)
  361. self.assertEqual(len(labels1), len(labels2))
  362. self.assertEqual(len(labels1), len(freq1))
  363. self.assertEqual(len(labels1), len(words2))
  364. for w1, w2 in zip(labels2, words2):
  365. self.assertEqual(w1, w2)
  366. f = load_model(self.output_sup + '.bin')
  367. labels1, freq1 = f.get_labels(include_freq=True)
  368. labels2 = f.get_labels(include_freq=False)
  369. self.assertEqual(len(labels1), len(labels2))
  370. self.assertEqual(len(labels1), len(freq1))
  371. def test_exercise_is_quant(self):
  372. f = load_model(self.output + '.bin')
  373. gotError = False
  374. try:
  375. f.quantize()
  376. except ValueError:
  377. gotError = True
  378. self.assertTrue(gotError)
  379. f = load_model(self.output_sup + '.bin')
  380. self.assertTrue(not f.is_quantized())
  381. f.quantize()
  382. self.assertTrue(f.is_quantized())
  383. def test_newline_predict_sentence(self):
  384. f = load_model(self.output_sup + '.bin')
  385. sentence = get_random_words(1, 1000, 2000)[0]
  386. f.predict(sentence, k=5)
  387. sentence += "\n"
  388. gotError = False
  389. try:
  390. f.predict(sentence, k=5)
  391. except ValueError:
  392. gotError = True
  393. self.assertTrue(gotError)
  394. f = load_model(self.output + '.bin')
  395. sentence = get_random_words(1, 1000, 2000)[0]
  396. f.get_sentence_vector(sentence)
  397. sentence += "\n"
  398. gotError = False
  399. try:
  400. f.get_sentence_vector(sentence)
  401. except ValueError:
  402. gotError = True
  403. self.assertTrue(gotError)
  404. class TestFastTextPyIntegration(TestFastTextPy):
  405. @classmethod
  406. def setUpClass(cls):
  407. cls.bin = os.environ['FASTTEXT_BIN']
  408. cls.data_dir = os.environ['FASTTEXT_DATA']
  409. cls.result_dir = tempfile.mkdtemp()
  410. def test_unsup1(self):
  411. train, test, output = self.build_paths("fil9", "rw/rw.txt", "fil9")
  412. model = train_unsupervised(
  413. input=train,
  414. model="skipgram",
  415. lr=0.025,
  416. dim=100,
  417. ws=5,
  418. epoch=1,
  419. minCount=5,
  420. neg=5,
  421. loss="ns",
  422. bucket=2000000,
  423. minn=3,
  424. maxn=6,
  425. t=1e-4,
  426. lrUpdateRate=100,
  427. thread=self.num_thread(),
  428. )
  429. model.save_model(output)
  430. path_size = self.get_path_size(output)
  431. vectors = {}
  432. with open(test, 'r') as test_f:
  433. for line in test_f:
  434. query0 = line.split()[0].strip()
  435. query1 = line.split()[1].strip()
  436. vector0 = model.get_word_vector(query0)
  437. vector1 = model.get_word_vector(query1)
  438. vectors[query0] = vector0
  439. vectors[query1] = vector1
  440. dataset, correlation, oov = compute_similarity(None, test, vectors)
  441. correlation = np.around(correlation)
  442. self.assertTrue(
  443. correlation >= 41, "Correlation: Want: 41 Is: " + str(correlation)
  444. )
  445. self.assertEqual(oov, 0.0, "Oov: Want: 0 Is: " + str(oov))
  446. self.assertEqual(
  447. path_size, 978480868, "Size: Want: 978480868 Is: " + str(path_size)
  448. )
  449. def gen_sup_test(lr, dataset, n, p1, r1, p1_q, r1_q, size, quant_size):
  450. def sup_test(self):
  451. def check(
  452. output_local, test_local, n_local, p1_local, r1_local, size_local,
  453. lessthan
  454. ):
  455. test_args = self.default_test_args(output_local, test_local)
  456. test_output = self.get_test_output(test_args)
  457. self.assertEqual(
  458. str(test_output[0]),
  459. str(n_local),
  460. "N: Want: " + str(n_local) + " Is: " + str(test_output[0])
  461. )
  462. self.assertTrue(
  463. float(test_output[1]) >= float(p1_local),
  464. "p1: Want: " + str(p1_local) + " Is: " + str(test_output[1])
  465. )
  466. self.assertTrue(
  467. float(test_output[2]) >= float(r1_local),
  468. "r1: Want: " + str(r1_local) + " Is: " + str(test_output[2])
  469. )
  470. path_size = self.get_path_size(output_local)
  471. if lessthan:
  472. self.assertTrue(
  473. path_size <= size_local, "Size: Want at most: " +
  474. str(size_local) + " Is: " + str(path_size)
  475. )
  476. else:
  477. self.assertTrue(
  478. path_size == size_local,
  479. "Size: Want: " + str(size_local) + " Is: " + str(path_size)
  480. )
  481. train, test, output = self.build_paths(
  482. dataset + ".train", dataset + ".test", dataset
  483. )
  484. model = train_supervised(
  485. input=train,
  486. dim=10,
  487. lr=lr,
  488. wordNgrams=2,
  489. minCount=1,
  490. bucket=10000000,
  491. epoch=5,
  492. thread=self.num_thread()
  493. )
  494. model.save_model(output)
  495. check(output, test, n, p1, r1, size, False)
  496. # Exercising
  497. model.predict("hello world")
  498. model.quantize(input=train, retrain=True, cutoff=100000, qnorm=True)
  499. model.save_model(output + ".ftz")
  500. # Exercising
  501. model.predict("hello world")
  502. check(output + ".ftz", test, n, p1_q, r1_q, quant_size, True)
  503. return sup_test
  504. if __name__ == "__main__":
  505. sup_job_lr = [0.25, 0.5, 0.5, 0.1, 0.1, 0.1, 0.05, 0.05]
  506. sup_job_n = [7600, 60000, 70000, 38000, 50000, 60000, 650000, 400000]
  507. sup_job_p1 = [0.921, 0.968, 0.984, 0.956, 0.638, 0.723, 0.603, 0.946]
  508. sup_job_r1 = [0.921, 0.968, 0.984, 0.956, 0.638, 0.723, 0.603, 0.946]
  509. sup_job_quant_p1 = [0.918, 0.965, 0.984, 0.953, 0.629, 0.707, 0.58, 0.940]
  510. sup_job_quant_r1 = [0.918, 0.965, 0.984, 0.953, 0.629, 0.707, 0.58, 0.940]
  511. sup_job_size = [
  512. 405607193, 421445471, 447481878, 427867393, 431292576, 517549567,
  513. 483742593, 493604598
  514. ]
  515. sup_job_quant_size = [
  516. 405607193, 421445471, 447481878, 427867393, 431292576, 517549567,
  517. 483742593, 493604598
  518. ]
  519. sup_job_quant_size = [
  520. 1600000, 1457000, 1690000, 1550000, 1567896, 1655000, 1600000, 1575010
  521. ]
  522. # Yelp_review_full can be a bit flaky
  523. sup_job_dataset = [
  524. "ag_news", "sogou_news", "dbpedia", "yelp_review_polarity",
  525. "yelp_review_full", "yahoo_answers", "amazon_review_full",
  526. "amazon_review_polarity"
  527. ]
  528. sup_job_args = [
  529. sup_job_lr, sup_job_dataset, sup_job_n, sup_job_p1, sup_job_r1,
  530. sup_job_quant_p1, sup_job_quant_r1, sup_job_size, sup_job_quant_size
  531. ]
  532. for lr, dataset, n, p1, r1, p1_q, r1_q, size, quant_size in zip(
  533. *sup_job_args
  534. ):
  535. setattr(
  536. TestFastTextPyIntegration, "test_" + dataset,
  537. gen_sup_test(lr, dataset, n, p1, r1, p1_q, r1_q, size, quant_size)
  538. )
  539. unittest.main()