train_supervised.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #!/usr/bin/env python
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the MIT license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import os
  12. from fasttext import train_supervised
  13. def print_results(N, p, r):
  14. print("N\t" + str(N))
  15. print("P@{}\t{:.3f}".format(1, p))
  16. print("R@{}\t{:.3f}".format(1, r))
  17. if __name__ == "__main__":
  18. train_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.train')
  19. valid_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.valid')
  20. # train_supervised uses the same arguments and defaults as the fastText cli
  21. model = train_supervised(
  22. input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=2, minCount=1
  23. )
  24. print_results(*model.test(valid_data))
  25. model = train_supervised(
  26. input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=2, minCount=1,
  27. loss="hs"
  28. )
  29. print_results(*model.test(valid_data))
  30. model.save_model("cooking.bin")
  31. model.quantize(input=train_data, qnorm=True, retrain=True, cutoff=100000)
  32. print_results(*model.test(valid_data))
  33. model.save_model("cooking.ftz")