reduce_model.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2017-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the MIT license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. from __future__ import absolute_import
  10. from __future__ import division
  11. from __future__ import print_function
  12. from __future__ import unicode_literals
  13. import argparse
  14. import os
  15. import re
  16. import sys
  17. import fasttext
  18. import fasttext.util
  19. args = None
  20. def eprint(*args, **kwargs):
  21. print(*args, file=sys.stderr, **kwargs)
  22. def guess_target_name(model_file, initial_dim, target_dim):
  23. """
  24. Given a model name with the convention a.<dim>.b, this function
  25. returns the model's name with `target_dim` value.
  26. For example model_file name `cc.en.300.bin` with initial dim 300 becomes
  27. `cc.en.100.bin` when the `target_dim` is 100.
  28. """
  29. prg = re.compile("(.*).%s.(.*)" % initial_dim)
  30. m = prg.match(model_file)
  31. if m:
  32. return "%s.%d.%s" % (m.group(1), target_dim, m.group(2))
  33. sp_ext = os.path.splitext(model_file)
  34. return "%s.%d%s" % (sp_ext[0], target_dim, sp_ext[1])
  35. def command_reduce(model_file, target_dim, if_exists):
  36. """
  37. Given a `model_file`, this function reduces its dimension to `target_dim`
  38. by applying a PCA.
  39. """
  40. eprint("Loading model")
  41. ft = fasttext.load_model(model_file)
  42. initial_dim = ft.get_dimension()
  43. if target_dim >= initial_dim:
  44. raise Exception("Target dimension (%d) should be less than initial dimension (%d)." % (
  45. target_dim, initial_dim))
  46. result_filename = guess_target_name(model_file, initial_dim, target_dim)
  47. if os.path.isfile(result_filename):
  48. if if_exists == 'overwrite':
  49. pass
  50. elif if_exists == 'strict':
  51. raise Exception(
  52. "File already exists. Use --overwrite to overwrite.")
  53. elif if_exists == 'ignore':
  54. return result_filename
  55. eprint("Reducing matrix dimensions")
  56. fasttext.util.reduce_model(ft, target_dim)
  57. eprint("Saving model")
  58. ft.save_model(result_filename)
  59. eprint("%s saved" % result_filename)
  60. return result_filename
  61. def main():
  62. global args
  63. parser = argparse.ArgumentParser(
  64. description='fastText helper tool to reduce model dimensions.')
  65. parser.add_argument("model", type=str,
  66. help="model file to reduce. model.bin")
  67. parser.add_argument("dim", type=int,
  68. help="targeted dimension of word vectors.")
  69. parser.add_argument("--overwrite", action="store_true",
  70. help="overwrite if file exists.")
  71. args = parser.parse_args()
  72. command_reduce(args.model, args.dim, if_exists=(
  73. 'overwrite' if args.overwrite else 'strict'))
  74. if __name__ == '__main__':
  75. main()