Bladeren bron

[TXL/PyT] Update for PyT Transformer-XL:
* WAR for issues with logging in distributed setting
* improved CPU-GPU affinity
* minor optimizations for the model

Szymon Migacz 5 jaren geleden
bovenliggende
commit
09a17a0f33

+ 37 - 4
PyTorch/LanguageModeling/Transformer-XL/pytorch/eval.py

@@ -20,11 +20,16 @@ import os
 import pickle
 import sys
 import time
+import warnings
 
 import dllogger
 import numpy as np
 import torch
 import yaml
+try:
+    import pyprof
+except ModuleNotFoundError:
+    warnings.warn('PyProf is unavailable')
 
 import data_utils
 import utils
@@ -72,6 +77,15 @@ def parse_args():
     parser.add_argument('--split', type=str, default='all',
                         choices=['all', 'valid', 'test'],
                         help='which split to evaluate')
+    parser.add_argument('--affinity', type=str,
+                        default='single_unique',
+                        choices=['socket', 'single', 'single_unique',
+                                 'socket_unique_interleaved',
+                                 'socket_unique_continuous',
+                                 'disabled'],
+                        help='type of CPU affinity')
+    parser.add_argument('--profile', action='store_true',
+                        help='Enable profiling with DLProf')
     parser.add_argument('--type', type=str, default='pytorch',
                         choices=['pytorch', 'torchscript'],
                         help='type of runtime to use')
@@ -134,7 +148,12 @@ def parse_args():
     if args.manual:
         args.batch_size = 1
 
-    assert args.ext_len >= 0, 'extended context length must be non-negative'
+    if args.same_length and args.tgt_len > args.mem_len:
+        warnings.warn('--same_length is intended to be used with large '
+                      'mem_len relative to tgt_len')
+
+    if args.ext_len < 0:
+        raise RuntimeError('Extended context length must be non-negative')
     return args
 
 
@@ -182,7 +201,6 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
                 loss = loss.float().mean()
                 log_loss += loss.item()
                 if warm:
-                    # assert all([m.size(0) == model.mem_len for m in mems])
                     total_loss += seq_len * loss.item()
                     total_len += seq_len
 
@@ -251,7 +269,14 @@ def compile_model(model, device, args):
 
 def main():
     args = parse_args()
-    utils.gpu_affinity.set_affinity(args.local_rank)
+    if args.affinity != 'disabled':
+        nproc_per_node = torch.cuda.device_count()
+        affinity = utils.gpu_affinity.set_affinity(
+            args.local_rank,
+            nproc_per_node,
+            args.affinity
+        )
+        print(f'{args.local_rank}: thread affinity: {affinity}')
 
     if args.type == 'pytorch':
         from mem_transformer import MemTransformerLM
@@ -286,6 +311,12 @@ def main():
                                   )
     utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
 
+    if args.profile:
+        try:
+            pyprof.init(enable_function_stack=True)
+        except NameError:
+            warnings.warn('Called pyprof.init() but pyprof is not available')
+
     logging.info(args)
     dllogger.log(step='PARAMETER', data=vars(args))
 
@@ -423,7 +454,9 @@ def main():
     meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
     meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
 
-    loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat)
+    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
+        loss = evaluate(iter, model, meters, args.log_interval, args.max_size,
+                        args.repeat)
     perplexity = math.exp(loss)
     log_str = format_log(loss, args.split, args)
 

+ 21 - 32
PyTorch/LanguageModeling/Transformer-XL/pytorch/inference/mem_transformer_jit.py

@@ -124,7 +124,7 @@ class MultiHeadAttn(nn.Module):
         # [bsz x n_head x qlen x klen]
         attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
         attn_score.mul_(self.scale)
-        if attn_mask is not None and attn_mask.any():
+        if attn_mask is not None:
             if attn_mask.dim() == 2:
                 attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
             elif attn_mask.dim() == 3:
@@ -279,7 +279,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
         attn_score.mul_(self.scale)
 
         # compute attention probability
-        if attn_mask is not None and attn_mask.any():
+        if attn_mask is not None:
             if attn_mask.dim() == 2:
                 attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
             elif attn_mask.dim() == 3:
@@ -370,7 +370,7 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
         attn_score.mul_(self.scale)
 
         # compute attention probability
-        if attn_mask is not None and attn_mask.any():
+        if attn_mask is not None:
             if attn_mask.dim() == 2:
                 attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
             elif attn_mask.dim() == 3:
