|
@@ -129,7 +129,7 @@ def main(args):
|
|
|
train_loader, valid_loader, test_loader = load_dataset(args, config)
|
|
train_loader, valid_loader, test_loader = load_dataset(args, config)
|
|
|
|
|
|
|
|
global_step = 0
|
|
global_step = 0
|
|
|
- perf_meter = PerformanceMeter()
|
|
|
|
|
|
|
+ perf_meter = PerformanceMeter(benchmark_mode=not args.disable_benchmark)
|
|
|
|
|
|
|
|
for epoch in range(args.epochs):
|
|
for epoch in range(args.epochs):
|
|
|
start = time.time()
|
|
start = time.time()
|
|
@@ -209,6 +209,7 @@ def validate(args, config, model, criterion, dataloader, global_step):
|
|
|
model.eval()
|
|
model.eval()
|
|
|
|
|
|
|
|
losses = []
|
|
losses = []
|
|
|
|
|
+ torch.cuda.synchronize()
|
|
|
validation_start = time.time()
|
|
validation_start = time.time()
|
|
|
for batch in dataloader:
|
|
for batch in dataloader:
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
@@ -219,6 +220,7 @@ def validate(args, config, model, criterion, dataloader, global_step):
|
|
|
bs = next(t for t in batch.values() if t is not None).shape[0]
|
|
bs = next(t for t in batch.values() if t is not None).shape[0]
|
|
|
losses.append((p_losses, bs))
|
|
losses.append((p_losses, bs))
|
|
|
|
|
|
|
|
|
|
+ torch.cuda.synchronize()
|
|
|
validation_end = time.time()
|
|
validation_end = time.time()
|
|
|
|
|
|
|
|
p_losses = sum([l[0]*l[1] for l in losses])/sum([l[1] for l in losses]) #takes into accunt that the last batch is not full
|
|
p_losses = sum([l[0]*l[1] for l in losses])/sum([l[1] for l in losses]) #takes into accunt that the last batch is not full
|
|
@@ -280,6 +282,7 @@ if __name__ == '__main__':
|
|
|
'disabled'],
|
|
'disabled'],
|
|
|
help='type of CPU affinity')
|
|
help='type of CPU affinity')
|
|
|
parser.add_argument("--ema_decay", type=float, default=0.0, help='Use exponential moving average')
|
|
parser.add_argument("--ema_decay", type=float, default=0.0, help='Use exponential moving average')
|
|
|
|
|
+ parser.add_argument("--disable_benchmark", action='store_true', help='Disable benchmarking mode')
|
|
|
|
|
|
|
|
|
|
|
|
|
ARGS = parser.parse_args()
|
|
ARGS = parser.parse_args()
|