|
@@ -7,7 +7,7 @@ import time
|
|
|
import torch
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
-from maskrcnn_benchmark.utils.comm import get_world_size
|
|
|
|
|
|
|
+from maskrcnn_benchmark.utils.comm import get_world_size, synchronized_timestamp
|
|
|
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
|
|
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
|
|
|
|
|
|
|
|
def reduce_loss_dict(loss_dict):
|
|
def reduce_loss_dict(loss_dict):
|
|
@@ -90,8 +90,8 @@ def do_train(
|
|
|
prefetcher = Prefetcher(data_loader, device)
|
|
prefetcher = Prefetcher(data_loader, device)
|
|
|
start_iter = arguments["iteration"]
|
|
start_iter = arguments["iteration"]
|
|
|
model.train()
|
|
model.train()
|
|
|
- start_training_time = time.time()
|
|
|
|
|
- end = time.time()
|
|
|
|
|
|
|
+ start_training_time = synchronized_timestamp()
|
|
|
|
|
+ end = start_training_time
|
|
|
if use_amp:
|
|
if use_amp:
|
|
|
scaler = torch.cuda.amp.GradScaler(init_scale=8192.0)
|
|
scaler = torch.cuda.amp.GradScaler(init_scale=8192.0)
|
|
|
for iteration, (images, targets) in enumerate(prefetcher, start_iter):
|
|
for iteration, (images, targets) in enumerate(prefetcher, start_iter):
|
|
@@ -169,7 +169,7 @@ def do_train(
|
|
|
if early_exit:
|
|
if early_exit:
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
- total_training_time = time.time() - start_training_time
|
|
|
|
|
|
|
+ total_training_time = synchronized_timestamp() - start_training_time
|
|
|
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
|
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
|
|
dllogger.log(step=tuple(), data={"e2e_train_time": total_training_time,
|
|
dllogger.log(step=tuple(), data={"e2e_train_time": total_training_time,
|
|
|
"train_perf_fps": max_iter * cfg.SOLVER.IMS_PER_BATCH / total_training_time})
|
|
"train_perf_fps": max_iter * cfg.SOLVER.IMS_PER_BATCH / total_training_time})
|