@@ -521,8 +521,7 @@ class AdaptiveEmbedding(nn.Module):
             emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
                                    dtype=self.dtype, device=torch.device('cuda'))
 
-            i = 0
-            for emb_layer in self.emb_layers:
+            for i, emb_layer in enumerate(self.emb_layers):
                 l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
 
                 mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
@@ -534,7 +533,6 @@ class AdaptiveEmbedding(nn.Module):
                     emb_i = F.linear(emb_i, self.emb_projs[i])
 
                     emb_flat.index_copy_(0, indices_i, emb_i)
-                i += 1
 
             embed = emb_flat.view(inp.size(0), inp.size(1), self.d_proj)
 
@@ -627,23 +625,15 @@ class MemTransformerLM(nn.Module):
         # default attention
         if self.attn_type == 0:
             self.pos_emb = PositionalEmbedding(self.d_model)
-            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
-            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
-
-    def reset_length(self, tgt_len, ext_len, mem_len):
-        self.tgt_len = tgt_len
-        self.mem_len = mem_len
-        self.ext_len = ext_len
+            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head).zero_())
+            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head).zero_())
 
     def init_mems(self):
-        mems = []
-        for i in range(self.n_layer+1):
-            empty = torch.empty(0, dtype=self.dtype, device=torch.device('cuda'))
-            mems.append(empty)
+        mems = torch.empty(self.n_layer, 0, dtype=self.dtype, device=torch.device('cuda'))
 
         return mems
 
-    def _update_mems(self, hids: List[torch.Tensor], mems: List[torch.Tensor],
+    def _update_mems(self, hids: List[torch.Tensor], mems: torch.Tensor,
                      qlen: int, mlen: int):
         assert len(hids) == len(mems), 'len(hids) != len(mems)'
 
@@ -652,16 +642,18 @@ class MemTransformerLM(nn.Module):
         # will be used as the extended context. Hence, we only cache
         # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
         # to `mlen + qlen - self.ext_len`.
-        new_mems = []
-        end_idx = mlen + max(0, qlen - 0 - self.ext_len)
+        stacked = torch.stack(hids)
+        end_idx = mlen + max(0, qlen - self.ext_len)
         beg_idx = max(0, end_idx - self.mem_len)
-        for i in range(len(hids)):
-            cat = torch.cat([mems[i], hids[i]], dim=0)
-            new_mems.append(cat[beg_idx:end_idx].detach())
+        if mems.numel():
+            cat = torch.cat([mems, stacked], dim=1)
+        else:
+            cat = stacked
+        new_mems = cat[:, beg_idx:end_idx].detach()
 
         return new_mems
 
-    def _forward(self, dec_inp, mems: List[torch.Tensor]):
+    def _forward(self, dec_inp, mems: torch.Tensor):
         qlen, bsz = dec_inp.size()
 
         word_emb = self.word_emb(dec_inp)
@@ -671,7 +663,7 @@ class MemTransformerLM(nn.Module):
         all_ones = torch.ones((qlen, klen), device=torch.device('cuda'),
                               dtype=self.dtype)
         if self.same_length:
-            mask_len = klen - self.mem_len
+            mask_len = klen - self.mem_len - 1
             if mask_len > 0:
                 mask_shift_len = qlen - mask_len
             else:
@@ -681,7 +673,6 @@ class MemTransformerLM(nn.Module):
         else:
             dec_attn_mask = torch.triu(all_ones, diagonal=1+mlen).to(torch.bool)
 
-        hids = []
         pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
                                dtype=word_emb.dtype)
         if self.clamp_len > 0:
@@ -691,22 +682,20 @@ class MemTransformerLM(nn.Module):
         core_out = self.drop(word_emb)
         pos_emb = self.drop(pos_emb)
 
-        hids.append(core_out)
-        i = 0
-        for layer in self.layers:
+        hids = []
+        for i, layer in enumerate(self.layers):
+            hids.append(core_out)
             mems_i = None if mems is None else mems[i]
             core_out = layer(core_out, pos_emb, self.r_w_bias,
                              self.r_r_bias, dec_attn_mask=dec_attn_mask,
                              mems=mems_i)
-            hids.append(core_out)
-            i += 1
         core_out = self.drop(core_out)
 
         new_mems = self._update_mems(hids, mems, qlen, mlen)
 
         return core_out, new_mems
 
