hvd_wrapper.py 783 B

1234567891011121314151617181920212223242526272829303132
  1. hvd_global_object = None
  2. def init(use_horovod: bool = False):
  3. global hvd_global_object
  4. if use_horovod:
  5. import horovod.tensorflow as hvd
  6. hvd.init()
  7. hvd_global_object = hvd
  8. else:
  9. class _DummyWrapper:
  10. def rank(self): return 0
  11. def size(self): return 1
  12. def local_rank(self): return 0
  13. def local_size(self): return 1
  14. hvd_global_object = _DummyWrapper()
  15. def size():
  16. global hvd_global_object
  17. return hvd_global_object.size()
  18. def rank():
  19. global hvd_global_object
  20. return hvd_global_object.rank()
  21. def local_rank():
  22. global hvd_global_object
  23. return hvd_global_object.local_rank()
  24. def local_size():
  25. global hvd_global_object
  26. return hvd_global_object.local_size()