Source code for deepquantum.photonic.qmath

"""Common functions"""

import itertools
import warnings
from collections import Counter
from collections.abc import Generator

import matplotlib.pyplot as plt
import torch
from matplotlib import cm
from torch import vmap
from torch.distributions.multivariate_normal import MultivariateNormal

import deepquantum.photonic as dqp

from ..qmath import block_sample, decimal_to_list, is_unitary, list_to_decimal, partial_trace
from .utils import mem_to_chunksize


[docs] def dirac_rep(state: torch.Tensor, den_mat: bool = False, topk: int = 5) -> dict: """Convert the batched Fock state tensors to the dictionary of Dirac representation.""" dirac_dict = {} for i in range(state.shape[0]): # consider batch state_i = state[i] abs_state = abs(state_i).flatten() top_vals, top_indices = torch.topk(abs_state, k=min(len(abs_state), topk)) coords = torch.stack(torch.unravel_index(top_indices, state_i.shape), dim=1) dirac_lst = [] for val, idx_coords in zip(top_vals, coords, strict=True): if val <= 1e-5: continue idx = idx_coords.tolist() use_comma = any(x > 9 for x in idx) coeff = state_i[tuple(idx)].item() if den_mat: idx1 = idx[: len(idx) // 2] idx2 = idx[len(idx) // 2 :] state_b1 = ','.join(map(str, idx1)) if use_comma else ''.join(map(str, idx1)) state_b2 = ','.join(map(str, idx2)) if use_comma else ''.join(map(str, idx2)) state_str = f'{coeff:+6.3f}|{state_b1}><{state_b2}|' else: state_b = ','.join(map(str, idx)) if use_comma else ''.join(map(str, idx)) state_str = f'{coeff:+6.3f}|{state_b}>' dirac_lst.append(state_str) dirac = ' '.join(dirac_lst) dirac_dict[f'state_{i}'] = dirac[1:] if dirac[0] == '+' else dirac return dirac_dict
[docs] def sort_dict_fock_basis(state_dict: dict, idx: int = 0) -> dict: """Sort the dictionary of Fock basis states in the descending order of probabilities.""" sort_list = sorted(state_dict.items(), key=lambda t: abs(t[1][idx]), reverse=True) sorted_dict = {} for key, value in sort_list: sorted_dict[key] = value return sorted_dict
[docs] def sub_matrix(u: torch.Tensor, input_state: torch.Tensor, output_state: torch.Tensor) -> torch.Tensor: """Get the submatrix for calculating the transfer amplitude and transfer probabilities. The rows are chosen according to the output state and the columns are chosen according to the input state. Args: u: The unitary matrix. input_state: The input state. output_state: The output state. """ with warnings.catch_warnings(): warnings.filterwarnings('ignore') # local warning u1 = torch.repeat_interleave(u, output_state, dim=0) u2 = torch.repeat_interleave(u1, input_state, dim=-1) return u2
[docs] def permanent(mat: torch.Tensor) -> torch.Tensor: """Calculate the permanent.""" shape = mat.shape if mat.numel() == 0: if shape[0] == shape[1] == 0: return torch.tensor(1, dtype=mat.dtype, device=mat.device) else: return torch.tensor(0, dtype=mat.dtype, device=mat.device) if len(mat.size()) == 0: return mat if shape[0] == 1: return mat[0, 0] if shape[0] == 2: return mat[0, 0] * mat[1, 1] + mat[0, 1] * mat[1, 0] if shape[0] == 3: return ( mat[0, 2] * mat[1, 1] * mat[2, 0] + mat[0, 1] * mat[1, 2] * mat[2, 0] + mat[0, 2] * mat[1, 0] * mat[2, 1] + mat[0, 0] * mat[1, 2] * mat[2, 1] + mat[0, 1] * mat[1, 0] * mat[2, 2] + mat[0, 0] * mat[1, 1] * mat[2, 2] ) return permanent_ryser(mat)
[docs] def create_subset(num_coincidence: int) -> Generator[torch.Tensor, None, None]: r"""Create all subsets from :math:`\{1,2,...,n\}`.""" for k in range(1, num_coincidence + 1): comb_lst = [] for comb in itertools.combinations(range(num_coincidence), k): comb_lst.append(list(comb)) yield torch.tensor(comb_lst).reshape(len(comb_lst), k)
[docs] def get_powerset(n: int) -> list: r"""Get the powerset of :math:`\{0,1,...,n-1\}`.""" powerset = [] for k in range(n + 1): subset = [] for i in itertools.combinations(range(n), k): subset.append(list(i)) powerset.append(subset) return powerset
[docs] def permanent_ryser(mat: torch.Tensor) -> torch.Tensor: """Calculate the permanent by Ryser's formula.""" def helper(subset: torch.Tensor, mat: torch.Tensor) -> torch.Tensor: num_elements = subset.numel() s = torch.sum(mat[:, subset], dim=-1) value_times = torch.prod(s) * (-1) ** num_elements return value_times num_coincidence = mat.size()[0] value_perm = 0 chunk_size = mem_to_chunksize(mat.device, mat.dtype) for subset in create_subset(num_coincidence): temp_value = vmap(helper, in_dims=(0, None), chunk_size=chunk_size)(subset, mat) value_perm += temp_value.sum() value_perm *= (-1) ** num_coincidence return value_perm
[docs] def product_factorial(state: torch.Tensor) -> torch.Tensor: """Get the product of the factorial from the Fock state, i.e., :math:`|s_1,s_2,...s_n> -> s_1!s_2!...s_n!`.""" state = state + 0.0 # nature log gamma function return torch.exp(torch.lgamma(state.cpu().double() + 1).sum(-1, keepdim=True)).to(state.device, state.dtype)
[docs] def fock_combinations(nmode: int, nphoton: int, cutoff: int | None = None, nancilla: int = 0) -> list[list[int]]: """Generate all possible combinations of Fock states for a given number of modes, photons, and cutoff. Args: nmode: The number of modes in the system. nphoton: The total number of photons in the system. cutoff: The Fock space truncation. Default: ``None`` nancilla: The number of ancilla modes (NOT limited by ``cutoff``). Default: ``0`` Returns: A list of all possible Fock states, each represented by a list of occupation numbers for each mode. Examples: >>> fock_combinations(2, 3) [[0, 3], [1, 2], [2, 1], [3, 0]] >>> fock_combinations(3, 2) [[0, 0, 2], [0, 1, 1], [0, 2, 0], [1, 0, 1], [1, 1, 0], [2, 0, 0]] >>> fock_combinations(4, 4, 2) [[1, 1, 1, 1]] """ if cutoff is None: cutoff = nphoton + 1 result = [] def backtrack(state: list[int], length: int, num_sum: int) -> None: """A helper function that uses backtracking to generate all possible Fock states. Args: state: The current Fock state being constructed. length: The remaining number of modes to be filled. num_sum: The remaining number of photons to be distributed. """ if length == 0: if num_sum == 0: result.append(state) return # Determine the effective length for cutoff effective_length = length - nancilla # skip iterations if remaining photons exceed the remaining cutoff if nancilla == 0 and num_sum > (cutoff - 1) * effective_length: return for i in range(min((num_sum + 1), cutoff) if effective_length > 0 else (num_sum + 1)): backtrack(state + [i], length - 1, num_sum - i) backtrack([], nmode, nphoton) return result
[docs] def ladder_ops(cutoff: int, dtype=torch.cfloat, device='cpu') -> tuple[torch.Tensor, torch.Tensor]: """Get the matrix representation of the annihilation and creation operators.""" sqrt = torch.arange(1, cutoff, dtype=dtype, device=device) ** 0.5 a = torch.diag(sqrt, diagonal=1) ad = a.mH # share the memory return a, ad
[docs] def shift_func(lst: list, nstep: int) -> list: """Shift a list by a number of steps. If ``nstep`` is positive, it shifts to the left. """ if len(lst) <= 1: return lst nstep = nstep % len(lst) return lst[nstep:] + lst[:nstep]
[docs] def xxpp_to_xpxp(matrix: torch.Tensor) -> torch.Tensor: """Transform the representation in ``xxpp`` ordering to the representation in ``xpxp`` ordering.""" nmode = matrix.shape[-2] // 2 idx = torch.arange(2 * nmode, device=matrix.device).reshape(2, nmode).T.flatten() if matrix.shape[-1] == 2 * nmode: return matrix[..., idx[:, None], idx] elif matrix.shape[-1] == 1: return matrix[..., idx, :]
[docs] def xpxp_to_xxpp(matrix: torch.Tensor) -> torch.Tensor: """Transform the representation in ``xpxp`` ordering to the representation in ``xxpp`` ordering.""" nmode = matrix.shape[-2] // 2 idx = torch.arange(2 * nmode, device=matrix.device).reshape(nmode, 2).T.flatten() if matrix.shape[-1] == 2 * nmode: return matrix[..., idx[:, None], idx] elif matrix.shape[-1] == 1: return matrix[..., idx, :]
[docs] def quadrature_to_ladder(tensor: torch.Tensor, symplectic: bool = False) -> torch.Tensor: """Transform the representation in ``xxpp`` ordering to the representation in ``aaa^+a^+`` ordering. Args: tensor: The input tensor in ``xxpp`` ordering. symplectic: Whether the transformation is applied for symplectic matrix or Gaussian state. Default: ``False`` (which means covariance matrix or displacement vector) """ nmode = tensor.shape[-2] // 2 tensor = tensor + 0j identity = torch.eye(nmode, dtype=tensor.dtype, device=tensor.device) omega = torch.cat([torch.cat([identity, identity * 1j], dim=-1), torch.cat([identity, identity * -1j], dim=-1)]) if tensor.shape[-1] == 2 * nmode: if symplectic: return omega @ tensor @ omega.mH / 2 # inversed omega else: return omega @ tensor @ omega.mH * dqp.kappa**2 / dqp.hbar elif tensor.shape[-1] == 1: return omega @ tensor * dqp.kappa / dqp.hbar**0.5
[docs] def ladder_to_quadrature(tensor: torch.Tensor, symplectic: bool = False) -> torch.Tensor: """Transform the representation in ``aaa^+a^+`` ordering to the representation in ``xxpp`` ordering. Args: tensor: The input tensor in ``aaa^+a^+`` ordering. symplectic: Whether the transformation is applied for symplectic matrix or Gaussian state. Default: ``False`` (which means covariance matrix or displacement vector) """ nmode = tensor.shape[-2] // 2 tensor = tensor + 0j identity = torch.eye(nmode, dtype=tensor.dtype, device=tensor.device) omega = torch.cat([torch.cat([identity, identity], dim=-1), torch.cat([identity * -1j, identity * 1j], dim=-1)]) if tensor.shape[-1] == 2 * nmode: if symplectic: return (omega @ tensor @ omega.mH).real / 2 # inversed omega else: return (omega @ tensor @ omega.mH).real * dqp.hbar / (4 * dqp.kappa**2) elif tensor.shape[-1] == 1: return (omega @ tensor).real * dqp.hbar**0.5 / (2 * dqp.kappa)
def _photon_number_mean_var_gaussian(cov: torch.Tensor, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Get the expectation value and variance of the photon number for single-mode Gaussian states.""" coef = dqp.kappa**2 / dqp.hbar cov = cov.reshape(-1, 2, 2) mean = mean.reshape(-1, 2, 1) exp = coef * (vmap(torch.trace)(cov) + (mean.mT @ mean).squeeze()) - 1 / 2 var = coef**2 * (vmap(torch.trace)(cov @ cov) + 2 * (mean.mT @ cov.to(mean.dtype) @ mean).squeeze()) * 2 - 1 / 4 return exp, var def _photon_number_mean_var_bosonic( cov: torch.Tensor, mean: torch.Tensor, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Get the expectation value and variance of the photon number for single-mode Bosonic states.""" shape_cov = cov.shape shape_mean = mean.shape cov = cov.reshape(*shape_cov[:2], 2, 2).reshape(-1, 2, 2) mean = mean.reshape(*shape_mean[:2], 2, 1).reshape(-1, 2, 1) exp_gaussian, var_gaussian = _photon_number_mean_var_gaussian(cov, mean) exp_gaussian = exp_gaussian.reshape(shape_cov[:2]) var_gaussian = var_gaussian.reshape(shape_cov[:2]) exp = (weight * exp_gaussian).sum(-1) var = (weight * var_gaussian).sum(-1) + (weight * exp_gaussian**2).sum(-1) - exp**2 zeros = cov.new_zeros(1) assert torch.allclose(exp.imag, zeros, atol=1e-6) assert torch.allclose(var.imag, zeros, atol=1e-6) return exp.real, var.real
[docs] def photon_number_mean_var_cv( cov: torch.Tensor, mean: torch.Tensor, weight: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Get the expectation value and variance of the photon number for single-mode Gaussian (Bosonic) states.""" if weight is None: return _photon_number_mean_var_gaussian(cov, mean) else: return _photon_number_mean_var_bosonic(cov, mean, weight)
[docs] def photon_number_mean_var_fock( state: torch.Tensor, nmode: int, cutoff: int, wires: list[int], den_mat: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: """Get the expectation value and variance of the photon number for Fock state tensors.""" if den_mat: rho = state.reshape(-1, cutoff**nmode, cutoff**nmode) prob = torch.diagonal(rho, dim1=1, dim2=2).reshape([-1] + [cutoff] * nmode).real else: if state.ndim == nmode: state = state.unsqueeze(0) prob = abs(state) ** 2 num_op = torch.arange(cutoff, device=state.device) num_exp_list = [] var_list = [] for i in wires: p_i = torch.sum(prob, dim=[j + 1 for j in range(nmode) if j != i]) num_exp = (num_op * p_i).sum(dim=-1) # (batch,) num2_exp = ((num_op**2) * p_i).sum(dim=-1) # (batch,) var = num2_exp - num_exp**2 num_exp_list.append(num_exp) var_list.append(var) return torch.stack(num_exp_list), torch.stack(var_list)
[docs] def quadrature_mean_fock( state: torch.Tensor, nmode: int, cutoff: int, wires: list[int], den_mat: bool = False ) -> torch.Tensor: """Get the expectation value of the quadrature x for Fock state tensors.""" coef = 2 * dqp.kappa**2 / dqp.hbar factor = torch.sqrt(torch.arange(1, cutoff, device=state.device, dtype=state.real.dtype) / 2) mean = [] if den_mat: state = state.reshape(-1, cutoff**nmode, cutoff**nmode) for wire in wires: trace_lst = [i for i in range(nmode) if i != wire] reduced_dm = partial_trace(state, nmode, trace_lst, cutoff) # (batch, cutoff, cutoff) reduced_dm = reduced_dm.reshape(-1, cutoff, cutoff) off_diag = reduced_dm.diagonal(offset=1, dim1=1, dim2=2) # rho_{n, n+1} term = factor * 2 * off_diag.real # only with real part contribution mean.append(term.sum(dim=1)) else: if state.ndim == nmode: state = state.unsqueeze(0) factor = factor.view([1, -1] + [1] * (nmode - 1)) for wire in wires: pm_shape = list(range(1, nmode + 1)) pm_shape.remove(wire + 1) pm_shape = [0] + [wire + 1] + pm_shape state_i = state.permute(pm_shape) cn = state_i[:, :-1, ...] # n cn1 = state_i[:, 1:, ...] # n+1 term = factor * 2 * (cn.conj() * cn1).real mean.append(term.sum(dim=tuple(range(1, nmode + 1)))) return coef ** (-0.5) * torch.stack(mean)
[docs] def takagi(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Tagaki decomposition for a symmetric complex matrix. See https://math.stackexchange.com/questions/2026110/ """ size = a.size()[0] a_2 = torch.block_diag(-a.real, a.real) if torch.is_complex(a): a_2[:size, size:] = a.imag a_2[size:, :size] = a.imag s, u = torch.linalg.eigh(a_2) diag = s[size:] # s already sorted v = u[:, size:][size:] + 1j * u[:, size:][:size] if is_unitary(v): return v, diag else: # consider degeneracy case idx_zero = torch.where(abs(s) < 1e-5)[0] idx_max = max(idx_zero) + 1 temp = abs(u[:size, idx_max:]) ** 2 + abs(u[size:, idx_max:]) ** 2 sum_rhalf = temp.sum(1) idx_lt_1 = torch.where(abs(sum_rhalf - 1) > 1e-6)[0] r = size - (2 * size - idx_max) # find the correct combination for i in itertools.combinations(idx_zero, r): u_temp = u[:, list(i)] temp2 = abs(u_temp[idx_lt_1]) ** 2 + abs(u_temp[idx_lt_1 + size]) ** 2 sum_lhalf = temp2.sum(1) sum_total = sum_lhalf + sum_rhalf[idx_lt_1] if torch.allclose(sum_total, torch.ones(len(idx_lt_1), dtype=sum_total.dtype, device=sum_total.device)): u_half = torch.cat([u[:, list(i)], u[:, idx_max:]], dim=1) v = u_half[size:] + 1j * u_half[:size] if is_unitary(v): return v, diag
[docs] def sqrtm_herm(mat: torch.Tensor) -> torch.Tensor: """Compute the positive matrix square root of a Hermitian matrix using eigenvalue decomposition.""" lambd, mat_q = torch.linalg.eigh(mat) return mat_q @ lambd.sqrt().diag_embed().to(mat_q.dtype) @ mat_q.mH
[docs] def schur_anti_symm_even(mat: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: r"""Schur decomposition for a real antisymmetric and even-dimensional matrix. This function decomposes a real antisymmetric matrix :math:`A` into the form :math:`A = O T O^T`, where :math:`O` is an orthogonal matrix and :math:`T` is a block-diagonal matrix with :math:`2 \times 2` antisymmetric blocks. """ assert torch.allclose(mat, -mat.mT, rtol=1e-5, atol=1e-5) n = len(mat) hermitian = mat * -1j lambd, u = torch.linalg.eigh(hermitian) mat_t = torch.zeros_like(mat) idx1 = torch.arange(0, n, 2, device=mat.device) idx2 = torch.arange(1, n, 2, device=mat.device) # positive value is above the diagonal and in ascending order mat_t[idx1, idx2] = lambd[n // 2 :] mat_t[idx2, idx1] = -lambd[n // 2 :] mat_o = torch.zeros_like(mat) mat_o[:, ::2] = u[:, n // 2 :].real mat_o[:, 1::2] = u[:, n // 2 :].imag norm = torch.linalg.vector_norm(mat_o, dim=0, keepdim=True) mat_o = mat_o / norm return mat_t, mat_o
[docs] def williamson(cov: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Williamson decomposition. This function decomposes a real symmetric and even-dimensional positive definite matrix :math:`V` into the form :math:`V = S D S^T`, where :math:`S` is a symplectic matrix and :math:`D` is a diagonal matrix with the symplectic eigenvalues. See https://arxiv.org/pdf/2403.04596 Section VII. """ assert torch.allclose(cov, cov.mT, rtol=1e-5, atol=1e-5) nmode = cov.shape[-1] // 2 omega = cov.new_ones(nmode) omega = torch.cat([-omega, omega]).diag_embed() omega = omega.reshape(2, nmode, 2 * nmode).flip(0).reshape(2 * nmode, 2 * nmode) # symplectic form vals = torch.linalg.eigvalsh(cov) assert torch.all(vals > 0), 'Matrix must be positive definite.' cov_sqrt = sqrtm_herm(cov) cov_sqrt_inv = cov_sqrt.inverse() psi = cov_sqrt_inv @ omega @ cov_sqrt_inv # antisymmetric mat_t, o_tilde = schur_anti_symm_even(psi) idx_perm = torch.arange(2 * nmode, device=cov.device).reshape(nmode, 2).T.flatten() mat_t_xxpp = mat_t[:, idx_perm][idx_perm] mat_o = o_tilde[:, idx_perm] idx = torch.arange(nmode, device=cov.device) phi = mat_t_xxpp[idx, idx + nmode] phi2 = torch.cat([phi, phi]) diag = (1 / phi2).diag_embed() mat_s = cov_sqrt @ mat_o @ phi2.sqrt().diag_embed() return diag, mat_s
[docs] def measure_fock_tensor( state: torch.Tensor, shots: int = 1024, with_prob: bool = False, wires: int | list[int] | None = None, block_size: int = 2**24, ) -> dict | list[dict]: r"""Measure the batched Fock state tensors. Args: state: The quantum state to measure. It should be a tensor of shape :math:`(\text{batch}, \text{cutoff}, ..., \text{cutoff})`. shots: The number of times to sample from the quantum state. Default: 1024 with_prob: A flag that indicates whether to return the probabilities along with the number of occurrences. Default: ``False`` wires: The wires to measure. It can be an integer or a list of integers specifying the indices of the wires. Default: ``None`` (which means all wires are measured) block_size: The block size for sampling. Default: 2**24 """ from .state import FockState shape = state.shape batch = shape[0] cutoff = shape[-1] nmode = len(shape) - 1 if wires is not None: if isinstance(wires, int): wires = [wires] assert isinstance(wires, list) wires = sorted(wires) pm_shape = list(range(nmode)) for w in wires: pm_shape.remove(w) pm_shape = wires + pm_shape nwires = len(wires) if wires else nmode results_tot = [] for i in range(batch): probs = torch.abs(state[i]) ** 2 if wires is not None: probs = probs.permute(pm_shape).reshape([cutoff] * nwires + [-1]).sum(-1) probs = probs.reshape(-1) # Perform block sampling to reduce memory consumption samples = Counter(block_sample(probs, shots, block_size)) results = {FockState(decimal_to_list(key, cutoff, nwires)): value for key, value in samples.items()} if with_prob: for k in results: index = list_to_decimal(k.state, cutoff) results[k] = results[k], probs[index] results_tot.append(results) if batch == 1: return results_tot[0] else: return results_tot
[docs] def sample_homodyne_fock( state: torch.Tensor, wire: int, nmode: int, cutoff: int, shots: int = 1, den_mat: bool = False, x_range: float = 15, nbin: int = 100000, ) -> torch.Tensor: """Get the samples of homodyne measurement for batched Fock state tensors on one mode.""" coef = 2 * dqp.kappa**2 / dqp.hbar if den_mat: state = state.reshape(-1, cutoff**nmode, cutoff**nmode) else: state = state.reshape(-1, cutoff**nmode, 1) state = state @ state.mH trace_lst = [i for i in range(nmode) if i != wire] reduced_dm = partial_trace(state, nmode, trace_lst, cutoff) # (batch, cutoff, cutoff) orders = torch.arange(cutoff, dtype=state.real.dtype, device=state.device).reshape(-1, 1) # (cutoff, 1) # with dimension \sqrt{m\omega\hbar} xs = torch.linspace(-x_range, x_range, nbin, dtype=state.real.dtype, device=state.device) # (nbin) h_vals = torch.special.hermite_polynomial_h(coef**0.5 * xs, orders) # (cutoff, nbin) # H_n / \sqrt{2^n * n!} factorial = torch.exp(torch.lgamma(orders.cpu().double() + 1)).to(orders.device, orders.dtype) h_vals = h_vals / torch.sqrt(2**orders * factorial) h_mat = h_vals.reshape(1, cutoff, nbin) * h_vals.reshape(cutoff, 1, nbin) # (cutoff, cutoff, nbin) h_terms = reduced_dm.unsqueeze(-1) * h_mat # (batch, cutoff, cutoff, nbin) probs = (h_terms.sum(dim=[-3, -2]) * torch.exp(-coef * xs**2)).real # (batch, nbin) probs = abs(probs) probs[probs < 1e-10] = 0 indices = torch.multinomial(probs.reshape(-1, nbin), num_samples=shots, replacement=True) # (batch, shots) samples = xs[indices] return samples.unsqueeze(-1) # (batch, shots, 1)
[docs] def sample_reject_bosonic( cov: torch.Tensor, mean: torch.Tensor, weight: torch.Tensor, cov_m: torch.Tensor, shots: int ) -> torch.Tensor: """Get the samples of the Bosonic states via rejection sampling. See https://arxiv.org/abs/2103.05530 Algorithm 1 in Section VI B """ if cov.ndim == 3: cov = cov.unsqueeze(0) if mean.ndim == 3: mean = mean.unsqueeze(0) if weight.ndim == 1: weight = weight.unsqueeze(0) assert cov.ndim == mean.ndim == 4 assert weight.ndim == 2 batch = cov.shape[0] rst = [cov.new_empty(0)] * batch batches = list(range(batch)) count_shots = [0] * batch shots_tmp = shots mask = (weight.real > 0) | (abs(weight.imag) > 1e-8) | (abs(mean.imag) > 1e-8).any(-2).squeeze(-1) exp_real = torch.exp(mean.imag.mT @ torch.linalg.solve(cov_m + cov, mean.imag) / 2).squeeze(-2, -1) c_tilde = mask * abs(weight) * exp_real while len(batches) > 0: cov_rest = cov[batches] mean_rest = mean[batches] cov_t = cov_m + cov_rest m0 = torch.multinomial(c_tilde[batches], 1).reshape(-1) # (batch) cov_m0 = cov[batches, m0] mean_m0 = mean[batches, m0].squeeze(-1).real dist_g = MultivariateNormal(mean_rest.squeeze(-1).real, cov_t) # (batch, ncomb, 2 * nmode) r0 = MultivariateNormal(mean_m0, cov_m + cov_m0).sample([shots_tmp]) # (shots, batch, 2 * nmode) prob_g = dist_g.log_prob(r0.unsqueeze(-2)).exp() # (shots, batch, ncomb, 2 * nmode) -> (shots, batch, ncomb) g_r0 = (c_tilde[batches] * prob_g).sum(-1) # (shots, batch) y0 = torch.rand_like(g_r0) * g_r0 rm = r0.unsqueeze(-1).unsqueeze(-3) # (shots, batch, 2 * nmode) -> (shots, batch, 1, 2 * nmode, 1) # (shots, batch, ncomb) exp_imag = torch.exp((rm - mean_rest.real).mT @ torch.linalg.solve(cov_t, mean_rest.imag) * 1j).squeeze() # Eq.(70-71) p_r0 = (weight[batches] * exp_real[batches] * prob_g * exp_imag).sum(-1) # (shots, batch) assert torch.allclose(p_r0.imag, p_r0.imag.new_zeros(1)) idx_shots, idx_batch = torch.where(y0 <= p_r0.real) batches_done = [] for i in range(len(batches)): idx = batches[i] rst[idx] = torch.cat([rst[idx], r0[idx_shots[idx_batch == i], i]]) # (shots, 2 * nmode) count_shots[idx] = len(rst[idx]) if count_shots[idx] >= shots: batches_done.append(idx) rst[idx] = rst[idx][:shots] for i in batches_done: batches.remove(i) shots_tmp = shots - min(count_shots) return torch.stack(rst) # (batch, shots, 2 * nmode)
[docs] def align_shape(cov: torch.Tensor, mean: torch.Tensor, weight: torch.Tensor) -> list[torch.Tensor]: """Align the shape for Bosonic state.""" ncomb = weight.shape[-1] if cov.ndim == mean.ndim == 4 and weight.ndim == 2: if cov.shape[1] == 1: cov = cov.expand(-1, ncomb, -1, -1) if mean.shape[1] == 1: mean = mean.expand(-1, ncomb, -1, -1) if weight.shape[0] == 1: weight = weight.expand(cov.shape[0], -1) elif cov.ndim == mean.ndim == 3 and weight.ndim == 1: if cov.shape[0] == 1: cov = cov.expand(ncomb, -1, -1) if mean.shape[0] == 1: mean = mean.expand(ncomb, -1, -1) return [cov, mean, weight]
[docs] def fock_to_wigner( state: torch.Tensor, wire: int, nmode: int, cutoff: int, den_mat: bool = False, xrange: int | list = 10, prange: int | list = 10, npoints: int | list = 100, plot: bool = True, k: int = 0, ) -> torch.Tensor: """Get the discretized Wigner function of the specified mode from a Fock state using the iterative method. See https://qutip.org/docs/4.7/modules/qutip/wigner.html Args: state: The input Fock state tensor or density matrix. wire: The Wigner function for the given wire. nmode: The mode number of the Fock state. cutoff: The Fock space truncation. den_mat: Whether to use density matrix representation. Only valid for Fock state tensor. Default: ``False`` xrange: The range of quadrature x. Default: 10 prange: The range of quadrature p. Default: 10 npoints: The number of discretization points for quadratures. Default: 100 plot: Whether to plot the Wigner function. Default: ``True`` k: The index of the Wigner function within the batch to plot. Default: 0 """ if den_mat: rho = state.reshape(-1, cutoff**nmode, cutoff**nmode) else: state = state.reshape(-1, cutoff**nmode, 1) rho = state @ state.mH trace_lst = [i for i in range(nmode) if i != wire] reduced_dm = partial_trace(rho, nmode, trace_lst, cutoff) # (batch, cutoff, cutoff) if reduced_dm.ndim == 2: reduced_dm = reduced_dm.unsqueeze(0) xlist = [-xrange, xrange] if isinstance(xrange, int) else xrange plist = [-prange, prange] if isinstance(prange, int) else prange if isinstance(npoints, int): xlist.append(npoints) plist.append(npoints) else: xlist.append(npoints[0]) plist.append(npoints[1]) assert len(xlist) == len(plist) == 3 xvec = torch.linspace(*xlist, dtype=state.real.dtype, device=state.device) pvec = torch.linspace(*plist, dtype=state.real.dtype, device=state.device) coef = 2 * dqp.kappa**2 / dqp.hbar xlist, plist = torch.meshgrid(xvec, pvec, indexing='ij') # alpha = (sqrt(2) * kappa / sqrt(hbar)) * (q + i p) / sqrt(2) alpha = coef**0.5 * (xlist + 1.0j * plist) / 2**0.5 w_list = xlist.new_zeros(cutoff, xlist.shape[-2], xlist.shape[-1]) * 1j w_00 = coef * torch.exp(-2 * abs(alpha) ** 2) / torch.pi w_list[0] = w_00 w = reduced_dm[:, 0, 0].reshape(-1, 1, 1) * w_list[0] # First row: W_{0i} for i in range(1, cutoff): # For numerical stability, it is recommended to use cutoff < 80 w_list[i] = 2 * alpha * w_list[i - 1] / rho.new_tensor(i).sqrt() w += 2 * (reduced_dm[:, 0, i].reshape(-1, 1, 1) * w_list[i]).real # Remaining rows: W_{ij}, i ≥ 1 for i in range(1, cutoff): # Diagonal element W_{ii} sqrt_i = i**0.5 temp = w_list[i].clone() w_list[i] = (2 * alpha.conj() * temp - sqrt_i * w_list[i - 1]) / sqrt_i w += reduced_dm[:, i, i].reshape(-1, 1, 1) * w_list[i] # Off-diagonal elements W_{ij}, j > i for j in range(i + 1, cutoff): sqrt_j = j**0.5 temp2 = (2 * alpha * w_list[j - 1] - sqrt_i * temp) / sqrt_j temp = w_list[j].clone() w_list[j] = temp2 w += 2 * (reduced_dm[:, i, j].reshape(-1, 1, 1) * w_list[j]).real if plot: plot_wigner(w.real, xvec, pvec, k) return w.real
[docs] def cv_to_wigner( state: list, wire: int, xrange: int | list = 10, prange: int | list = 10, npoints: int | list = 100, plot: bool = True, k: int = 0, normalize: bool = True, ): """Get the discretized Wigner function of the specified mode from a CV state. Args: state: The input ``Gaussianstate`` or ``BosonicState``. wire: The Wigner function for the given wire. xrange: The range of quadrature x. Default: 10 prange: The range of quadrature p. Default: 10 npoints: The number of discretization points for quadratures. Default: 100 plot: Whether to plot the Wigner function. Default: ``True`` k: The index of the Wigner function within the batch to plot. Default: 0 normalize: Whether to normalize the Wigner function. Default: ``True`` """ cov, mean = state[:2] xlist = [-xrange, xrange] if isinstance(xrange, int) else xrange plist = [-prange, prange] if isinstance(prange, int) else prange if isinstance(npoints, int): xlist.append(npoints) plist.append(npoints) else: xlist.append(npoints[0]) plist.append(npoints[1]) assert len(xlist) == len(plist) == 3 xvec = torch.linspace(*xlist, dtype=cov.dtype, device=cov.device) pvec = torch.linspace(*plist, dtype=cov.dtype, device=cov.device) grid_x, grid_y = torch.meshgrid(xvec, pvec, indexing='ij') coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)]).mT coords2 = coords.unsqueeze(1).unsqueeze(2) # (npoints, 1, 1, 2) coords3 = coords.unsqueeze(-1).unsqueeze(-3) if not isinstance(wire, torch.Tensor): wire = torch.tensor(wire).reshape(1) if cov.ndim == 2: cov = cov.unsqueeze(0) if mean.ndim == 2: mean = mean.unsqueeze(0) if cov.ndim == 3: cov = cov.unsqueeze(1) if mean.ndim == 3: mean = mean.unsqueeze(1) if len(state) == 2: weight = cov.new_ones(1) elif len(state) == 3: weight = state[-1] cov, mean, weight = align_shape(cov, mean, weight) nmode = cov.shape[-1] // 2 idx = torch.cat([wire, wire + nmode]) # xxpp order cov = cov[..., idx[:, None], idx] mean = mean[..., idx, :] + 0j # for Gaussian state gauss_b = MultivariateNormal(mean.squeeze(-1).real, cov) # mean shape: (batch, ncomb, 2) prob_g = gauss_b.log_prob(coords2).exp() # (npoints, batch, ncomb) exp_real = torch.exp(mean.imag.mT @ torch.linalg.solve(cov, mean.imag) / 2).squeeze(-2, -1) # (batch, ncomb) # (batch, npoints, ncomb) exp_imag = torch.exp( (coords3 - mean.real.unsqueeze(1)).mT @ torch.linalg.solve(cov, mean.imag).unsqueeze(1) * 1j ).squeeze(-2, -1) wigner_vals = exp_real.unsqueeze(-2) * prob_g.permute(1, 0, 2) * exp_imag * weight.unsqueeze(-2) wigner_vals = wigner_vals.sum(dim=2).reshape(-1, len(xvec), len(pvec)).real if normalize: # normalize the wigner function dx = xvec[1] - xvec[0] dp = pvec[1] - pvec[0] total_integral = torch.sum(wigner_vals, dim=[1, 2]) * dx * dp wigner_vals = wigner_vals / total_integral.reshape(-1, 1, 1) if plot: plot_wigner(wigner_vals, xvec, pvec, k) return wigner_vals
[docs] def plot_wigner(wigner: torch.Tensor, xvec: torch.Tensor, pvec: torch.Tensor, k: int = 0): """Plot a 2D contour and a 3D surface of a discretized Wigner function W(x, p). Args: wigner: Discretized Wigner values with shape (batch, nx, np). xvec: 1D grid for quadrature x. pvec: 1D grid for quadrature p. k: The index of the Wigner function within the batch to plot. Default: 0 """ grid_x, grid_y = torch.meshgrid(xvec, pvec, indexing='ij') x = grid_x.cpu() y = grid_y.cpu() z = wigner[k].cpu() fig = plt.figure(figsize=(16, 8)) ax1 = fig.add_subplot(1, 2, 1) ax1.set_xlabel('Quadrature x') ax1.set_ylabel('Quadrature p') cntr = ax1.contourf(x, y, z, 60, cmap=cm.RdBu) fig.colorbar(cntr, ax=ax1, shrink=0.5) ax2 = fig.add_subplot(1, 2, 2, projection='3d') ax2.plot_surface(x, y, z, cmap=cm.RdBu, alpha=0.8) ax2.set_xlabel('Quadrature x') ax2.set_ylabel('Quadrature p') ax2.set_zlabel('W(x, p)') plt.tight_layout() plt.show()