-    def forward(self, data, target, mems: Optional[List[torch.Tensor]]):
+    def forward(self, data, target, mems: Optional[torch.Tensor]):
         # nn.DataParallel does not allow size(0) tensors to be broadcasted.
         # So, have to initialize size(0) mems inside the model forward.
         # Moreover, have to return new_mems to allow nn.DataParallel to piece

+ 34 - 22
PyTorch/LanguageModeling/Transformer-XL/pytorch/mem_transformer.py

@@ -21,6 +21,11 @@ from utils.log_uniform_sampler import sample_logits
 from utils.proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
 
 
[email protected]
+def add_and_scale(tensor1, tensor2, alpha: float):
+    return alpha * (tensor1 + tensor2)
+
+
 class PositionalEmbedding(nn.Module):
     def __init__(self, demb):
         super(PositionalEmbedding, self).__init__()
@@ -122,7 +127,7 @@ class MultiHeadAttn(nn.Module):
         # [bsz x n_head x qlen x klen]
         attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
         attn_score.mul_(self.scale)
-        if attn_mask is not None and attn_mask.any().item():
+        if attn_mask is not None:
             if attn_mask.dim() == 2:
                 attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
             elif attn_mask.dim() == 3:
@@ -266,11 +271,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
         BD = self._rel_shift(BD)
 
         # [bsz x n_head x qlen x klen]
-        attn_score = AC + BD
-        attn_score.mul_(self.scale)
+        attn_score = add_and_scale(AC, BD, self.scale)
 
         # compute attention probability
-        if attn_mask is not None and attn_mask.any().item():
+        if attn_mask is not None:
             if attn_mask.dim() == 2:
                 attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
             elif attn_mask.dim() == 3:
@@ -354,11 +358,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
         BD = self._rel_shift(B_ + D_)
 
         # [bsz x qlen x klen x n_head]
-        attn_score = AC + BD
-        attn_score.mul_(self.scale)
+        attn_score = add_and_scale(AC, BD, self.scale)
 
         # compute attention probability
-        if attn_mask is not None and attn_mask.any().item():
+        if attn_mask is not None:
             if attn_mask.dim() == 2:
                 attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
             elif attn_mask.dim() == 3:
@@ -627,6 +630,12 @@ class MemTransformerLM(nn.Module):
                     self.n_layer, self.max_klen, self.d_model).zero_())
 
     def reset_length(self, tgt_len, ext_len, mem_len):
+        if tgt_len < 1:
+            raise RuntimeError(f'tgt_len should be >= 1, but got {tgt_len}')
+        if ext_len < 0:
+            raise RuntimeError(f'ext_len should be >= 0, but got {ext_len}')
+        if mem_len < 0:
+            raise RuntimeError(f'mem_len should be >= 0, but got {mem_len}')
         self.tgt_len = tgt_len
         self.mem_len = mem_len
         self.ext_len = ext_len
@@ -634,7 +643,7 @@ class MemTransformerLM(nn.Module):
     def init_mems(self):
         if self.mem_len > 0:
             param = next(self.parameters())
