train.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import time
  16. import os
  17. import pickle
  18. import json
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. import torch.distributed as dist
  23. from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
  24. from apex import amp
  25. from apex.optimizers import FusedAdam
  26. #from torch.nn.parallel import DistributedDataParallel as DDP
  27. from apex.parallel import DistributedDataParallel as DDP
  28. import numpy as np
  29. import dllogger
  30. from modeling import TemporalFusionTransformer
  31. from configuration import CONFIGS
  32. from data_utils import TFTBinaryDataset, sample_data
  33. from log_helper import setup_logger
  34. from criterions import QuantileLoss
  35. from inference import predict
  36. from utils import PerformanceMeter
  37. import gpu_affinity
  38. from ema import ModelEma
  39. def load_dataset(args, config):
  40. train_split = TFTBinaryDataset(os.path.join(args.data_path, 'train.bin'), config)
  41. train_split = sample_data(train_split, args.sample_data[0])
  42. if args.distributed_world_size > 1:
  43. data_sampler = DistributedSampler(train_split, args.distributed_world_size, args.distributed_rank, seed=args.seed + args.distributed_rank, drop_last=True)
  44. else:
  45. data_sampler = RandomSampler(train_split)
  46. train_loader = DataLoader(train_split, batch_size=args.batch_size, num_workers=4, sampler=data_sampler, pin_memory=True)
  47. valid_split = TFTBinaryDataset(os.path.join(args.data_path, 'valid.bin'), config)
  48. valid_split = sample_data(valid_split, args.sample_data[1])
  49. if args.distributed_world_size > 1:
  50. data_sampler = DistributedSampler(valid_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
  51. else:
  52. data_sampler = None
  53. valid_loader = DataLoader(valid_split, batch_size=args.batch_size, sampler=data_sampler, num_workers=4, pin_memory=True)
  54. test_split = TFTBinaryDataset(os.path.join(args.data_path, 'test.bin'), config)
  55. if args.distributed_world_size > 1:
  56. data_sampler = DistributedSampler(test_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
  57. else:
  58. data_sampler = None
  59. test_loader = DataLoader(test_split, batch_size=args.batch_size, sampler=data_sampler, num_workers=4, pin_memory=True)
  60. print_once(f'Train split length: {len(train_split)}')
  61. print_once(f'Valid split length: {len(valid_split)}')
  62. print_once(f'Test split length: {len(test_split)}')
  63. return train_loader, valid_loader, test_loader
  64. def print_once(*args, **kwargs):
  65. if not dist.is_initialized() or dist.get_rank() == 0:
  66. print(*args, **kwargs)
  67. def main(args):
  68. ### INIT DISTRIBUTED
  69. args.distributed_world_size = int(os.environ.get('WORLD_SIZE', 1))
  70. args.local_rank = int(os.environ.get('LOCAL_RANK', 0))
  71. if args.distributed_world_size > 1:
  72. dist.init_process_group(backend='nccl', init_method='env://')
  73. print_once(f'Distributed training with {args.distributed_world_size} GPUs')
  74. args.distributed_rank = dist.get_rank()
  75. torch.cuda.set_device(args.local_rank)
  76. torch.cuda.synchronize()
  77. # Enable CuDNN autotuner
  78. nproc_per_node = torch.cuda.device_count()
  79. if args.affinity != 'disabled':
  80. affinity = gpu_affinity.set_affinity(
  81. args.local_rank,
  82. nproc_per_node,
  83. args.affinity
  84. )
  85. print(f'{args.local_rank}: thread affinity: {affinity}')
  86. torch.backends.cudnn.benchmark = True
  87. if args.seed:
  88. np.random.seed(args.seed)
  89. torch.manual_seed(args.seed)
  90. torch.cuda.manual_seed(args.seed)
  91. setup_logger(args)
  92. config = CONFIGS[args.dataset]()
  93. if args.overwrite_config:
  94. config.__dict__.update(json.loads(args.overwrite_config))
  95. dllogger.log(step='HPARAMS', data={**vars(args), **vars(config)}, verbosity=1)
  96. model = TemporalFusionTransformer(config).cuda()
  97. if args.ema_decay:
  98. model_ema = ModelEma(model, decay=args.ema_decay)
  99. print_once('Model params: {}'.format(sum(p.numel() for p in model.parameters())))
  100. criterion = QuantileLoss(config).cuda()
  101. optimizer = FusedAdam(model.parameters(), lr=args.lr)
  102. if args.use_amp:
  103. model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic")
  104. if args.distributed_world_size > 1:
  105. #model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
  106. model = DDP(model)
  107. train_loader, valid_loader, test_loader = load_dataset(args, config)
  108. global_step = 0
  109. perf_meter = PerformanceMeter(benchmark_mode=not args.disable_benchmark)
  110. for epoch in range(args.epochs):
  111. start = time.time()
  112. dllogger.log(step=global_step, data={'epoch': epoch}, verbosity=1)
  113. model.train()
  114. for local_step, batch in enumerate(train_loader):
  115. perf_meter.reset_current_lap()
  116. batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
  117. predictions = model(batch)
  118. targets = batch['target'][:,config.encoder_length:,:]
  119. p_losses = criterion(predictions, targets)
  120. loss = p_losses.sum()
  121. if args.use_amp:
  122. with amp.scale_loss(loss, optimizer) as scaled_loss:
  123. scaled_loss.backward()
  124. else:
  125. loss.backward()
  126. if not args.grad_accumulation or (global_step+1) % args.grad_accumulation == 0:
  127. if args.clip_grad:
  128. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
  129. optimizer.step()
  130. optimizer.zero_grad()
  131. if args.ema_decay:
  132. model_ema.update(model)
  133. if args.distributed_world_size > 1:
  134. dist.all_reduce(p_losses)
  135. p_losses /= args.distributed_world_size
  136. loss = p_losses.sum()
  137. torch.cuda.synchronize()
  138. ips = perf_meter.update(args.batch_size * args.distributed_world_size,
  139. exclude_from_total=local_step in [0, len(train_loader)-1])
  140. log_dict = {'P10':p_losses[0].item(), 'P50':p_losses[1].item(), 'P90':p_losses[2].item(), 'loss': loss.item(), 'items/s':ips}
  141. dllogger.log(step=global_step, data=log_dict, verbosity=1)
  142. global_step += 1
  143. validate(args, config, model_ema if args.ema_decay else model, criterion, valid_loader, global_step)
  144. if validate.early_stop_c >= args.early_stopping:
  145. print_once('Early stopping')
  146. break
  147. ### TEST PHASE ###
  148. state_dict = torch.load(os.path.join(args.results, 'checkpoint.pt'), map_location='cpu')
  149. if isinstance(model, DDP):
  150. model.module.load_state_dict(state_dict['model'])
  151. else:
  152. model.load_state_dict(state_dict['model'])
  153. model.cuda().eval()
  154. tgt_scalers = pickle.load(open(os.path.join(args.data_path, 'tgt_scalers.bin'), 'rb'))
  155. cat_encodings = pickle.load(open(os.path.join(args.data_path,'cat_encodings.bin'), 'rb'))
  156. unscaled_predictions, unscaled_targets, _, _ = predict(args, config, model, test_loader, tgt_scalers, cat_encodings)
  157. losses = QuantileLoss(config)(unscaled_predictions, unscaled_targets)
  158. normalizer = unscaled_targets.abs().mean()
  159. quantiles = 2 * losses / normalizer
  160. if args.distributed_world_size > 1:
  161. quantiles = quantiles.cuda()
  162. dist.all_reduce(quantiles)
  163. quantiles /= args.distributed_world_size
  164. quantiles = {'test_p10': quantiles[0].item(), 'test_p50': quantiles[1].item(), 'test_p90': quantiles[2].item(), 'sum':sum(quantiles).item()}
  165. finish_log = {**quantiles, 'average_ips':perf_meter.avg, 'convergence_step':validate.conv_step}
  166. dllogger.log(step=(), data=finish_log, verbosity=1)
  167. def validate(args, config, model, criterion, dataloader, global_step):
  168. if not hasattr(validate, 'best_valid_loss'):
  169. validate.best_valid_loss = float('inf')
  170. if not hasattr(validate, 'early_stop_c'):
  171. validate.early_stop_c = 0
  172. model.eval()
  173. losses = []
  174. torch.cuda.synchronize()
  175. validation_start = time.time()
  176. for batch in dataloader:
  177. with torch.no_grad():
  178. batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
  179. predictions = model(batch)
  180. targets = batch['target'][:,config.encoder_length:,:]
  181. p_losses = criterion(predictions, targets)
  182. bs = next(t for t in batch.values() if t is not None).shape[0]
  183. losses.append((p_losses, bs))
  184. torch.cuda.synchronize()
  185. validation_end = time.time()
  186. p_losses = sum([l[0]*l[1] for l in losses])/sum([l[1] for l in losses]) #takes into accunt that the last batch is not full
  187. if args.distributed_world_size > 1:
  188. dist.all_reduce(p_losses)
  189. p_losses = p_losses/args.distributed_world_size
  190. ips = len(dataloader.dataset) / (validation_end - validation_start)
  191. log_dict = {'P10':p_losses[0].item(), 'P50':p_losses[1].item(), 'P90':p_losses[2].item(), 'loss': p_losses.sum().item(), 'items/s':ips}
  192. if log_dict['loss'] < validate.best_valid_loss:
  193. validate.best_valid_loss = log_dict['loss']
  194. validate.early_stop_c = 0
  195. validate.conv_step = global_step
  196. if not dist.is_initialized() or dist.get_rank() == 0:
  197. state_dict = model.module.state_dict() if isinstance(model, (DDP, ModelEma)) else model.state_dict()
  198. ckpt = {'args':args, 'config':config, 'model':state_dict}
  199. torch.save(ckpt, os.path.join(args.results, 'checkpoint.pt'))
  200. if args.distributed_world_size > 1:
  201. dist.barrier()
  202. else:
  203. validate.early_stop_c += 1
  204. log_dict = {'val_'+k:v for k,v in log_dict.items()}
  205. dllogger.log(step=global_step, data=log_dict, verbosity=1)
  206. if __name__ == '__main__':
  207. parser = argparse.ArgumentParser()
  208. parser.add_argument('--data_path', type=str, required=True,
  209. help='Path to the dataset')
  210. parser.add_argument('--dataset', type=str, required=True, choices=CONFIGS.keys(),
  211. help='Dataset name')
  212. parser.add_argument('--epochs', type=int, default=25,
  213. help='Default number of training epochs')
  214. parser.add_argument('--sample_data', type=lambda x: int(float(x)), nargs=2, default=[-1, -1],
  215. help="""Subsample the dataset. Specify number of training and valid examples.
  216. Values can be provided in scientific notation. Floats will be truncated.""")
  217. parser.add_argument('--batch_size', type=int, default=64)
  218. parser.add_argument('--lr', type=float, default=1e-3)
  219. parser.add_argument('--seed', type=int, default=1)
  220. parser.add_argument('--use_amp', action='store_true', help='Enable automatic mixed precision')
  221. parser.add_argument('--clip_grad', type=float, default=0.0)
  222. parser.add_argument('--grad_accumulation', type=int, default=0)
  223. parser.add_argument('--early_stopping', type=int, default=1000,
  224. help='Stop training if validation loss does not improve for more than this number of epochs.')
  225. parser.add_argument('--results', type=str, default='/results',
  226. help='Directory in which results are stored')
  227. parser.add_argument('--log_file', type=str, default='dllogger.json',
  228. help='Name of dllogger output file')
  229. parser.add_argument('--overwrite_config', type=str, default='',
  230. help='JSON string used to overload config')
  231. parser.add_argument('--affinity', type=str,
  232. default='socket_unique_interleaved',
  233. choices=['socket', 'single', 'single_unique',
  234. 'socket_unique_interleaved',
  235. 'socket_unique_continuous',
  236. 'disabled'],
  237. help='type of CPU affinity')
  238. parser.add_argument("--ema_decay", type=float, default=0.0, help='Use exponential moving average')
  239. parser.add_argument("--disable_benchmark", action='store_true', help='Disable benchmarking mode')
  240. ARGS = parser.parse_args()
  241. main(ARGS)