Source code for deepquantum.communication

"""Communication utilities"""

import os

import torch
import torch.distributed as dist


[docs] def setup_distributed(backend: str = 'nccl', port: str = '29500') -> tuple[int, int, int]: """Initialize ``torch.distributed``.""" try: # These should be set by the launch script (e.g., torchrun) rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) local_rank = int(os.environ['LOCAL_RANK']) # GPU id on the current node except KeyError: print('RANK, WORLD_SIZE, and LOCAL_RANK env vars must be set.') # Fallback for single-process testing (optional) rank = 0 world_size = 1 local_rank = 0 os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = port if backend == 'nccl': print(f'Initializing distributed setup: Rank {rank}/{world_size}, Local Rank (GPU): {local_rank}') elif backend == 'gloo': print(f'Initializing distributed setup: Rank {rank}/{world_size}, Local Rank (CPU): {local_rank}') # Initialize the process group dist.init_process_group(backend, world_size=world_size, rank=rank) if backend == 'nccl': # Pin the current process to a specific GPU torch.cuda.set_device(local_rank) print(f'Rank {rank} initialized, using GPU {local_rank}.') elif backend == 'gloo': print(f'Rank {rank} initialized.') return rank, world_size, local_rank
[docs] def cleanup_distributed() -> None: """Clean up the distributed environment.""" dist.destroy_process_group()
[docs] def comm_get_rank() -> int: """Get the rank of the current process.""" if not dist.is_initialized(): return 0 return dist.get_rank()
[docs] def comm_get_world_size() -> int: """Get the total number of processes.""" if not dist.is_initialized(): return 1 return dist.get_world_size()
[docs] def comm_exchange_arrays(send_data: torch.Tensor, recv_data: torch.Tensor, pair_rank: int | None) -> None: """Exchange tensor data with a peer rank using collective communication. This performs a point-to-point communication via ``dist.all_to_all_single`` and allows specific ranks to participate in the collective call without active data transfer by setting ``pair_rank`` to ``None``. Args: send_data: Data to be sent to the ``pair_rank``. If ``pair_rank`` is ``None``, this can be an empty tensor with correct dtype and device. recv_data: Pre-allocated buffer to store received data. Must match ``send_data`` in shape and dtype if ``pair_rank`` is active. If ``pair_rank`` is ``None``, this can be an empty tensor. pair_rank: The target rank for exchange, or ``None`` to remain quiescent during the collective call. """ world_size = comm_get_world_size() rank = comm_get_rank() if not dist.is_initialized() or world_size <= 1: return if world_size == 1 and pair_rank is not None and rank == pair_rank: if send_data.numel() > 0 and recv_data.numel() > 0: recv_data.copy_(send_data) return is_valid = (pair_rank is not None) and (0 <= pair_rank < world_size) io_sizes = [0] * world_size if is_valid: assert send_data.shape == recv_data.shape, 'Send/Recv shape must match for active P2P' assert send_data.dtype == recv_data.dtype, 'Send/Recv dtype must match for active P2P' io_sizes[pair_rank] = len(send_data) else: send_data = send_data.new_empty(0) recv_data = recv_data.new_empty(0) dist.all_to_all_single(output=recv_data, input=send_data, output_split_sizes=io_sizes, input_split_sizes=io_sizes)