|
|
@@ -14,7 +14,6 @@
|
|
|
|
|
|
import os
|
|
|
|
|
|
-import pyprof
|
|
|
import torch
|
|
|
from pytorch_lightning import Trainer, seed_everything
|
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
|
@@ -29,6 +28,7 @@ if __name__ == "__main__":
|
|
|
args = get_main_args()
|
|
|
|
|
|
if args.profile:
|
|
|
+ import pyprof
|
|
|
pyprof.init(enable_function_stack=True)
|
|
|
print("Profiling enabled")
|
|
|
|