FastTextEmbeddingBag.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the MIT license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. # NOTE: This requires PyTorch! We do not provide installation scripts to install PyTorch.
  8. # It is up to you to install this dependency if you want to execute this example.
  9. # PyTorch's website should give you clear instructions on this: http://pytorch.org/
  10. from __future__ import absolute_import
  11. from __future__ import division
  12. from __future__ import print_function
  13. from __future__ import unicode_literals
  14. from torch.nn.modules.sparse import EmbeddingBag
  15. import numpy as np
  16. import torch
  17. import random
  18. import string
  19. import time
  20. from fastText import load_model
  21. from torch.autograd import Variable
  22. class FastTextEmbeddingBag(EmbeddingBag):
  23. def __init__(self, model_path):
  24. self.model = load_model(model_path)
  25. input_matrix = self.model.get_input_matrix()
  26. input_matrix_shape = input_matrix.shape
  27. super().__init__(input_matrix_shape[0], input_matrix_shape[1])
  28. self.weight.data.copy_(torch.FloatTensor(input_matrix))
  29. def forward(self, words):
  30. word_subinds = np.empty([0], dtype=np.int64)
  31. word_offsets = [0]
  32. for word in words:
  33. _, subinds = self.model.get_subwords(word)
  34. word_subinds = np.concatenate((word_subinds, subinds))
  35. word_offsets.append(word_offsets[-1] + len(subinds))
  36. word_offsets = word_offsets[:-1]
  37. ind = Variable(torch.LongTensor(word_subinds))
  38. offsets = Variable(torch.LongTensor(word_offsets))
  39. return super().forward(ind, offsets)
  40. def random_word(N):
  41. return ''.join(
  42. random.choices(
  43. string.ascii_uppercase + string.ascii_lowercase + string.digits,
  44. k=N
  45. )
  46. )
  47. if __name__ == "__main__":
  48. ft_emb = FastTextEmbeddingBag("fil9.bin")
  49. model = load_model("fil9.bin")
  50. num_lines = 200
  51. total_seconds = 0.0
  52. total_words = 0
  53. for _ in range(num_lines):
  54. words = [
  55. random_word(random.randint(1, 10))
  56. for _ in range(random.randint(15, 25))
  57. ]
  58. total_words += len(words)
  59. words_average_length = sum([len(word) for word in words]) / len(words)
  60. start = time.clock()
  61. words_emb = ft_emb(words)
  62. total_seconds += (time.clock() - start)
  63. for i in range(len(words)):
  64. word = words[i]
  65. ft_word_emb = model.get_word_vector(word)
  66. py_emb = np.array(words_emb[i].data)
  67. assert (np.isclose(ft_word_emb, py_emb).all())
  68. print(
  69. "Avg. {:2.5f} seconds to build embeddings for {} lines with a total of {} words.".
  70. format(total_seconds, num_lines, total_words)
  71. )