get_word_vector.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the MIT license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from __future__ import absolute_import
  7. from __future__ import division
  8. from __future__ import print_function
  9. from __future__ import unicode_literals
  10. from fastText import load_model
  11. from fastText import tokenize
  12. import sys
  13. import time
  14. import tempfile
  15. import argparse
  16. def get_word_vector(data, model):
  17. t1 = time.time()
  18. print("Reading")
  19. with open(data, 'r') as f:
  20. tokens = tokenize(f.read())
  21. t2 = time.time()
  22. print("Read TIME: " + str(t2 - t1))
  23. print("Read NUM : " + str(len(tokens)))
  24. f = load_model(model)
  25. # This is not equivalent to piping the data into
  26. # print-word-vector, because the data is tokenized
  27. # first.
  28. t3 = time.time()
  29. i = 0
  30. for t in tokens:
  31. f.get_word_vector(t)
  32. i += 1
  33. if i % 10000 == 0:
  34. sys.stderr.write("\ri: " + str(float(i / len(tokens))))
  35. sys.stderr.flush()
  36. t4 = time.time()
  37. print("\nVectoring: " + str(t4 - t3))
  38. if __name__ == "__main__":
  39. parser = argparse.ArgumentParser(description='Simple benchmark for get_word_vector.')
  40. parser.add_argument('model', help='A model file to use for benchmarking.')
  41. parser.add_argument('data', help='A data file to use for benchmarking.')
  42. args = parser.parse_args()
  43. get_word_vector(args.data, args.model)