|
|
@@ -35,7 +35,7 @@ if __name__ == "__main__":
|
|
|
print("Profiling enabled")
|
|
|
|
|
|
if args.affinity != "disabled":
|
|
|
- affinity = set_affinity(int(os.getenv("LOCAL_RANK", "0")), args.gpus, mode=args.affinity)
|
|
|
+ set_affinity(int(os.getenv("LOCAL_RANK", "0")), args.gpus, mode=args.affinity)
|
|
|
|
|
|
# Limit number of CPU threads
|
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
|
@@ -100,6 +100,7 @@ if __name__ == "__main__":
|
|
|
default_root_dir=args.results,
|
|
|
resume_from_checkpoint=ckpt_path,
|
|
|
accelerator="ddp" if args.gpus > 1 else None,
|
|
|
+ checkpoint_callback=args.save_ckpt,
|
|
|
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
|
|
|
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
|
|
|
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
|