-            mems = torch.empty(self.n_layer + 1, 0, dtype=param.dtype,
+            mems = torch.empty(self.n_layer, 0, dtype=param.dtype,
                                device=param.device)
             return mems
         else:
@@ -654,14 +663,21 @@ class MemTransformerLM(nn.Module):
         # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
         # to `mlen + qlen - self.ext_len`.
         with torch.no_grad():
-            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
-            beg_idx = max(0, end_idx - self.mem_len)
             stacked = torch.stack(hids)
-            if mems.numel():
-                cat = torch.cat([mems, stacked], dim=1)
+            if (
+                self.mem_len == self.tgt_len
+                and self.ext_len == 0
+                and stacked.size(1) == self.mem_len
+            ):
+                new_mems = stacked.detach()
             else:
-                cat = stacked
-            new_mems = cat[:, beg_idx:end_idx].detach()
+                end_idx = mlen + max(0, qlen - self.ext_len)
+                beg_idx = max(0, end_idx - self.mem_len)
+                if mems.numel():
+                    cat = torch.cat([mems, stacked], dim=1)
+                else:
+                    cat = stacked
+                new_mems = cat[:, beg_idx:end_idx].detach()
 
         return new_mems
 
@@ -697,18 +713,17 @@ class MemTransformerLM(nn.Module):
             core_out = self.drop(word_emb)
             pos_emb = self.drop(pos_emb)
 
-            hids.append(core_out.detach())
             for i, layer in enumerate(self.layers):
+                hids.append(core_out.detach())
                 mems_i = None if mems is None else mems[i]
                 core_out = layer(core_out, pos_emb, self.r_w_bias,
                                  self.r_r_bias, dec_attn_mask=dec_attn_mask,
                                  mems=mems_i)
-                hids.append(core_out.detach())
         # learnable
         elif self.attn_type == 1:
             core_out = self.drop(word_emb)
-            hids.append(core_out.detach())
             for i, layer in enumerate(self.layers):
+                hids.append(core_out.detach())
                 if self.clamp_len > 0:
                     r_emb = self.r_emb[i][-self.clamp_len:]
                     r_bias = self.r_bias[i][-self.clamp_len:]
@@ -718,7 +733,6 @@ class MemTransformerLM(nn.Module):
                 mems_i = None if mems is None else mems[i]
                 core_out = layer(core_out, r_emb, self.r_w_bias[i],
                                  r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
-                hids.append(core_out.detach())
         # absolute
         elif self.attn_type == 2:
             pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
@@ -729,19 +743,18 @@ class MemTransformerLM(nn.Module):
 
             core_out = self.drop(word_emb + pos_emb[-qlen:])
 
-            hids.append(core_out.detach())
             for i, layer in enumerate(self.layers):
+                hids.append(core_out.detach())
                 mems_i = None if mems is None else mems[i]
                 if mems_i is not None and len(mems_i) and i == 0:
                     mems_i += pos_emb[:mlen]
                 core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                  mems=mems_i)
-                hids.append(core_out.detach())
         elif self.attn_type == 3:
             core_out = self.drop(word_emb)
 
-            hids.append(core_out.detach())
             for i, layer in enumerate(self.layers):
+                hids.append(core_out.detach())
                 mems_i = None if mems is None else mems[i]
                 if mems_i is not None and len(mems_i) and mlen > 0:
                     cur_emb = self.r_emb[i][:-qlen]
@@ -756,7 +769,6 @@ class MemTransformerLM(nn.Module):
 
                 core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                  mems=mems_i)
-                hids.append(core_out.detach())
 
         core_out = self.drop(core_out)
 

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/inference_benchmark.sh

