Fix case with one training shard only
@@ -471,7 +471,9 @@ def main():
overflow_buf = None
if args.allreduce_post_accumulation:
overflow_buf = torch.cuda.IntTensor([0])
-
+
+ if len(files) == 1:
+ f_start_id = -1
for f_id in range(f_start_id + 1 , len(files)):