train_net.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
  2. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  3. r"""
  4. Basic training script for PyTorch
  5. """
  6. # Set up custom environment before nearly anything else is imported
  7. # NOTE: this should be the first import (no not reorder)
  8. from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip
  9. import argparse
  10. import os
  11. import logging
  12. import functools
  13. import torch
  14. from maskrcnn_benchmark.config import cfg
  15. from maskrcnn_benchmark.data import make_data_loader
  16. from maskrcnn_benchmark.solver import make_lr_scheduler
  17. from maskrcnn_benchmark.solver import make_optimizer
  18. from maskrcnn_benchmark.engine.inference import inference
  19. from maskrcnn_benchmark.engine.trainer import do_train
  20. from maskrcnn_benchmark.modeling.detector import build_detection_model
  21. from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
  22. from maskrcnn_benchmark.utils.collect_env import collect_env_info
  23. from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process
  24. from maskrcnn_benchmark.utils.imports import import_file
  25. from maskrcnn_benchmark.utils.logger import setup_logger
  26. from maskrcnn_benchmark.utils.miscellaneous import mkdir
  27. from maskrcnn_benchmark.engine.tester import test
  28. from maskrcnn_benchmark.utils.logger import format_step
  29. #from dllogger import Logger, StdOutBackend, JSONStreamBackend, Verbosity
  30. #import dllogger as DLLogger
  31. import dllogger
  32. import torch.utils.tensorboard as tbx
  33. from maskrcnn_benchmark.utils.logger import format_step
  34. # See if we can use apex.DistributedDataParallel instead of the torch default,
  35. # and enable mixed-precision via apex.amp
  36. try:
  37. from apex.parallel import DistributedDataParallel as DDP
  38. use_apex_ddp = True
  39. except ImportError:
  40. print('Use APEX for better performance')
  41. use_apex_ddp = False
  42. def test_and_exchange_map(tester, model, distributed):
  43. results = tester(model=model, distributed=distributed)
  44. # main process only
  45. if is_main_process():
  46. # Note: one indirection due to possibility of multiple test datasets, we only care about the first
  47. # tester returns (parsed results, raw results). In our case, don't care about the latter
  48. map_results, raw_results = results[0]
  49. bbox_map = map_results.results["bbox"]['AP']
  50. segm_map = map_results.results["segm"]['AP']
  51. else:
  52. bbox_map = 0.
  53. segm_map = 0.
  54. if distributed:
  55. map_tensor = torch.tensor([bbox_map, segm_map], dtype=torch.float32, device=torch.device("cuda"))
  56. torch.distributed.broadcast(map_tensor, 0)
  57. bbox_map = map_tensor[0].item()
  58. segm_map = map_tensor[1].item()
  59. return bbox_map, segm_map
  60. def mlperf_test_early_exit(iteration, iters_per_epoch, tester, model, distributed, min_bbox_map, min_segm_map):
  61. if iteration > 0 and iteration % iters_per_epoch == 0:
  62. epoch = iteration // iters_per_epoch
  63. dllogger.log(step="PARAMETER", data={"eval_start": True})
  64. bbox_map, segm_map = test_and_exchange_map(tester, model, distributed)
  65. # necessary for correctness
  66. model.train()
  67. dllogger.log(step=(iteration, epoch, ), data={"BBOX_mAP": bbox_map, "MASK_mAP": segm_map})
  68. # terminating condition
  69. if bbox_map >= min_bbox_map and segm_map >= min_segm_map:
  70. dllogger.log(step="PARAMETER", data={"target_accuracy_reached": True})
  71. return True
  72. return False
  73. def train(cfg, local_rank, distributed, fp16, dllogger):
  74. model = build_detection_model(cfg)
  75. device = torch.device(cfg.MODEL.DEVICE)
  76. model.to(device)
  77. optimizer = make_optimizer(cfg, model)
  78. scheduler = make_lr_scheduler(cfg, optimizer)
  79. use_amp = False
  80. if fp16:
  81. use_amp = True
  82. else:
  83. use_amp = cfg.DTYPE == "float16"
  84. if distributed:
  85. if cfg.USE_TORCH_DDP or not use_apex_ddp:
  86. model = torch.nn.parallel.DistributedDataParallel(
  87. model, device_ids=[local_rank], output_device=local_rank,
  88. # this should be removed if we update BatchNorm stats
  89. broadcast_buffers=False,
  90. )
  91. else:
  92. model = DDP(model, delay_allreduce=True)
  93. arguments = {}
  94. arguments["iteration"] = 0
  95. output_dir = cfg.OUTPUT_DIR
  96. save_to_disk = get_rank() == 0
  97. checkpointer = DetectronCheckpointer(
  98. cfg, model, optimizer, scheduler, output_dir, save_to_disk
  99. )
  100. extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
  101. arguments.update(extra_checkpoint_data)
  102. data_loader, iters_per_epoch = make_data_loader(
  103. cfg,
  104. is_train=True,
  105. is_distributed=distributed,
  106. start_iter=arguments["iteration"],
  107. )
  108. checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
  109. # set the callback function to evaluate and potentially
  110. # early exit each epoch
  111. if cfg.PER_EPOCH_EVAL:
  112. per_iter_callback_fn = functools.partial(
  113. mlperf_test_early_exit,
  114. iters_per_epoch=iters_per_epoch,
  115. tester=functools.partial(test, cfg=cfg, dllogger=dllogger),
  116. model=model,
  117. distributed=distributed,
  118. min_bbox_map=cfg.MIN_BBOX_MAP,
  119. min_segm_map=cfg.MIN_MASK_MAP)
  120. else:
  121. per_iter_callback_fn = None
  122. do_train(
  123. model,
  124. data_loader,
  125. optimizer,
  126. scheduler,
  127. checkpointer,
  128. device,
  129. checkpoint_period,
  130. arguments,
  131. use_amp,
  132. cfg,
  133. dllogger,
  134. per_iter_end_callback_fn=per_iter_callback_fn,
  135. nhwc=cfg.NHWC
  136. )
  137. return model, iters_per_epoch
  138. def test_model(cfg, model, distributed, iters_per_epoch, dllogger):
  139. if distributed:
  140. model = model.module
  141. torch.cuda.empty_cache() # TODO check if it helps
  142. iou_types = ("bbox",)
  143. if cfg.MODEL.MASK_ON:
  144. iou_types = iou_types + ("segm",)
  145. output_folders = [None] * len(cfg.DATASETS.TEST)
  146. dataset_names = cfg.DATASETS.TEST
  147. if cfg.OUTPUT_DIR:
  148. for idx, dataset_name in enumerate(dataset_names):
  149. output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
  150. mkdir(output_folder)
  151. output_folders[idx] = output_folder
  152. data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
  153. results = []
  154. for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
  155. result = inference(
  156. model,
  157. data_loader_val,
  158. dataset_name=dataset_name,
  159. iou_types=iou_types,
  160. box_only=cfg.MODEL.RPN_ONLY,
  161. device=cfg.MODEL.DEVICE,
  162. expected_results=cfg.TEST.EXPECTED_RESULTS,
  163. expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
  164. output_folder=output_folder,
  165. dllogger=dllogger,
  166. )
  167. synchronize()
  168. results.append(result)
  169. if is_main_process():
  170. map_results, raw_results = results[0]
  171. bbox_map = map_results.results["bbox"]['AP']
  172. segm_map = map_results.results["segm"]['AP']
  173. dllogger.log(step=(cfg.SOLVER.MAX_ITER, cfg.SOLVER.MAX_ITER / iters_per_epoch,), data={"BBOX_mAP": bbox_map, "MASK_mAP": segm_map})
  174. dllogger.log(step=tuple(), data={"BBOX_mAP": bbox_map, "MASK_mAP": segm_map})
  175. def main():
  176. parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
  177. parser.add_argument(
  178. "--config-file",
  179. default="",
  180. metavar="FILE",
  181. help="path to config file",
  182. type=str,
  183. )
  184. parser.add_argument("--local_rank", type=int, default=os.getenv('LOCAL_RANK', 0))
  185. parser.add_argument("--max_steps", type=int, default=0, help="Override number of training steps in the config")
  186. parser.add_argument("--skip-test", dest="skip_test", help="Do not test the final model",
  187. action="store_true",)
  188. parser.add_argument("--fp16", help="Mixed precision training", action="store_true")
  189. parser.add_argument("--amp", help="Mixed precision training", action="store_true")
  190. parser.add_argument('--skip_checkpoint', default=False, action='store_true', help="Whether to save checkpoints")
  191. parser.add_argument("--json-summary", help="Out file for DLLogger", default="dllogger.out",
  192. type=str,
  193. )
  194. parser.add_argument(
  195. "opts",
  196. help="Modify config options using the command-line",
  197. default=None,
  198. nargs=argparse.REMAINDER,
  199. )
  200. args = parser.parse_args()
  201. args.fp16 = args.fp16 or args.amp
  202. num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
  203. args.distributed = num_gpus > 1
  204. if args.distributed:
  205. torch.cuda.set_device(args.local_rank)
  206. torch.distributed.init_process_group(
  207. backend="nccl", init_method="env://"
  208. )
  209. synchronize()
  210. cfg.merge_from_file(args.config_file)
  211. cfg.merge_from_list(args.opts)
  212. # Redundant option - Override config parameter with command line input
  213. if args.max_steps > 0:
  214. cfg.SOLVER.MAX_ITER = args.max_steps
  215. if args.skip_checkpoint:
  216. cfg.SAVE_CHECKPOINT = False
  217. cfg.freeze()
  218. output_dir = cfg.OUTPUT_DIR
  219. if output_dir:
  220. mkdir(output_dir)
  221. logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
  222. if is_main_process():
  223. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  224. filename=args.json_summary),
  225. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
  226. else:
  227. dllogger.init(backends=[])
  228. dllogger.metadata("BBOX_mAP", {"unit": None})
  229. dllogger.metadata("MASK_mAP", {"unit": None})
  230. dllogger.metadata("e2e_train_time", {"unit": "s"})
  231. dllogger.metadata("train_perf_fps", {"unit": "images/s"})
  232. dllogger.log(step="PARAMETER", data={"gpu_count":num_gpus})
  233. # dllogger.log(step="PARAMETER", data={"environment_info": collect_env_info()})
  234. dllogger.log(step="PARAMETER", data={"config_file": args.config_file})
  235. with open(args.config_file, "r") as cf:
  236. config_str = "\n" + cf.read()
  237. dllogger.log(step="PARAMETER", data={"config":cfg})
  238. if args.fp16:
  239. fp16 = True
  240. else:
  241. fp16 = False
  242. model, iters_per_epoch = train(cfg, args.local_rank, args.distributed, fp16, dllogger)
  243. if not args.skip_test:
  244. if not cfg.PER_EPOCH_EVAL:
  245. test_model(cfg, model, args.distributed, iters_per_epoch, dllogger)
  246. if __name__ == "__main__":
  247. main()
  248. dllogger.log(step=tuple(), data={})
  249. dllogger.flush()