@@ -33,7 +33,7 @@ for (( i = 0; i < ${#TYPES[@]}; i++ )); do
          DIR="LM-TFM/inference/${GPU}_${BATCH_SIZES[j]}_${MATHS_FULL[k]}_${TYPES[i]}"
          mkdir -p "${DIR}"
 
-         taskset -c 0 bash run_wt103_"${MODEL}".sh eval 1 \
+         bash run_wt103_"${MODEL}".sh eval 1 \
             --work_dir "${DIR}" \
             --model "${CHECKPOINT}" \
             --type "${TYPES[i]}" \

+ 68 - 26
PyTorch/LanguageModeling/Transformer-XL/pytorch/train.py

@@ -35,6 +35,10 @@ try:
     from apex import amp
 except ModuleNotFoundError:
     warnings.warn('APEX AMP is unavailable')
+try:
+    import pyprof
+except ModuleNotFoundError:
+    warnings.warn('PyProf is unavailable')
 
 from torch.nn.parallel import DistributedDataParallel
 
@@ -111,6 +115,15 @@ def parse_args():
                          help='Optimization level for apex amp')
     general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
                          help='Implementation of automatic mixed precision')
+    general.add_argument('--affinity', type=str,
+                         default='socket_unique_interleaved',
+                         choices=['socket', 'single', 'single_unique',
+                                  'socket_unique_interleaved',
+                                  'socket_unique_continuous',
+                                  'disabled'],
+                         help='type of CPU affinity')
+    general.add_argument('--profile', action='store_true',
+                         help='Enable profiling with DLProf')
 
     dataset = parser.add_argument_group('dataset setup')
     dataset.add_argument('--data', type=str, default='../data/wikitext-103',
@@ -256,6 +269,19 @@ def parse_args():
     if args.ext_len < 0:
         raise RuntimeError('Extended context length must be non-negative')
 
+    if args.mem_len == 0:
+        if args.eval_tgt_len > args.ext_len + args.tgt_len:
+            raise RuntimeError('eval_tgt_len should be <= tgt_len + ext_len; '
+                               f'eval_tgt_len: {args.eval_tgt_len}, '
+                               f'tgt_len: {args.tgt_len}, '
+                               f'ext_len: {args.ext_len}')
+    else:
+        if args.eval_tgt_len > args.mem_len + args.tgt_len:
+            raise RuntimeError('eval_tgt_len should be <= tgt_len + mem_len; '
+                               f'eval_tgt_len: {args.eval_tgt_len}, '
+                               f'tgt_len: {args.tgt_len}, '
+                               f'mem_len: {args.mem_len}')
+
     if args.batch_size % args.batch_chunk != 0:
         raise RuntimeError('Batch size needs to be divisible by batch chunk')
 
@@ -421,7 +447,7 @@ def evaluate(eval_iter, model, args):
             loss, mems = model(data, target, mems)
             loss = loss.float().mean()
             if warm:
-                assert (mems is None) or mems.size(1) == model.mem_len
+                # assert (mems is None) or mems.size(1) == model.mem_len
                 total_loss += seq_len * loss.item()
                 total_len += seq_len
 
@@ -659,7 +685,14 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
 
 def main():
     args = parse_args()
-    utils.gpu_affinity.set_affinity(args.local_rank)
+    if args.affinity != 'disabled':
+        nproc_per_node = torch.cuda.device_count()
+        affinity = utils.gpu_affinity.set_affinity(
+            args.local_rank,
+            nproc_per_node,
+            args.affinity
+        )
+        print(f'{args.local_rank}: thread affinity: {affinity}')
 
     # Initialize device and distributed backend
     torch.cuda.set_device(args.local_rank)
@@ -703,6 +736,12 @@ def main():
         logging.info(f'--local_batch_size was set, adjusting global batch size'
                      f' to {args.batch_size} (local_batch_size * world_size)')
 
+    if args.profile:
+        try:
+            pyprof.init(enable_function_stack=True)
+        except NameError:
+            warnings.warn('Called pyprof.init() but pyprof is not available')
+
     logging.info(args)
     dllogger.log(step='PARAMETER', data=vars(args))
 
@@ -956,28 +995,30 @@ def main():
     # Loop over epochs.
     # At any point you can hit Ctrl + C to break out of training early.
     start_time = time.time()
-    with TimeoutHandler() as timeout_handler:
-        try:
-            for epoch in itertools.count(start=start_epoch):
-                if args.roll:
-                    tr_iter.roll(seed=args.seed + epoch)
-                train_step, best_val_loss = train(
-                    tr_iter, va_iter, model, para_model, model_config,
-                    optimizer, optimizer_sparse, scheduler, scheduler_sparse,
-                    scaler, vocab, epoch, last_batch, last_iter, train_step,
-                    best_val_loss, meters, timeout_handler, device, args
-                    )
-
-                last_batch = 0
-                last_iter = 0
-
-                if train_step == args.max_step:
-                    logging.info('-' * 100)
-                    logging.info('End of training')
-                    break
-        except KeyboardInterrupt:
-            logging.info('-' * 100)
-            logging.info('Exiting from training early')
+    with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
+        with TimeoutHandler() as timeout_handler:
+            try:
+                for epoch in itertools.count(start=start_epoch):
+                    if args.roll:
+                        tr_iter.roll(seed=args.seed + epoch)
+                    train_step, best_val_loss = train(
+                        tr_iter, va_iter, model, para_model, model_config,
+                        optimizer, optimizer_sparse, scheduler,
+                        scheduler_sparse, scaler, vocab, epoch, last_batch,
+                        last_iter, train_step, best_val_loss, meters,
+                        timeout_handler, device, args
+                        )
+
+                    last_batch = 0
+                    last_iter = 0
+
+                    if train_step == args.max_step:
+                        logging.info('-' * 100)
+                        logging.info('End of training')
+                        break
+            except KeyboardInterrupt:
+                logging.info('-' * 100)
+                logging.info('Exiting from training early')
     elapsed = time.time() - start_time
 
     ###########################################################################
@@ -992,8 +1033,9 @@ def main():
 
         # Run on test data.
         test_start_time = time.time()
-        test_loss = evaluate(te_iter, model, args)
-        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
+        with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
+            test_loss = evaluate(te_iter, model, args)
+            test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
         test_elapsed = time.time() - test_start_time
 
         logging.info('=' * 100)

+ 4 - 0
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/exp_utils.py

@@ -155,6 +155,10 @@ def setup_logging(log_all_ranks=True, filename=os.devnull, filemode='w'):
         if rank != 0:
             filename = os.devnull
 
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+        handler.close()
+
     logging.basicConfig(level=logging.DEBUG,
                         format=logging_format,
                         datefmt="%Y-%m-%d %H:%M:%S",

+ 101 - 8
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/gpu_affinity.py

@@ -1,5 +1,8 @@
+import collections
 import math
 import os
+import pathlib
+import re
 
 import pynvml
 
@@ -35,15 +38,105 @@ class device:
         affinity_list = [int(x) for x in affinity_string]
         affinity_list.reverse()  # so core 0 is in 0th element of list
 
-        return [i for i, e in enumerate(affinity_list) if e != 0]
+        ret = [i for i, e in enumerate(affinity_list) if e != 0]
+        return ret
 
 
-def set_affinity(gpu_id=None):
-    if gpu_id is None:
-        gpu_id = int(os.getenv('LOCAL_RANK', 0))
-
+def set_socket_affinity(gpu_id):
     dev = device(gpu_id)
-    os.sched_setaffinity(0, dev.getCpuAffinity())
+    affinity = dev.getCpuAffinity()
+    os.sched_setaffinity(0, affinity)
+
 
-    # list of ints representing the logical cores this process is now affinitied with
-    return os.sched_getaffinity(0)
+def set_single_affinity(gpu_id):
+    dev = device(gpu_id)
+    affinity = dev.getCpuAffinity()
+    os.sched_setaffinity(0, affinity[:1])
+
+
+def set_single_unique_affinity(gpu_id, nproc_per_node):
+    devices = [device(i) for i in range(nproc_per_node)]
+    socket_affinities = [dev.getCpuAffinity() for dev in devices]
+
+    siblings_list = get_thread_siblings_list()
+    siblings_dict = dict(siblings_list)
+
+    # remove siblings
+    for idx, socket_affinity in enumerate(socket_affinities):
+        socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
+
+    affinities = []
+    assigned = []
+
+    for socket_affinity in socket_affinities:
+        for core in socket_affinity:
+            if core not in assigned:
+                affinities.append([core])
+                assigned.append(core)
+                break
+    os.sched_setaffinity(0, affinities[gpu_id])
+
+
+def set_socket_unique_affinity(gpu_id, nproc_per_node, mode):
+    device_ids = [device(i) for i in range(nproc_per_node)]
+    socket_affinities = [dev.getCpuAffinity() for dev in device_ids]
+
+    siblings_list = get_thread_siblings_list()
+    siblings_dict = dict(siblings_list)
+
+    # remove siblings
+    for idx, socket_affinity in enumerate(socket_affinities):
+        socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
+
+    socket_affinities_to_device_ids = collections.defaultdict(list)
+
+    for idx, socket_affinity in enumerate(socket_affinities):
+        socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
+
+    for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
+        devices_per_group = len(device_ids)
+        cores_per_device = len(socket_affinity) // devices_per_group
+        for group_id, device_id in enumerate(device_ids):
+            if device_id == gpu_id:
+                if mode == 'interleaved':
+                    affinity = list(socket_affinity[group_id::devices_per_group])
+                elif mode == 'continuous':
+                    affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device])
+                else:
+                    raise RuntimeError('Unknown set_socket_unique_affinity mode')
+
+                # reintroduce siblings
+                affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
+                os.sched_setaffinity(0, affinity)
+
+
+def get_thread_siblings_list():
+    path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list'
+    thread_siblings_list = []
+    pattern = re.compile(r'(\d+)\D(\d+)')
+    for fname in pathlib.Path(path[0]).glob(path[1:]):
+        with open(fname) as f:
+            content = f.read().strip()
+            res = pattern.findall(content)
+            if res:
+                pair = tuple(map(int, res[0]))
+                thread_siblings_list.append(pair)
+    return thread_siblings_list
+
+
+def set_affinity(gpu_id, nproc_per_node, mode='socket'):
+    if mode == 'socket':
+        set_socket_affinity(gpu_id)
+    elif mode == 'single':
+        set_single_affinity(gpu_id)
+    elif mode == 'single_unique':
+        set_single_unique_affinity(gpu_id, nproc_per_node)
+    elif mode == 'socket_unique_interleaved':
+        set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved')
+    elif mode == 'socket_unique_continuous':
+        set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous')
+    else:
+        raise RuntimeError('Unknown affinity mode')
+
+    affinity = os.sched_getaffinity(0)
+    return affinity