|
|
@@ -116,18 +116,23 @@ if __name__ == '__main__':
|
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
|
|
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
|
|
- evaluate(model,
|
|
|
- test_dataloader,
|
|
|
- callbacks,
|
|
|
- args)
|
|
|
+ if not args.benchmark:
|
|
|
+ evaluate(model,
|
|
|
+ test_dataloader,
|
|
|
+ callbacks,
|
|
|
+ args)
|
|
|
|
|
|
- for callback in callbacks:
|
|
|
- callback.on_validation_end()
|
|
|
+ for callback in callbacks:
|
|
|
+ callback.on_validation_end()
|
|
|
|
|
|
- if args.benchmark:
|
|
|
+ else:
|
|
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
|
- callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
|
|
|
- for _ in range(6):
|
|
|
+ callbacks = [PerformanceCallback(
|
|
|
+ logger, args.batch_size * world_size,
|
|
|
+ warmup_epochs=1 if args.epochs > 1 else 0,
|
|
|
+ mode='inference'
|
|
|
+ )]
|
|
|
+ for _ in range(args.epochs):
|
|
|
evaluate(model,
|
|
|
test_dataloader,
|
|
|
callbacks,
|