communication#

Communication utilities

Functions

cleanup_distributed()

Clean up the distributed environment.

comm_exchange_arrays(send_data, recv_data, ...)

Exchange tensor data with a peer rank using collective communication.

comm_get_rank()

Get the rank of the current process.

comm_get_world_size()

Get the total number of processes.

setup_distributed([backend, port])

Initialize torch.distributed.

cleanup_distributed() None[source]#

Clean up the distributed environment.

comm_exchange_arrays(send_data: Tensor, recv_data: Tensor, pair_rank: int | None) None[source]#

Exchange tensor data with a peer rank using collective communication.

This performs a point-to-point communication via dist.all_to_all_single and allows specific ranks to participate in the collective call without active data transfer by setting pair_rank to None.

Parameters:
  • send_data (Tensor) – Data to be sent to the pair_rank. If pair_rank is None, this can be an empty tensor with correct dtype and device.

  • recv_data (Tensor) – Pre-allocated buffer to store received data. Must match send_data in shape and dtype if pair_rank is active. If pair_rank is None, this can be an empty tensor.

  • pair_rank (int | None) – The target rank for exchange, or None to remain quiescent during the collective call.

comm_get_rank() int[source]#

Get the rank of the current process.

comm_get_world_size() int[source]#

Get the total number of processes.

setup_distributed(backend: str = 'nccl', port: str = '29500') tuple[int, int, int][source]#

Initialize torch.distributed.