1
0

test_script.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from __future__ import absolute_import
  7. from __future__ import division
  8. from __future__ import print_function
  9. from __future__ import unicode_literals
  10. from fasttext import train_supervised
  11. from fasttext import train_unsupervised
  12. from fasttext import util
  13. import fasttext
  14. import os
  15. import subprocess
  16. import unittest
  17. import tempfile
  18. import random
  19. import sys
  20. import copy
  21. import numpy as np
  22. try:
  23. import unicode
  24. except ImportError:
  25. pass
  26. from fasttext.tests.test_configurations import get_supervised_models
  27. def eprint(cls, *args, **kwargs):
  28. print(*args, file=sys.stderr, **kwargs)
  29. def get_random_unicode(length):
  30. # See: https://stackoverflow.com/questions/1477294/generate-random-utf-8-string-in-python
  31. try:
  32. get_char = unichr
  33. except NameError:
  34. get_char = chr
  35. # Update this to include code point ranges to be sampled
  36. include_ranges = [
  37. (0x0021, 0x0021),
  38. (0x0023, 0x0026),
  39. (0x0028, 0x007E),
  40. (0x00A1, 0x00AC),
  41. (0x00AE, 0x00FF),
  42. (0x0100, 0x017F),
  43. (0x0180, 0x024F),
  44. (0x2C60, 0x2C7F),
  45. (0x16A0, 0x16F0),
  46. (0x0370, 0x0377),
  47. (0x037A, 0x037E),
  48. (0x0384, 0x038A),
  49. (0x038C, 0x038C),
  50. ]
  51. alphabet = [
  52. get_char(code_point)
  53. for current_range in include_ranges
  54. for code_point in range(current_range[0], current_range[1] + 1)
  55. ]
  56. return ''.join(random.choice(alphabet) for i in range(length))
  57. def get_random_words(N, a=1, b=20, unique=True):
  58. words = []
  59. while (len(words) < N):
  60. length = random.randint(a, b)
  61. word = get_random_unicode(length)
  62. if unique and word not in words:
  63. words.append(word)
  64. else:
  65. words.append(word)
  66. return words
  67. def get_random_data(
  68. num_lines=100,
  69. max_vocab_size=100,
  70. min_words_line=0,
  71. max_words_line=20,
  72. min_len_word=1,
  73. max_len_word=10,
  74. unique_words=True,
  75. ):
  76. random_words = get_random_words(
  77. max_vocab_size, min_len_word, max_len_word, unique=unique_words
  78. )
  79. lines = []
  80. for _ in range(num_lines):
  81. line = []
  82. line_length = random.randint(min_words_line, max_words_line)
  83. for _ in range(line_length):
  84. i = random.randint(0, max_vocab_size - 1)
  85. line.append(random_words[i])
  86. line = " ".join(line)
  87. lines.append(line)
  88. return lines
  89. def default_kwargs(kwargs):
  90. default = {"thread": 1, "epoch": 1, "minCount": 1, "bucket": 1000}
  91. for k, v in default.items():
  92. if k not in kwargs:
  93. kwargs[k] = v
  94. return kwargs
  95. def build_unsupervised_model(data, kwargs):
  96. kwargs = default_kwargs(kwargs)
  97. with tempfile.NamedTemporaryFile(delete=False) as tmpf:
  98. for line in data:
  99. tmpf.write((line + "\n").encode("UTF-8"))
  100. tmpf.flush()
  101. model = train_unsupervised(input=tmpf.name, **kwargs)
  102. return model
  103. def build_supervised_model(data, kwargs):
  104. kwargs = default_kwargs(kwargs)
  105. with tempfile.NamedTemporaryFile(delete=False) as tmpf:
  106. for line in data:
  107. line = "__label__" + line.strip() + "\n"
  108. tmpf.write(line.encode("UTF-8"))
  109. tmpf.flush()
  110. model = train_supervised(input=tmpf.name, **kwargs)
  111. return model
  112. def read_labels(data_file):
  113. labels = []
  114. lines = []
  115. with open(data_file, 'r') as f:
  116. for line in f:
  117. labels_line = []
  118. words_line = []
  119. try:
  120. line = unicode(line, "UTF-8").split()
  121. except NameError:
  122. line = line.split()
  123. for word in line:
  124. if word.startswith("__label__"):
  125. labels_line.append(word)
  126. else:
  127. words_line.append(word)
  128. labels.append(labels_line)
  129. lines.append(" ".join(words_line))
  130. return lines, labels
  131. class TestFastTextUnitPy(unittest.TestCase):
  132. # TODO: Unit test copy behavior of fasttext
  133. def gen_test_get_vector(self, kwargs):
  134. # Confirm if no subwords, OOV is zero, confirm min=10 means words < 10 get zeros
  135. f = build_unsupervised_model(get_random_data(100), kwargs)
  136. words, _ = f.get_words(include_freq=True)
  137. words += get_random_words(100)
  138. for word in words:
  139. f.get_word_vector(word)
  140. def gen_test_multi_get_line(self, kwargs):
  141. data = get_random_data(100)
  142. model1 = build_supervised_model(data, kwargs)
  143. model2 = build_unsupervised_model(data, kwargs)
  144. lines1 = []
  145. lines2 = []
  146. for line in data:
  147. words, labels = model1.get_line(line)
  148. lines1.append(words)
  149. self.assertEqual(len(labels), 0)
  150. words, labels = model2.get_line(line)
  151. lines2.append(words)
  152. self.assertEqual(len(labels), 0)
  153. all_lines1, all_labels1 = model1.get_line(data)
  154. all_lines2, all_labels2 = model2.get_line(data)
  155. self.assertEqual(lines1, all_lines1)
  156. self.assertEqual(lines2, all_lines2)
  157. for labels in all_labels1:
  158. self.assertEqual(len(labels), 0)
  159. for labels in all_labels2:
  160. self.assertEqual(len(labels), 0)
  161. def gen_test_supervised_util_test(self, kwargs):
  162. def check(data):
  163. third = int(len(data) / 3)
  164. train_data = data[:2 * third]
  165. valid_data = data[third:]
  166. with tempfile.NamedTemporaryFile(
  167. delete=False
  168. ) as tmpf, tempfile.NamedTemporaryFile(delete=False) as tmpf2:
  169. for line in train_data:
  170. tmpf.write(
  171. ("__label__" + line.strip() + "\n").encode("UTF-8")
  172. )
  173. tmpf.flush()
  174. for line in valid_data:
  175. tmpf2.write(
  176. ("__label__" + line.strip() + "\n").encode("UTF-8")
  177. )
  178. tmpf2.flush()
  179. model = train_supervised(input=tmpf.name, **kwargs)
  180. true_labels = []
  181. all_words = []
  182. with open(tmpf2.name, 'r') as fid:
  183. for line in fid:
  184. if sys.version_info < (3, 0):
  185. line = line.decode("UTF-8")
  186. if len(line.strip()) == 0:
  187. continue
  188. words, labels = model.get_line(line.strip())
  189. if len(labels) == 0:
  190. continue
  191. all_words.append(" ".join(words))
  192. true_labels += [labels]
  193. predictions, _ = model.predict(all_words)
  194. p, r = util.test(predictions, true_labels)
  195. N = len(predictions)
  196. Nt, pt, rt = model.test(tmpf2.name)
  197. self.assertEqual(N, Nt)
  198. self.assertEqual(p, pt)
  199. self.assertEqual(r, rt)
  200. # Need at least one word to have a label and a word to prevent error
  201. check(get_random_data(100, min_words_line=2))
  202. def gen_test_supervised_predict(self, kwargs):
  203. # Confirm number of labels, confirm labels for easy dataset
  204. # Confirm 1 label and 0 label dataset
  205. f = build_supervised_model(get_random_data(100), kwargs)
  206. words = get_random_words(100)
  207. for k in [1, 2, 5]:
  208. for w in words:
  209. labels, probs = f.predict(w, k)
  210. data = get_random_data(100)
  211. for line in data:
  212. labels, probs = f.predict(line, k)
  213. def gen_test_supervised_multiline_predict(self, kwargs):
  214. # Confirm number of labels, confirm labels for easy dataset
  215. # Confirm 1 label and 0 label dataset
  216. def check_predict(f):
  217. for k in [1, 2, 5]:
  218. words = get_random_words(10)
  219. agg_labels = []
  220. agg_probs = []
  221. for w in words:
  222. labels, probs = f.predict(w, k)
  223. agg_labels += [labels]
  224. agg_probs += [probs]
  225. all_labels1, all_probs1 = f.predict(words, k)
  226. data = get_random_data(10)
  227. for line in data:
  228. labels, probs = f.predict(line, k)
  229. agg_labels += [labels]
  230. agg_probs += [probs]
  231. all_labels2, all_probs2 = f.predict(data, k)
  232. all_labels = list(all_labels1) + list(all_labels2)
  233. all_probs = list(all_probs1) + list(all_probs2)
  234. for label1, label2 in zip(all_labels, agg_labels):
  235. self.assertEqual(list(label1), list(label2))
  236. for prob1, prob2 in zip(all_probs, agg_probs):
  237. self.assertEqual(list(prob1), list(prob2))
  238. check_predict(build_supervised_model(get_random_data(100), kwargs))
  239. check_predict(
  240. build_supervised_model(
  241. get_random_data(100, min_words_line=1), kwargs
  242. )
  243. )
  244. def gen_test_vocab(self, kwargs):
  245. # Confirm empty dataset, confirm all label dataset
  246. data = get_random_data(100)
  247. words_python = {}
  248. for line in data:
  249. line_words = line.split()
  250. for w in line_words:
  251. if w not in words_python:
  252. words_python[w] = 0
  253. words_python[w] += 1
  254. f = build_unsupervised_model(data, kwargs)
  255. words, freqs = f.get_words(include_freq=True)
  256. foundEOS = False
  257. for word, freq in zip(words, freqs):
  258. if word == fasttext.EOS:
  259. foundEOS = True
  260. else:
  261. self.assertEqual(words_python[word], freq)
  262. # EOS is special to fasttext, but still part of the vocab
  263. self.assertEqual(len(words_python), len(words) - 1)
  264. self.assertTrue(foundEOS)
  265. # Should cause "Empty vocabulary" error.
  266. data = get_random_data(0)
  267. gotError = False
  268. try:
  269. build_unsupervised_model(data, kwargs)
  270. except ValueError:
  271. gotError = True
  272. self.assertTrue(gotError)
  273. def gen_test_subwords(self, kwargs):
  274. # Define expected behavior
  275. f = build_unsupervised_model(get_random_data(100), kwargs)
  276. words, _ = f.get_words(include_freq=True)
  277. words += get_random_words(10, 1, 10)
  278. for w in words:
  279. f.get_subwords(w)
  280. def gen_test_tokenize(self, kwargs):
  281. self.assertEqual(["asdf", "asdb"], fasttext.tokenize("asdf asdb"))
  282. self.assertEqual(["asdf"], fasttext.tokenize("asdf"))
  283. self.assertEqual([fasttext.EOS], fasttext.tokenize("\n"))
  284. self.assertEqual(["asdf", fasttext.EOS], fasttext.tokenize("asdf\n"))
  285. self.assertEqual([], fasttext.tokenize(""))
  286. self.assertEqual([], fasttext.tokenize(" "))
  287. # An empty string is not a token (it's just whitespace)
  288. # So the minimum length must be 1
  289. words = get_random_words(100, 1, 20)
  290. self.assertEqual(words, fasttext.tokenize(" ".join(words)))
  291. def gen_test_unsupervised_dimension(self, kwargs):
  292. if "dim" in kwargs:
  293. f = build_unsupervised_model(get_random_data(100), kwargs)
  294. self.assertEqual(f.get_dimension(), kwargs["dim"])
  295. def gen_test_supervised_dimension(self, kwargs):
  296. if "dim" in kwargs:
  297. f = build_supervised_model(get_random_data(100), kwargs)
  298. self.assertEqual(f.get_dimension(), kwargs["dim"])
  299. def gen_test_subword_vector(self, kwargs):
  300. f = build_unsupervised_model(get_random_data(100), kwargs)
  301. words, _ = f.get_words(include_freq=True)
  302. words += get_random_words(100, 1, 20)
  303. input_matrix = f.get_input_matrix()
  304. for word in words:
  305. # Universal API to get word vector
  306. vec1 = f.get_word_vector(word)
  307. # Build word vector from subwords
  308. subwords, subinds = f.get_subwords(word)
  309. subvectors = list(map(lambda x: f.get_input_vector(x), subinds))
  310. if len(subvectors) == 0:
  311. vec2 = np.zeros((f.get_dimension(), ))
  312. else:
  313. subvectors = np.vstack(subvectors)
  314. vec2 = np.sum((subvectors / len(subwords)), 0)
  315. # Build word vector from subinds
  316. if len(subinds) == 0:
  317. vec3 = np.zeros((f.get_dimension(), ))
  318. else:
  319. vec3 = np.sum(input_matrix[subinds] / len(subinds), 0)
  320. # Build word vectors from word and subword ids
  321. wid = f.get_word_id(word)
  322. if wid >= 0:
  323. swids = list(map(lambda x: f.get_subword_id(x), subwords[1:]))
  324. swids.append(wid)
  325. else:
  326. swids = list(map(lambda x: f.get_subword_id(x), subwords))
  327. if len(swids) == 0:
  328. vec4 = np.zeros((f.get_dimension(), ))
  329. else:
  330. swids = np.array(swids)
  331. vec4 = np.sum(input_matrix[swids] / len(swids), 0)
  332. self.assertTrue(np.isclose(vec1, vec2, atol=1e-5, rtol=0).all())
  333. self.assertTrue(np.isclose(vec2, vec3, atol=1e-5, rtol=0).all())
  334. self.assertTrue(np.isclose(vec3, vec4, atol=1e-5, rtol=0).all())
  335. self.assertTrue(np.isclose(vec4, vec1, atol=1e-5, rtol=0).all())
  336. def gen_test_unsupervised_get_words(self, kwargs):
  337. # Check more corner cases of 0 vocab, empty file etc.
  338. f = build_unsupervised_model(get_random_data(100), kwargs)
  339. words1, freq1 = f.get_words(include_freq=True)
  340. words2 = f.get_words(include_freq=False)
  341. self.assertEqual(len(words1), len(words2))
  342. self.assertEqual(len(words1), len(freq1))
  343. def gen_test_supervised_get_words(self, kwargs):
  344. f = build_supervised_model(get_random_data(100), kwargs)
  345. words1, freq1 = f.get_words(include_freq=True)
  346. words2 = f.get_words(include_freq=False)
  347. self.assertEqual(len(words1), len(words2))
  348. self.assertEqual(len(words1), len(freq1))
  349. def gen_test_unsupervised_get_labels(self, kwargs):
  350. f = build_unsupervised_model(get_random_data(100), kwargs)
  351. labels1, freq1 = f.get_labels(include_freq=True)
  352. labels2 = f.get_labels(include_freq=False)
  353. words2 = f.get_words(include_freq=False)
  354. self.assertEqual(len(labels1), len(labels2))
  355. self.assertEqual(len(labels1), len(freq1))
  356. self.assertEqual(len(labels1), len(words2))
  357. for w1, w2 in zip(labels2, words2):
  358. self.assertEqual(w1, w2)
  359. def gen_test_supervised_get_labels(self, kwargs):
  360. f = build_supervised_model(get_random_data(100), kwargs)
  361. labels1, freq1 = f.get_labels(include_freq=True)
  362. labels2 = f.get_labels(include_freq=False)
  363. self.assertEqual(len(labels1), len(labels2))
  364. self.assertEqual(len(labels1), len(freq1))
  365. def gen_test_unsupervised_exercise_is_quant(self, kwargs):
  366. f = build_unsupervised_model(get_random_data(100), kwargs)
  367. gotError = False
  368. try:
  369. f.quantize()
  370. except ValueError:
  371. gotError = True
  372. self.assertTrue(gotError)
  373. def gen_test_supervised_exercise_is_quant(self, kwargs):
  374. f = build_supervised_model(
  375. get_random_data(1000, max_vocab_size=1000), kwargs
  376. )
  377. self.assertTrue(not f.is_quantized())
  378. f.quantize()
  379. self.assertTrue(f.is_quantized())
  380. def gen_test_newline_predict_sentence(self, kwargs):
  381. f = build_supervised_model(get_random_data(100), kwargs)
  382. sentence = " ".join(get_random_words(20))
  383. f.predict(sentence, k=5)
  384. sentence += "\n"
  385. gotError = False
  386. try:
  387. f.predict(sentence, k=5)
  388. except ValueError:
  389. gotError = True
  390. self.assertTrue(gotError)
  391. f = build_supervised_model(get_random_data(100), kwargs)
  392. sentence = " ".join(get_random_words(20))
  393. f.get_sentence_vector(sentence)
  394. sentence += "\n"
  395. gotError = False
  396. try:
  397. f.get_sentence_vector(sentence)
  398. except ValueError:
  399. gotError = True
  400. self.assertTrue(gotError)
  401. # Generate a supervised test case
  402. # The returned function will be set as an attribute to a test class
  403. def gen_sup_test(configuration, data_dir):
  404. def sup_test(self):
  405. def get_path_size(path):
  406. path_size = subprocess.check_output(["stat", "-c", "%s",
  407. path]).decode('utf-8')
  408. path_size = int(path_size)
  409. return path_size
  410. def check(model, model_filename, test, lessthan, msg_prefix=""):
  411. N_local_out, p1_local_out, r1_local_out = model.test(test["data"])
  412. self.assertEqual(
  413. N_local_out, test["n"], msg_prefix + "N: Want: " +
  414. str(test["n"]) + " Is: " + str(N_local_out)
  415. )
  416. self.assertTrue(
  417. p1_local_out >= test["p1"], msg_prefix + "p1: Want: " +
  418. str(test["p1"]) + " Is: " + str(p1_local_out)
  419. )
  420. self.assertTrue(
  421. r1_local_out >= test["r1"], msg_prefix + "r1: Want: " +
  422. str(test["r1"]) + " Is: " + str(r1_local_out)
  423. )
  424. path_size = get_path_size(model_filename)
  425. size_msg = str(test["size"]) + " Is: " + str(path_size)
  426. if lessthan:
  427. self.assertTrue(
  428. path_size <= test["size"],
  429. msg_prefix + "Size: Want at most: " + size_msg
  430. )
  431. else:
  432. self.assertTrue(
  433. path_size == test["size"],
  434. msg_prefix + "Size: Want: " + size_msg
  435. )
  436. configuration["args"]["input"] = os.path.join(
  437. data_dir, configuration["args"]["input"]
  438. )
  439. configuration["quant_args"]["input"] = configuration["args"]["input"]
  440. configuration["test"]["data"] = os.path.join(
  441. data_dir, configuration["test"]["data"]
  442. )
  443. configuration["quant_test"]["data"] = configuration["test"]["data"]
  444. output = os.path.join(tempfile.mkdtemp(), configuration["dataset"])
  445. print()
  446. model = train_supervised(**configuration["args"])
  447. model.save_model(output + ".bin")
  448. check(
  449. model,
  450. output + ".bin",
  451. configuration["test"],
  452. False,
  453. msg_prefix="Supervised: "
  454. )
  455. print()
  456. model.quantize(**configuration["quant_args"])
  457. model.save_model(output + ".ftz")
  458. check(
  459. model,
  460. output + ".ftz",
  461. configuration["quant_test"],
  462. True,
  463. msg_prefix="Quantized: "
  464. )
  465. return sup_test
  466. def gen_unit_tests(verbose=0):
  467. gen_funcs = [
  468. func for func in dir(TestFastTextUnitPy)
  469. if callable(getattr(TestFastTextUnitPy, func))
  470. if func.startswith("gen_test_")
  471. ]
  472. general_settings = [
  473. {
  474. "minn": 2,
  475. "maxn": 4,
  476. }, {
  477. "minn": 0,
  478. "maxn": 0,
  479. "bucket": 0
  480. }, {
  481. "dim": 1
  482. }, {
  483. "dim": 5
  484. }
  485. ]
  486. supervised_settings = [
  487. {
  488. "minn": 2,
  489. "maxn": 4,
  490. }, {
  491. "minn": 0,
  492. "maxn": 0,
  493. "bucket": 0
  494. }, {
  495. "dim": 1
  496. }, {
  497. "dim": 5
  498. }, {
  499. "dim": 5,
  500. "loss": "hs"
  501. }
  502. ]
  503. unsupervised_settings = [
  504. {
  505. "minn": 2,
  506. "maxn": 4,
  507. }, {
  508. "minn": 0,
  509. "maxn": 0,
  510. "bucket": 0
  511. }, {
  512. "dim": 1
  513. }, {
  514. "dim": 5,
  515. "model": "cbow"
  516. }, {
  517. "dim": 5,
  518. "model": "skipgram"
  519. }
  520. ]
  521. for gen_func in gen_funcs:
  522. def build_test(test_name, kwargs=None):
  523. if kwargs is None:
  524. kwargs = {}
  525. kwargs["verbose"] = verbose
  526. def test(self):
  527. return getattr(TestFastTextUnitPy,
  528. "gen_" + test_name)(self, copy.deepcopy(kwargs))
  529. return test
  530. test_name = gen_func[4:]
  531. if "_unsupervised_" in test_name:
  532. for i, setting in enumerate(unsupervised_settings):
  533. setattr(
  534. TestFastTextUnitPy, test_name + "_" + str(i),
  535. build_test(test_name, setting)
  536. )
  537. elif "_supervised_" in test_name:
  538. for i, setting in enumerate(supervised_settings):
  539. setattr(
  540. TestFastTextUnitPy, test_name + "_" + str(i),
  541. build_test(test_name, setting)
  542. )
  543. else:
  544. for i, setting in enumerate(general_settings):
  545. setattr(
  546. TestFastTextUnitPy, test_name + "_" + str(i),
  547. build_test(test_name, setting)
  548. )
  549. return TestFastTextUnitPy
  550. def gen_tests(data_dir, verbose=1):
  551. class TestFastTextPy(unittest.TestCase):
  552. pass
  553. i = 0
  554. for configuration in get_supervised_models(verbose=verbose):
  555. setattr(
  556. TestFastTextPy,
  557. "test_sup_" + str(i) + "_" + configuration["dataset"],
  558. gen_sup_test(configuration, data_dir)
  559. )
  560. i += 1
  561. return TestFastTextPy