Source code for deepquantum.distributed

"""Distributed operations"""

from collections import Counter

import torch
import torch.distributed as dist

from .bitmath import all_bits_are_one, flip_bit, flip_bits, get_bit, get_bit_mask, log_base2
from .communication import comm_exchange_arrays, comm_get_world_size
from .qmath import block_sample, evolve_state, measure
from .state import DistributedQubitState


# The 0-th qubit is the rightmost in a ket for the `target`
[docs] def local_gate(state: torch.Tensor, targets: list[int], matrix: torch.Tensor) -> torch.Tensor: """Apply a gate to a state vector locally.""" nqubit = log_base2(len(state)) wires = [nqubit - target - 1 for target in targets] state[:] = evolve_state(state.reshape([1] + [2] * nqubit), matrix, nqubit, wires, 2).reshape(-1) return state
[docs] def local_many_ctrl_one_targ_gate( state: torch.Tensor, controls: list[int], target: int, matrix: torch.Tensor, derivative: bool = False ) -> torch.Tensor: """Apply a multi-control single-qubit gate to a state vector locally. See https://arxiv.org/abs/2311.01512 Alg.3 """ indices = torch.arange(len(state), device=state.device) control_mask = torch.ones_like(indices, dtype=torch.bool) for control in controls: control_mask &= get_bit(indices, control) == 1 mask = control_mask & (get_bit(indices, target) == 0) # Indices where controls are 1 AND target is 0 indices_0 = indices[mask] # Indices where controls are 1 AND target is 1 indices_1 = flip_bit(indices_0, target) amps_0 = state[indices_0] amps_1 = state[indices_1] if derivative: state.zero_() state[indices_0] = matrix[0, 0] * amps_0 + matrix[0, 1] * amps_1 state[indices_1] = matrix[1, 0] * amps_0 + matrix[1, 1] * amps_1 return state
[docs] def local_swap_gate(state: torch.Tensor, target1: int, target2: int) -> torch.Tensor: """Apply a SWAP gate to a state vector locally.""" nqubit = log_base2(len(state)) wire1 = nqubit - target1 - 1 wire2 = nqubit - target2 - 1 state[:] = state.reshape([2] * nqubit).transpose(wire1, wire2).reshape(-1) return state
[docs] def dist_one_targ_gate(state: DistributedQubitState, target: int, matrix: torch.Tensor) -> DistributedQubitState: """Apply a single-qubit gate to a distributed state vector. See https://arxiv.org/abs/2311.01512 Alg.6 """ nqubit_local = state.log_num_amps_per_node if target < nqubit_local: state.amps = local_gate(state.amps, [target], matrix) else: rank_target = target - nqubit_local pair_rank = flip_bit(state.rank, rank_target) comm_exchange_arrays(state.amps, state.buffer, pair_rank) bit = get_bit(state.rank, rank_target) state.amps = matrix[bit, bit] * state.amps + matrix[bit, 1 - bit] * state.buffer return state
[docs] def dist_many_ctrl_one_targ_gate( state: DistributedQubitState, controls: list[int], target: int, matrix: torch.Tensor, derivative: bool = False ) -> DistributedQubitState: """Apply a multi-control single-qubit gate or its derivative to a distributed state vector. See https://arxiv.org/abs/2311.01512 Alg.7 """ prefix_ctrls = [] suffix_ctrls = [] nqubit_local = state.log_num_amps_per_node for q in controls: if q >= nqubit_local: prefix_ctrls.append(q - nqubit_local) else: suffix_ctrls.append(q) if not all_bits_are_one(state.rank, prefix_ctrls): if derivative: state.amps.zero_() comm_exchange_arrays(state.amps, state.buffer, None) return state if target < nqubit_local: state.amps = local_many_ctrl_one_targ_gate(state.amps, suffix_ctrls, target, matrix, derivative) comm_exchange_arrays(state.amps, state.buffer, None) else: if not suffix_ctrls: state = dist_one_targ_gate(state, target, matrix) else: state = dist_ctrl_sub(state, suffix_ctrls, target, matrix, derivative) return state
[docs] def dist_ctrl_sub( state: DistributedQubitState, controls: list[int], target: int, matrix: torch.Tensor, derivative: bool = False ) -> DistributedQubitState: """A subroutine of `dist_many_ctrl_one_targ_gate`. See https://arxiv.org/abs/2311.01512 Alg.8 """ rank_target = target - state.log_num_amps_per_node pair_rank = flip_bit(state.rank, rank_target) indices = torch.arange(state.num_amps_per_node, device=state.amps.device) control_mask = torch.ones_like(indices, dtype=torch.bool) for control in controls: control_mask &= get_bit(indices, control) == 1 # Indices where controls are 1 indices = indices[control_mask] send = state.amps[indices].contiguous() recv = state.buffer[: len(send)] comm_exchange_arrays(send, recv, pair_rank) if derivative: state.amps.zero_() bit = get_bit(state.rank, rank_target) state.amps[indices] = matrix[bit, bit] * send + matrix[bit, 1 - bit] * recv return state
[docs] def dist_swap_gate(state: DistributedQubitState, qb1: int, qb2: int): """Apply a SWAP gate to a distributed state vector. See https://arxiv.org/abs/2311.01512 Alg.9 """ if qb1 > qb2: qb1, qb2 = qb2, qb1 nqubit_local = state.log_num_amps_per_node if qb2 < nqubit_local: state.amps = local_swap_gate(state.amps, qb1, qb2) # comm_exchange_arrays(state.amps, state.buffer, None) elif qb1 >= nqubit_local: qb1_rank = qb1 - nqubit_local qb2_rank = qb2 - nqubit_local if get_bit(state.rank, qb1_rank) != get_bit(state.rank, qb2_rank): pair_rank = flip_bits(state.rank, [qb1_rank, qb2_rank]) comm_exchange_arrays(state.amps, state.buffer, pair_rank) state.amps = state.buffer else: qb2_rank = qb2 - nqubit_local bit = 1 - get_bit(state.rank, qb2_rank) pair_rank = flip_bit(state.rank, qb2_rank) indices = torch.arange(state.num_amps_per_node, device=state.amps.device) mask = get_bit(indices, qb1) == bit indices = indices[mask] send = state.amps[indices].contiguous() recv = state.buffer[: len(send)] comm_exchange_arrays(send, recv, pair_rank) state.amps[indices] = recv return state
[docs] def get_local_targets(targets: list[int], nqubit_local: int) -> list[int]: """Map global target qubits to available local indices for distributed gates.""" mask = get_bit_mask(targets) min_non_targ = 0 while get_bit(mask, min_non_targ): min_non_targ += 1 targets_new = [] for target in targets: if target < nqubit_local: targets_new.append(target) else: targets_new.append(min_non_targ) min_non_targ += 1 while get_bit(mask, min_non_targ): min_non_targ += 1 return targets_new
[docs] def dist_many_targ_gate( state: DistributedQubitState, targets: list[int], matrix: torch.Tensor ) -> DistributedQubitState: """Apply a multi-qubit gate to a distributed state vector. See https://arxiv.org/abs/2311.01512 Alg.10 """ nqubit_local = state.log_num_amps_per_node nt = len(targets) assert nt <= nqubit_local if max(targets) < nqubit_local: state.amps = local_gate(state.amps, targets, matrix) # comm_exchange_arrays(state.amps, state.buffer, None) else: targets_new = get_local_targets(targets, nqubit_local) for i in range(nt): if targets_new[i] != targets[i]: dist_swap_gate(state, targets_new[i], targets[i]) state.amps = local_gate(state.amps, targets_new, matrix) for i in range(nt): if targets_new[i] != targets[i]: dist_swap_gate(state, targets_new[i], targets[i]) return state
[docs] def measure_dist( state: DistributedQubitState, shots: int = 1024, with_prob: bool = False, wires: int | list[int] | None = None, block_size: int = 2**24, ) -> dict: """Measure a distributed state vector.""" if state.world_size == 1: return measure(state.amps, shots, with_prob, wires, False, block_size) else: nqubit_local = state.log_num_amps_per_node nqubit_global = state.log_num_nodes probs = torch.abs(state.amps) ** 2 if isinstance(wires, int): wires = [wires] num_bits = len(wires) if wires else state.nqubit if wires is not None: probs = probs.reshape([2] * nqubit_local) targets = [state.nqubit - wire - 1 for wire in wires] pm_shape = list(range(nqubit_local)) # Assume nqubit_global < nqubit_local if num_bits <= nqubit_local: # All targets move to local qubits if max(targets) >= nqubit_local: targets_new = get_local_targets(targets, nqubit_local) for i in range(num_bits): if targets_new[i] != targets[i]: dist_swap_gate(state, targets[i], targets_new[i]) wires_local = sorted([nqubit_local - target - 1 for target in targets_new]) else: wires_local = sorted([nqubit_local - target - 1 for target in targets]) for w in wires_local: pm_shape.remove(w) pm_shape = wires_local + pm_shape probs = probs.permute(pm_shape).reshape([2] * num_bits + [-1]).sum(-1).reshape(-1) dist.all_reduce(probs, dist.ReduceOp.SUM) if state.rank == 0: samples = Counter(block_sample(probs, shots, block_size)) results = {bin(key)[2:].zfill(num_bits): value for key, value in samples.items()} if with_prob: for k in results: index = int(k, 2) results[k] = results[k], probs[index].item() return results return {} else: # All targets are sorted, then move to global qubits targets_sort = sorted(targets, reverse=True) wires_local = [] for i, target in enumerate(targets_sort): if i < nqubit_global: target_new = state.nqubit - i - 1 if target_new != target: dist_swap_gate(state, target, target_new) else: wires_local.append(nqubit_local - target - 1) for w in wires_local: pm_shape.remove(w) pm_shape = wires_local + pm_shape probs = probs.permute(pm_shape).reshape([2] * len(wires_local) + [-1]).sum(-1).reshape(-1) probs_rank = probs.new_empty(state.world_size) dist.all_gather_into_tensor(probs_rank, probs.sum().unsqueeze(0)) blocks = torch.multinomial(probs_rank, shots, replacement=True) dist.broadcast(blocks, src=0) block_dict = Counter(blocks.cpu().numpy()) key_offset = state.rank << (num_bits - nqubit_global) if state.rank in block_dict: samples = Counter(block_sample(probs, block_dict[state.rank], block_size)) results = {key + key_offset: value for key, value in samples.items()} else: results = {} if with_prob: for k in results: index = k - key_offset results[k] = results[k], probs[index].item() results_lst = [None] * state.world_size dist.all_gather_object(results_lst, results) if state.rank == 0: results = {bin(key)[2:].zfill(num_bits): value for r in results_lst for key, value in r.items()} return results else: return {}
[docs] def inner_product_dist(bra: DistributedQubitState, ket: DistributedQubitState) -> torch.Tensor: """Get the inner product of two distributed state vectors.""" world_size = comm_get_world_size() value = bra.amps.conj() @ ket.amps if world_size > 1: dist.all_reduce(value, dist.ReduceOp.SUM) return value