Преглед изворни кода

[resnet/mxnet] Apply horovod patch for hvd init

Michal Marcinkiewicz пре 2 година
родитељ
комит
810bcf375e
2 измењених фајлова са 6 додато и 6 уклоњено
  1. 1 1
      MxNet/Classification/RN50v1.5/dali.py
  2. 5 5
      MxNet/Classification/RN50v1.5/fit.py

+ 1 - 1
MxNet/Classification/RN50v1.5/dali.py

@@ -31,7 +31,7 @@ def add_dali_args(parser):
     group.add_argument('--dali-validation-threads', type=int, default=10, help="number of threads" +\
                        "per GPU for DALI for validation")
     group.add_argument('--dali-prefetch-queue', type=int, default=5, help="DALI prefetch queue depth")
-    group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=256, help="Memory padding value for nvJPEG (in MB)")
+    group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=64, help="Memory padding value for nvJPEG (in MB)")
     group.add_argument('--dali-fuse-decoder', type=int, default=1, help="0 or 1 whether to fuse decoder or not")
 
     group.add_argument('--dali-nvjpeg-width-hint', type=int, default=5980, help="Width hint value for nvJPEG (in pixels)")

+ 5 - 5
MxNet/Classification/RN50v1.5/fit.py

@@ -483,11 +483,6 @@ def fit(args, model, data_loader):
     # select gpu for horovod process
     if 'horovod' in args.kv_store:
         args.gpus = [args.gpus[hvd.local_rank()]]
-        ctx = mx.gpu(hvd.local_rank())
-
-        tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
-        tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
-        tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
 
     if args.amp:
         amp.init()
@@ -579,6 +574,11 @@ def fit(args, model, data_loader):
         params = model.collect_params()
         if params is not None:
             hvd.broadcast_parameters(params, root_rank=0)
+        ctx = mx.gpu(hvd.local_rank())
+        tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
+        tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
+        tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
+        
     global_metrics = CompositeMeter()
     if args.mode in ['train_val', 'train']:
         global_metrics.register_metric('train.loss', MinMeter())