| """ |
| Helpers for distributed training. |
| """ |
|
|
| import datetime |
| import io |
| import os |
| import socket |
|
|
| import blobfile as bf |
| from pdb import set_trace as st |
| |
| import torch as th |
| import torch.distributed as dist |
|
|
| |
| |
| GPUS_PER_NODE = 8 |
| SETUP_RETRY_COUNT = 3 |
|
|
|
|
| def get_rank(): |
| if not dist.is_available(): |
| return 0 |
|
|
| if not dist.is_initialized(): |
| return 0 |
|
|
| return dist.get_rank() |
|
|
|
|
| def synchronize(): |
| if not dist.is_available(): |
| return |
|
|
| if not dist.is_initialized(): |
| return |
|
|
| world_size = dist.get_world_size() |
|
|
| if world_size == 1: |
| return |
|
|
| dist.barrier() |
|
|
|
|
| def get_world_size(): |
| if not dist.is_available(): |
| return 1 |
|
|
| if not dist.is_initialized(): |
| return 1 |
|
|
| return dist.get_world_size() |
|
|
|
|
| def setup_dist(args): |
| """ |
| Setup a distributed process group. |
| """ |
| if dist.is_initialized(): |
| return |
|
|
| |
|
|
| |
| |
| dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000)) |
| print(f"{args.local_rank=} init complete") |
|
|
| |
|
|
| th.cuda.empty_cache() |
|
|
| def cleanup(): |
| dist.destroy_process_group() |
|
|
| def dev(): |
| """ |
| Get the device to use for torch.distributed. |
| """ |
| if th.cuda.is_available(): |
|
|
| if get_world_size() > 1: |
| return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}") |
| return th.device(f"cuda") |
|
|
| return th.device("cpu") |
|
|
|
|
| |
| def load_state_dict(path, **kwargs): |
| """ |
| Load a PyTorch file without redundant fetches across MPI ranks. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| ckpt = th.load(path, **kwargs) |
| |
| |
| |
| |
| return ckpt |
|
|
|
|
| def sync_params(params): |
| """ |
| Synchronize a sequence of Tensors across ranks from rank 0. |
| """ |
| |
| for p in params: |
| with th.no_grad(): |
| try: |
| dist.broadcast(p, 0) |
| except Exception as e: |
| print(k, e) |
| |
|
|
|
|
| def _find_free_port(): |
| try: |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.bind(("", 0)) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| return s.getsockname()[1] |
| finally: |
| s.close() |
|
|
|
|
| _num_moments = 3 |
| _reduce_dtype = th.float32 |
| _counter_dtype = th.float64 |
| _rank = 0 |
| _sync_device = None |
| _sync_called = False |
| _counters = dict() |
| _cumulative = dict() |
|
|
| def init_multiprocessing(rank, sync_device): |
| r"""Initializes `utils.torch_utils.training_stats` for collecting statistics |
| across multiple processes. |
| This function must be called after |
| `torch.distributed.init_process_group()` and before `Collector.update()`. |
| The call is not necessary if multi-process collection is not needed. |
| Args: |
| rank: Rank of the current process. |
| sync_device: PyTorch device to use for inter-process |
| communication, or None to disable multi-process |
| collection. Typically `torch.device('cuda', rank)`. |
| """ |
| global _rank, _sync_device |
| assert not _sync_called |
| _rank = rank |
| _sync_device = sync_device |