train.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # BSD 3-Clause License
  2. # Copyright (c) 2018-2020, NVIDIA Corporation
  3. # All rights reserved.
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright notice, this
  7. # list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright notice,
  9. # this list of conditions and the following disclaimer in the documentation
  10. # and/or other materials provided with the distribution.
  11. # * Neither the name of the copyright holder nor the names of its
  12. # contributors may be used to endorse or promote products derived from
  13. # this software without specific prior written permission.
  14. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  15. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  16. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  17. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  18. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  19. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  20. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  21. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  22. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  23. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  24. """https://github.com/NVIDIA/tacotron2"""
  25. import os
  26. from numpy import finfo
  27. import torch
  28. from tacotron2.distributed import apply_gradient_allreduce
  29. import torch.distributed as dist
  30. from torch.utils.data.distributed import DistributedSampler
  31. from torch.utils.data import DataLoader
  32. from tacotron2.model import Tacotron2
  33. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  34. def reduce_tensor(tensor, n_gpus):
  35. rt = tensor.clone()
  36. dist.all_reduce(rt, op=dist.reduce_op.SUM)
  37. rt /= n_gpus
  38. return rt
  39. def init_distributed(hparams, n_gpus, rank, group_name):
  40. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  41. print("Initializing Distributed")
  42. # Set cuda device so everything is done on the right GPU.
  43. torch.cuda.set_device(rank % torch.cuda.device_count())
  44. # Initialize distributed communication
  45. dist.init_process_group(
  46. backend=hparams.dist_backend, init_method=hparams.dist_url,
  47. world_size=n_gpus, rank=rank, group_name=group_name)
  48. print("Done initializing distributed")
  49. def load_model(hparams):
  50. model = Tacotron2(hparams).to(device)
  51. if hparams.fp16_run:
  52. model.decoder.attention_layer.score_mask_value = finfo('float16').min
  53. if hparams.distributed_run:
  54. model = apply_gradient_allreduce(model)
  55. return model
  56. def warm_start_model(checkpoint_path, model, ignore_layers):
  57. assert os.path.isfile(checkpoint_path)
  58. print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
  59. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  60. model_dict = checkpoint_dict['state_dict']
  61. if len(ignore_layers) > 0:
  62. model_dict = {k: v for k, v in model_dict.items()
  63. if k not in ignore_layers}
  64. dummy_dict = model.state_dict()
  65. dummy_dict.update(model_dict)
  66. model_dict = dummy_dict
  67. model.load_state_dict(model_dict)
  68. return model
  69. def load_checkpoint(checkpoint_path, model, optimizer):
  70. assert os.path.isfile(checkpoint_path)
  71. print("Loading checkpoint '{}'".format(checkpoint_path))
  72. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  73. model.load_state_dict(checkpoint_dict['state_dict'])
  74. optimizer.load_state_dict(checkpoint_dict['optimizer'])
  75. learning_rate = checkpoint_dict['learning_rate']
  76. iteration = checkpoint_dict['iteration']
  77. print("Loaded checkpoint '{}' from iteration {}" .format(
  78. checkpoint_path, iteration))
  79. return model, optimizer, learning_rate, iteration
  80. def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
  81. print("Saving model and optimizer state at iteration {} to {}".format(
  82. iteration, filepath))
  83. torch.save({'iteration': iteration,
  84. 'state_dict': model.state_dict(),
  85. 'optimizer': optimizer.state_dict(),
  86. 'learning_rate': learning_rate}, filepath)
  87. def validate(model, criterion, valset, iteration, batch_size, n_gpus,
  88. collate_fn, logger, distributed_run, rank):
  89. """Handles all the validation scoring and printing"""
  90. model.eval()
  91. with torch.no_grad():
  92. val_sampler = DistributedSampler(valset) if distributed_run else None
  93. val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
  94. shuffle=False, batch_size=batch_size,
  95. pin_memory=False, collate_fn=collate_fn)
  96. val_loss = 0.0
  97. for i, batch in enumerate(val_loader):
  98. x, y = model.parse_batch(batch)
  99. y_pred = model(x)
  100. loss = criterion(y_pred, y)
  101. if distributed_run:
  102. reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
  103. else:
  104. reduced_val_loss = loss.item()
  105. val_loss += reduced_val_loss
  106. val_loss = val_loss / (i + 1)
  107. model.train()
  108. if rank == 0:
  109. print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
  110. logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)