Source code for deepquantum.photonic.circuit

"""Photonic quantum circuit"""

import itertools
import warnings
from collections import Counter, defaultdict
from copy import copy, deepcopy
from typing import Any

import numpy as np
import torch
from torch import nn, vmap
from torch.distributions.multivariate_normal import MultivariateNormal

import deepquantum.photonic as dqp

from ..qmath import block_sample, get_prob_mps, inner_product_mps, is_positive_definite, sample_sc_mcmc
from ..state import MatrixProductState
from .channel import PhotonLoss
from .decompose import UnitaryDecomposer
from .distributed import measure_dist
from .draw import DrawCircuit
from .gate import (
    Barrier,
    BeamSplitter,
    BeamSplitterPhi,
    BeamSplitterSingle,
    BeamSplitterTheta,
    ControlledX,
    ControlledZ,
    CrossKerr,
    CubicPhase,
    DelayBS,
    DelayMZI,
    Displacement,
    DisplacementMomentum,
    DisplacementPosition,
    Kerr,
    MZI,
    PhaseShift,
    QuadraticPhase,
    Squeezing,
    Squeezing2,
    UAnyGate,
)
from .hafnian_ import hafnian
from .measurement import Generaldyne, Homodyne
from .operation import Channel, Delay, Gate, Operation
from .qmath import (
    align_shape,
    fock_combinations,
    measure_fock_tensor,
    permanent,
    photon_number_mean_var_cv,
    photon_number_mean_var_fock,
    product_factorial,
    quadrature_mean_fock,
    quadrature_to_ladder,
    sample_homodyne_fock,
    sample_reject_bosonic,
    shift_func,
    sort_dict_fock_basis,
    sub_matrix,
    williamson,
)
from .state import (
    BosonicState,
    CatState,
    DistributedFockState,
    FockState,
    GKPState,
    GaussianState,
    combine_bosonic_states,
)
from .torontonian_ import torontonian


[docs] class QumodeCircuit(Operation): r"""Photonic quantum circuit. Args: nmode: The number of modes in the circuit. init_state: The initial state of the circuit. It can be a vacuum state with ``'vac'`` or ``'zeros'``. For Fock backend, it can be a Fock basis state, e.g., ``[1,0,0]``, or a Fock state tensor, e.g., ``[(1/2**0.5, [1,0]), (1/2**0.5, [0,1])]``. Alternatively, it can be a tensor representation. For Gaussian backend, it can be arbitrary Gaussian states with ``[cov, mean]``. For Bosonic backend, it can be arbitrary linear combinations of Gaussian states with ``[cov, mean, weight]``, or a list of local Bosonic states. Use ``xxpp`` convention and :math:`\hbar=2` by default. cutoff: The Fock space truncation. Default: ``None`` backend: Use ``'fock'`` for Fock backend, ``'gaussian'`` for Gaussian backend or ``'bosonic'`` for Bosonic backend. Default: ``'fock'`` basis: Whether to use the representation of Fock basis state for the initial state. Default: ``True`` den_mat: Whether to use density matrix representation. Only valid for Fock state tensor. Default: ``False`` detector: For Gaussian backend, use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Default: ``'pnrd'`` name: The name of the circuit. Default: ``None`` mps: Whether to use matrix product state representation. Default: ``False`` chi: The bond dimension for matrix product state representation. Default: ``None`` noise: Whether to introduce Gaussian noise. Default: ``False`` mu: The mean of Gaussian noise. Default: 0 sigma: The standard deviation of Gaussian noise. Default: 0.1 """ def __init__( self, nmode: int, init_state: Any, cutoff: int | None = None, backend: str = 'fock', basis: bool = True, den_mat: bool = False, detector: str = 'pnrd', name: str | None = None, mps: bool = False, chi: int | None = None, noise: bool = False, mu: float = 0, sigma: float = 0.1, ) -> None: super().__init__( name=name, nmode=nmode, wires=list(range(nmode)), cutoff=cutoff, den_mat=den_mat, noise=noise, mu=mu, sigma=sigma, ) self.backend = backend.lower() self.basis = basis self.detector = detector.lower() self.mps = mps self.chi = chi self.set_init_state(init_state) self.operators = nn.Sequential() self.encoders = [] self.measurements = nn.ModuleList() self.wires_homodyne = [] self.state = None self.state_measured = None self.ndata = 0 self.depth = np.array([0] * nmode) self._lossy = False self._nloss = 0 # Fock basis state self._is_batch_expanded = False # whether to expand batched states (add a mode) to align photon numbers self._init_state_forward = None # prepared initial state in forward() (init_state + nloss + batch expansion) self._out_fock_basis = None # the output Fock basis states self._reset_fock_basis = True # whether to recompute the output Fock basis states in forward() self._init_state_sample = None # for _sample_mcmc_fock() # Bosonic self._bosonic_states = None # ModuleList of initial Bosonic states # TDM self._with_delay = False self._nmode_tdm = self.nmode self._ntau_dict = defaultdict(list) # {wire: [tau1, tau2, ...]} self._unroll_dict = None # {wire_space: [wires_delay_n, ..., wires_delay_1, wire_space_concurrent]} self._operators_tdm = None self._measurements_tdm = None
[docs] def set_init_state(self, init_state: Any) -> None: """Set the initial state of the circuit.""" if isinstance(init_state, (FockState, GaussianState, BosonicState, MatrixProductState)): if isinstance(init_state, MatrixProductState): assert self.nmode == init_state.nsite assert self.backend == 'fock' and not self.basis, ( 'Only support MPS for Fock backend with Fock state tensor.' ) self.mps = True self.chi = init_state.chi self.cutoff = init_state.qudit # if self.cutoff is changed, the operators should be reset else: assert self.nmode == init_state.nmode self.mps = False self.cutoff = init_state.cutoff # if self.cutoff is changed, the operators should be reset if isinstance(init_state, FockState): self.backend = 'fock' self.basis = init_state.basis elif isinstance(init_state, GaussianState): self.backend = 'gaussian' elif isinstance(init_state, BosonicState): self.backend = 'bosonic' self.init_state = init_state else: if self.mps: assert self.backend == 'fock' and not self.basis, ( 'Only support MPS for Fock backend with Fock state tensor.' ) assert self.cutoff is not None, 'Please set the cutoff.' self.init_state = MatrixProductState( nsite=self.nmode, state=init_state, chi=self.chi, qudit=self.cutoff, normalize=False ) else: if self.backend == 'fock': self.init_state = FockState( state=init_state, nmode=self.nmode, cutoff=self.cutoff, basis=self.basis, den_mat=self.den_mat ) elif self.backend == 'gaussian': self.init_state = GaussianState(state=init_state, nmode=self.nmode, cutoff=self.cutoff) elif self.backend == 'bosonic': if isinstance(init_state, list) and all(isinstance(s, BosonicState) for s in init_state): self.init_state = combine_bosonic_states(states=init_state, cutoff=self.cutoff) if self.init_state.nmode < self.nmode: nmode = self.nmode - self.init_state.nmode vac = BosonicState(state='vac', nmode=nmode, cutoff=self.cutoff) self.init_state.tensor_product(vac) assert self.init_state.nmode == self.nmode else: self.init_state = BosonicState(state=init_state, nmode=self.nmode, cutoff=self.cutoff) self.cutoff = self.init_state.cutoff
def __add__(self, rhs: 'QumodeCircuit') -> 'QumodeCircuit': """Addition of the ``QumodeCircuit``. The initial state is the same as the first ``QumodeCircuit``. """ assert self.nmode == rhs.nmode cir = QumodeCircuit( nmode=self.nmode, init_state=self.init_state, cutoff=self.cutoff, backend=self.backend, basis=self.basis, den_mat=self.den_mat, detector=self.detector, name=self.name, mps=self.mps, chi=self.chi, noise=self.noise, mu=self.mu, sigma=self.sigma, ) cir.operators = self.operators + rhs.operators cir.encoders = self.encoders + rhs.encoders cir.measurements = rhs.measurements cir.wires_homodyne = rhs.wires_homodyne cir.npara = self.npara + rhs.npara cir.ndata = self.ndata + rhs.ndata cir.depth = self.depth + rhs.depth cir._lossy = self._lossy or rhs._lossy cir._nloss = self._nloss + rhs._nloss cir._bosonic_states = self._bosonic_states cir._with_delay = self._with_delay or rhs._with_delay cir._nmode_tdm = self._nmode_tdm + rhs._nmode_tdm - self.nmode cir._ntau_dict = defaultdict(list) for key, value in self._ntau_dict.items(): cir._ntau_dict[key].extend(value) for key, value in rhs._ntau_dict.items(): cir._ntau_dict[key].extend(value) return cir
[docs] def forward( self, data: torch.Tensor | None = None, state: Any = None, is_prob: bool | None = None, detector: str | None = None, sort: bool = True, stepwise: bool = False, ) -> torch.Tensor | dict | list[torch.Tensor]: """Perform a forward pass of the photonic quantum circuit and return the final-state-related result. Args: data: The input data for the ``encoders``. Default: ``None`` state: The initial state for the photonic quantum circuit. Default: ``None`` is_prob: For Fock backend, whether to return probabilities or amplitudes. For Gaussian (Bosonic) backend, whether to return probabilities or the final Gaussian (Bosonic) state. For Fock backend with ``basis=True``, set ``None`` to return the unitary matrix. Default: ``None`` detector: For Gaussian backend, use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Default: ``None`` sort: Whether to sort dictionary of Fock basis states in the descending order of probabilities. Default: ``True`` stepwise: Whether to use the forward function of each operator for Gaussian backend. Default: ``False`` Returns: The result of the photonic quantum circuit after applying the ``operators``. """ if self.backend == 'fock': return self._forward_fock(data, state, is_prob, sort) elif self.backend in ('gaussian', 'bosonic'): return self._forward_cv(data, state, is_prob, detector, stepwise)
def _forward_fock( self, data: torch.Tensor | None = None, state: Any = None, is_prob: bool | None = None, sort: bool = True ) -> torch.Tensor | dict | list[torch.Tensor]: """Perform a forward pass based on the Fock backend. Args: data: The input data for the ``encoders``. Default: ``None`` state: The initial state for the photonic quantum circuit. Default: ``None`` is_prob: Whether to return probabilities or amplitudes. When ``basis=True``, set ``None`` to return the unitary matrix. Default: ``None`` sort: Whether to sort dictionary of Fock basis states in the descending order of probabilities. Default: ``True`` Returns: Unitary matrix, Fock state tensor, a dictionary of probabilities or amplitudes, or a list of tensors for MPS. """ if self.mps: assert not is_prob if state is None: state = self.init_state if isinstance(state, MatrixProductState): assert not self.basis state = state.tensors elif isinstance(state, FockState): state = state.state elif not isinstance(state, torch.Tensor): state = FockState(state, self.nmode, self.cutoff, self.basis, self.den_mat).state if not self.basis and isinstance(state, torch.Tensor) and state.device.type == 'mps': max_mps_dim = 16 mps_dim = 2 * self.nmode + 1 if self.den_mat else self.nmode + 1 if mps_dim > max_mps_dim: warnings.warn( f'Apple Silicon MPS limit ({max_mps_dim} dims) exceeded. Auto-falling back to CPU.', UserWarning, stacklevel=4, ) self.cpu() state = state.cpu() if isinstance(data, torch.Tensor): data = data.cpu() # preprocessing of batched initial states if self.basis: self._is_batch_expanded = False # reset state = self._prepare_init_state(state, reset_fock_basis=self._reset_fock_basis) self._init_state_forward = state if self.ndata == 0: data = None if data is None or data.ndim == 1: if self.basis: assert state.ndim in (1, 2) if state.ndim == 1: self.state = self._forward_helper_basis(data, state, is_prob) elif state.ndim == 2: self.state = vmap(self._forward_helper_basis, in_dims=(None, 0, None))(data, state, is_prob) else: self.state = self._forward_helper_tensor(data, state, is_prob) if not self.mps and self.state.ndim == self.nmode: self.state = self.state.unsqueeze(0) else: assert data.ndim == 2 if self.basis: assert state.ndim in (1, 2) if state.ndim == 1: self.state = vmap(self._forward_helper_basis, in_dims=(0, None, None))(data, state, is_prob) elif state.ndim == 2: if data.shape[0] == 1: self.state = vmap(self._forward_helper_basis, in_dims=(None, 0, None))(data[0], state, is_prob) else: self.state = vmap(self._forward_helper_basis, in_dims=(0, 0, None))(data, state, is_prob) else: if self.mps: assert state[0].ndim in (3, 4) if state[0].ndim == 3: self.state = vmap(self._forward_helper_tensor, in_dims=(0, None, None))(data, state, is_prob) elif state[0].ndim == 4: self.state = vmap(self._forward_helper_tensor, in_dims=(0, 0, None))(data, state, is_prob) else: if state.shape[0] == 1: self.state = vmap(self._forward_helper_tensor, in_dims=(0, None, None))(data, state, is_prob) else: self.state = vmap(self._forward_helper_tensor, in_dims=(0, 0, None))(data, state, is_prob) # for plotting the last data self.encode(data[-1]) if sort and self.basis and is_prob is not None: self.state = sort_dict_fock_basis(self.state) return self.state def _forward_helper_basis( self, data: torch.Tensor | None = None, state: torch.Tensor | None = None, is_prob: bool | None = None ) -> torch.Tensor | dict: """Perform a forward pass for one sample if the input is a Fock basis state.""" self.encode(data) unitary = self.get_unitary() if is_prob is None: return unitary else: if state is None: state = self.init_state.state out_dict = defaultdict(float) final_states = self._out_fock_basis if self._is_batch_expanded: unitary = torch.block_diag(unitary, torch.eye(1, dtype=unitary.dtype, device=unitary.device)) sub_mats = vmap(sub_matrix, in_dims=(None, None, 0))(unitary, state, final_states) per_norms = self._get_permanent_norms(state, final_states).to(unitary.dtype) if is_prob: rst = vmap(self._get_prob_fock_vmap)(sub_mats, per_norms) else: rst = vmap(self._get_amplitude_fock_vmap)(sub_mats, per_norms) for i in range(len(final_states)): final_state = FockState(state=final_states[i], nmode=self.nmode, cutoff=self.cutoff, basis=self.basis) if not is_prob: assert final_state not in out_dict, ( 'Amplitudes of reduced states can not be added, please set "is_prob" to be True.' ) out_dict[final_state] += rst[i] return dict(out_dict) def _forward_helper_tensor( self, data: torch.Tensor | None = None, state: torch.Tensor | list[torch.Tensor] | None = None, is_prob: bool | None = None, ) -> torch.Tensor | list[torch.Tensor]: """Perform a forward pass for one sample if the input is a Fock state tensor.""" self.encode(data) if state is None: state = self.init_state if self.mps: if not isinstance(state, MatrixProductState): state = MatrixProductState( nsite=self.nmode, state=state, chi=self.chi, qudit=self.cutoff, normalize=self.init_state.normalize ) return self.operators(state).tensors else: if isinstance(state, FockState): state = state.state x = self.operators(self.tensor_rep(state)).squeeze(0) if is_prob: if self.den_mat: x = x.reshape(-1, self.cutoff**self.nmode, self.cutoff**self.nmode).diagonal(dim1=-2, dim2=-1) x = abs(x).reshape([-1] + [self.cutoff] * self.nmode).squeeze(0) else: x = abs(x) ** 2 return x def _forward_cv( self, data: torch.Tensor | None = None, state: Any = None, is_prob: bool | None = None, detector: str | None = None, stepwise: bool = False, ) -> list[torch.Tensor] | dict: """Perform a forward pass based on the Gaussian (Bosonic) backend. Args: data: The input data for the ``encoders``. Default: ``None`` state: The initial state for the photonic quantum circuit. Default: ``None`` is_prob: Whether to return probabilities or the final Gaussian (Bosonic) state. Default: ``None`` detector: Use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Only valid when ``is_prob`` is ``True``. Default: ``None`` stepwise: Whether to use the forward function of each operator. Default: ``False`` Returns: The final Gaussian (Bosonic) state or a dictionary of probabilities. """ if state is None: if self.backend == 'bosonic' and self._bosonic_states is not None: state = combine_bosonic_states(states=self._bosonic_states, cutoff=self.cutoff) else: state = self.init_state elif not isinstance(state, (GaussianState, BosonicState)): nmode = self.nmode is_list_of_tensors = isinstance(state, list) and state and isinstance(state[0], torch.Tensor) if is_list_of_tensors and state[0].shape[-1] == 2 * self._nmode_tdm: nmode = self._nmode_tdm if self.backend == 'gaussian': state = GaussianState(state=state, nmode=nmode, cutoff=self.cutoff) elif self.backend == 'bosonic': state = BosonicState(state=state, nmode=nmode, cutoff=self.cutoff) cov, mean = state.cov, state.mean weight = state.weight if self.backend == 'bosonic' else None if self._with_delay: self._prepare_unroll_dict() cov, mean = self._unroll_init_state([cov, mean]) self._unroll_circuit() if data is None or data.ndim == 1: cov, mean = self._forward_helper_gaussian(data, [cov, mean], stepwise) if cov.ndim < state.cov.ndim: cov = cov.unsqueeze(0) if mean.ndim < state.mean.ndim: mean = mean.unsqueeze(0) else: assert data.ndim == 2 if cov.shape[0] == 1: cov, mean = vmap(self._forward_helper_gaussian, in_dims=(0, None, None))(data, [cov, mean], stepwise) else: cov, mean = vmap(self._forward_helper_gaussian, in_dims=(0, 0, None))(data, [cov, mean], stepwise) self.encode(data[-1]) if is_prob: self.state = [cov, mean] # for checking purity self.state = self._forward_cv_prob(cov, mean, weight, detector) else: if self._with_delay: cov, mean = self._shift_state([cov, mean]) if self.backend == 'gaussian': self.state = [cov, mean] elif self.backend == 'bosonic': self.state = [cov, mean, weight] return self.state def _forward_helper_gaussian( self, data: torch.Tensor | None = None, state: list[torch.Tensor] | None = None, stepwise: bool = False ) -> list[torch.Tensor]: """Perform a forward pass for one sample if the input is a Gaussian state.""" if self._lossy: stepwise = True self.encode(data) operators = self._operators_tdm if self._with_delay else self.operators cov, mean = (self.init_state.cov, self.init_state.mean) if state is None else state if self.backend == 'bosonic' and cov.ndim == 3: cov = cov.unsqueeze(0) if stepwise: cov, mean = operators([cov, mean]) else: sp_mat = self.get_symplectic() cov = sp_mat @ cov @ sp_mat.mT mean = self.get_displacement(mean) return [cov.squeeze(0), mean.squeeze(0)] def _forward_cv_prob( self, cov: torch.Tensor, mean: torch.Tensor, weight: torch.Tensor | None = None, detector: str | None = None ) -> dict: """Get the probabilities of all possible final states for Gaussian (Bosonic) backend by different detectors. Args: cov: The covariance matrices of the Gaussian states. mean: The displacement vectors of the Gaussian states. weight: The weights of the Gaussian states. Default: ``None`` detector: Use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Default: ``None`` """ assert weight is None, 'Currently Fock probability is not supported in Bosonic backend' shape_cov = cov.shape shape_mean = mean.shape if shape_cov[1] == 1: cov = cov.expand(-1, shape_mean[1], -1, -1) if shape_mean[1] == 1: mean = mean.expand(-1, shape_cov[1], -1, -1) cov = cov.reshape(-1, *shape_cov[-2:]) mean = mean.reshape(-1, *shape_mean[-2:]) purity = GaussianState([cov, mean]).is_pure batch_forward = vmap(self._forward_gaussian_prob_helper, in_dims=(0, 0, None, None, None, None)) if detector is None: detector = self.detector else: detector = detector.lower() self.detector = detector basis = self._get_odd_even_fock_basis(detector=detector) if detector == 'pnrd': idx_loop = torch.all(mean == 0, dim=1) idx_loop = idx_loop.squeeze(1) cov_0 = cov[idx_loop] mean_0 = mean[idx_loop] cov_1 = cov[~idx_loop] mean_1 = mean[~idx_loop] final_states = torch.cat([torch.cat(basis[1]), torch.cat(basis[0])]) probs = [] if len(cov_0) > 0: loop = False probs_0 = batch_forward(cov_0, mean_0, basis, detector, purity, loop) probs.append(probs_0) if len(cov_1) > 0: loop = True probs_1 = batch_forward(cov_1, mean_1, basis, detector, purity, loop) probs.append(probs_1) probs = torch.cat(probs) # reorder the result here if len(cov_0) * len(cov_1) > 0: idx0 = torch.where(~idx_loop == 0)[0] idx1 = torch.where(~idx_loop == 1)[0] probs = probs[torch.argsort(torch.cat([idx0, idx1]))] elif detector == 'threshold': final_states = torch.cat(basis) loop = True probs = batch_forward(cov, mean, basis, detector, purity, loop) keys = list(map(FockState, final_states.tolist())) # TODO: Fock probabilities for Bosonic state with weights # if weight is not None: # probs = probs.reshape(weight.shape[0], weight.shape[1], -1) # (batch, ncomb, nfock) # probs = (probs * weight.unsqueeze(-1)).sum(1).real return dict(zip(keys, probs.mT, strict=True)) def _forward_gaussian_prob_helper(self, cov, mean, basis, detector, purity, loop): prob_lst = [] if detector == 'pnrd': odd_basis = basis[0] even_basis = basis[1] for state in even_basis: prob_even = self._get_probs_gaussian_helper(state, cov, mean, detector, purity, loop) prob_lst.append(prob_even) if loop or not purity: for state in odd_basis: prob_odd = self._get_probs_gaussian_helper(state, cov, mean, detector, purity, loop) prob_lst.append(prob_odd) probs = torch.cat(prob_lst) else: probs = torch.cat(prob_lst) probs = torch.cat([probs, torch.zeros(len(torch.cat(odd_basis)), device=probs.device)]) elif detector == 'threshold': for state in basis: prob = self._get_probs_gaussian_helper(state, cov, mean, detector, purity, loop) prob_lst.append(prob) probs = torch.cat(prob_lst) return probs
[docs] def set_fock_basis(self, state: Any = None, reset_in_forward: bool = False) -> None: """Set the output Fock basis states. By default it will generate all Fock basis states according to the inital state. Args: state: The output Fock basis states. Default: ``None`` reset_in_forward: Whether to recompute the output Fock basis states in the forward pass. Default: ``False`` """ assert self.basis assert self.init_state.state.ndim == 1, 'Manual setting for batched initial states is not allowed.' if state is None: if not reset_in_forward: # avoid double calculation of the output Fock basis states self._prepare_init_state(self.init_state.state, reset_fock_basis=True) else: state = FockState(state).state if state.ndim == 1: state = state.unsqueeze(0) assert torch.all(state.sum(dim=-1) == state[0].sum(dim=-1)), ( 'The number of photons must be the same and equal to initial states.' ) assert state.shape[-1] == self.nmode + self._nloss, ( 'Please fill in the right number of modes (including all ancilla modes in lossy case.)' ) self._out_fock_basis = state self._reset_fock_basis = reset_in_forward
[docs] def get_fock_basis(self) -> torch.Tensor: """Get the output Fock basis states according to the current settings.""" if self._out_fock_basis is None: self._prepare_init_state(self.init_state.state, reset_fock_basis=True) return self._out_fock_basis
def _get_all_fock_basis(self, init_state: torch.Tensor) -> torch.Tensor: """Calculate all possible Fock basis states according to the initial state.""" nphoton = torch.max(torch.sum(init_state, dim=-1)) nmode = len(init_state) nancilla = nmode - self._nmode_tdm if self._with_delay else nmode - self.nmode states = torch.tensor( fock_combinations(nmode, nphoton, self.cutoff, nancilla=nancilla), dtype=torch.long, device=init_state.device, ) return states def _get_odd_even_fock_basis(self, detector: str | None = None) -> tuple[list, list] | list: """Split the Fock basis into the odd and even photon number parts.""" if detector is None: detector = self.detector nmode = self._nmode_tdm if self._with_delay else self.nmode if detector == 'pnrd': max_photon = nmode * (self.cutoff - 1) odd_lst = [] even_lst = [] for i in range(0, max_photon + 1): state_tmp = torch.tensor([i] + [0] * (nmode - 1)) temp_basis = self._get_all_fock_basis(state_tmp) if i % 2 == 0: even_lst.append(temp_basis) else: odd_lst.append(temp_basis) return odd_lst, even_lst elif detector == 'threshold': final_states = torch.tensor(list(itertools.product(range(2), repeat=nmode))) keys = torch.sum(final_states, dim=1) dic_temp = defaultdict(list) for state, s in zip(final_states, keys, strict=True): dic_temp[s.item()].append(state) state_lst = [torch.stack(i) for i in list(dic_temp.values())] return state_lst def _prepare_init_state(self, state: torch.Tensor, reset_fock_basis: bool = False) -> torch.Tensor: """Check and expand the Fock state if necessary.""" if state.ndim == 1: if self._lossy: state = torch.cat([state, state.new_zeros(self._nloss)], dim=-1) if reset_fock_basis: self._out_fock_basis = self._get_all_fock_basis(state) elif state.ndim == 2: if self._lossy: state = torch.cat([state, state.new_zeros(state.shape[0], self._nloss)], dim=-1) nphotons = torch.sum(state, dim=-1, keepdim=True) max_photon = torch.max(nphotons).item() # expand the Fock state if the photon number is not conserved if any(nphoton < max_photon for nphoton in nphotons): state = torch.cat([state, max_photon - nphotons], dim=-1) self._is_batch_expanded = True if reset_fock_basis: self._out_fock_basis = self._get_all_fock_basis(state[0]) return state def _prepare_unroll_dict(self) -> dict[int, list]: """Create a dictionary that maps spatial modes to concurrent modes.""" if self._unroll_dict is None: self._unroll_dict = defaultdict(list) wires = list(range(self._nmode_tdm)) start = 0 for i in range(self.nmode): for ntau in reversed(self._ntau_dict[i]): self._unroll_dict[i].append(wires[start : start + ntau]) # modes in delay line start += ntau self._unroll_dict[i].append(wires[start]) # spatial mode start += 1 return self._unroll_dict def _unroll_init_state(self, state: list[torch.Tensor]) -> list[torch.Tensor]: """Unroll the initial state from spatial modes to concurrent modes.""" idx = torch.tensor([value[-1] for value in self._unroll_dict.values()]) idx = torch.cat([idx, idx + self._nmode_tdm]) cov, mean = state size = cov.size() size_tdm = 2 * self._nmode_tdm if size[-1] == size_tdm: return state else: cov_tdm = cov.new_ones(size[:-2].numel() * size_tdm).reshape(*size[:-2], size_tdm).diag_embed() mean_tdm = mean.new_zeros(*size[:-2], size_tdm, 1) cov_tdm[..., idx[:, None], idx] = cov mean_tdm[..., idx, :] = mean return [cov_tdm, mean_tdm] def _unroll_circuit(self) -> None: """Unroll the circuit from spatial modes to concurrent modes.""" nmode = self._nmode_tdm if self._operators_tdm is None: self._operators_tdm = nn.Sequential() ndelay = np.array([0] * self.nmode) # counter of delay loops for each mode for op in self.operators: if isinstance(op, Delay): wire = op.wires[0] ndelay[wire] += 1 idx_delay = -ndelay[wire] - 1 wires = [self._unroll_dict[wire][idx_delay][0], self._unroll_dict[wire][-1]] op.gates[0].nmode = nmode op.gates[0].wires = wires self._operators_tdm.append(op.gates[0]) if len(op.gates) > 1: for gate in op.gates[1:]: gate.nmode = nmode gate.wires = wires[0:1] if isinstance(gate, PhotonLoss): self._lossy = True self._nloss += 1 self._operators_tdm.append(gate) else: op_tdm = copy(op) op_tdm.nmode = nmode op_tdm.wires = [self._unroll_dict[wire][-1] for wire in op.wires] self._operators_tdm.append(op_tdm) if self._measurements_tdm is None: self._measurements_tdm = nn.ModuleList() for op_m in self.measurements: op_m_tdm = copy(op_m) op_m_tdm.nmode = nmode op_m_tdm.wires = [self._unroll_dict[wire][-1] for wire in op_m.wires] self._measurements_tdm.append(op_m_tdm)
[docs] def global_circuit(self, nstep: int, use_deepcopy: bool = False) -> 'QumodeCircuit': """Get the global circuit given the number of time steps. Note: The initial state of the global circuit is always the vacuum state. """ self._prepare_unroll_dict() nmode = self._nmode_tdm + (nstep - 1) * self.nmode cir = QumodeCircuit( nmode, init_state='vac', cutoff=self.cutoff, backend=self.backend, basis=self.basis, den_mat=self.den_mat, detector=self.detector, name=self.name, mps=self.mps, chi=self.chi, noise=self.noise, mu=self.mu, sigma=self.sigma, ) for i in range(nstep): ndelay = np.array([0] * self.nmode) # counter of delay loops for each mode for op in self.operators: encode = op in self.encoders is_deep = use_deepcopy or encode if isinstance(op, Delay): wire = op.wires[0] ndelay[wire] += 1 idx_delay = -ndelay[wire] - 1 wire1 = self._unroll_dict[wire][idx_delay][i % op.ntau] wire2 = self._unroll_dict[wire][-1] if i == 0 else self._nmode_tdm + self.nmode * (i - 1) + wire op_tdm = deepcopy(op.gates[0]) if is_deep else copy(op.gates[0]) op_tdm.nmode = nmode op_tdm.wires = [wire1, wire2] cir.add(op_tdm, encode=encode) if len(op.gates) > 1: for gate in op.gates[1:]: op_gate = deepcopy(gate) if is_deep else copy(gate) op_gate.nmode = nmode op_gate.wires = [wire1] if isinstance(gate, PhotonLoss): cir._lossy = True cir._nloss += 1 cir.add(op_gate, encode=encode) else: op_tdm = deepcopy(op) if is_deep else copy(op) op_tdm.nmode = nmode if i == 0: op_tdm.wires = [self._unroll_dict[wire][-1] for wire in op.wires] else: op_tdm.wires = [self._nmode_tdm + self.nmode * (i - 1) + wire for wire in op.wires] if isinstance(op, PhotonLoss): cir._lossy = True cir._nloss += 1 cir.add(op_tdm, encode=encode) for op_m in self.measurements: op_m_tdm = copy(op_m) op_m_tdm.nmode = nmode if i == 0: op_m_tdm.wires = [self._unroll_dict[wire][-1] for wire in op_m.wires] else: op_m_tdm.wires = [self._nmode_tdm + self.nmode * (i - 1) + wire for wire in op_m.wires] cir.add(op_m_tdm) cir.barrier() return cir
def _shift_state(self, state: list[torch.Tensor], nstep: int = 1, reverse: bool = False) -> list[torch.Tensor]: """Shift the state according to ``nstep``, which is equivalent to shifting the TDM circuit.""" cov, mean = state idx_shift = [] for wire in self._unroll_dict: for idx in self._unroll_dict[wire]: if isinstance(idx, int): idx_shift.append(idx) elif isinstance(idx, list): if reverse: idx_shift.extend(shift_func(idx, -nstep)) else: idx_shift.extend(shift_func(idx, nstep)) idx_shift = torch.tensor(idx_shift) idx_shift = torch.cat([idx_shift, idx_shift + self._nmode_tdm]) cov = cov[..., idx_shift[:, None], idx_shift] mean = mean[..., idx_shift, :] return [cov, mean]
[docs] def encode(self, data: torch.Tensor | None) -> None: """Encode the input data into the photonic quantum circuit parameters. This method iterates over the ``encoders`` of the circuit and initializes their parameters with the input data. Args: data: The input data for the ``encoders``, must be a 1D tensor. """ if data is None: return assert len(data) >= self.ndata count = 0 for op in self.encoders: count_up = count + op.npara op.init_para(data[count:count_up]) count = count_up
[docs] def get_unitary(self) -> torch.Tensor: """Get the unitary matrix of the photonic quantum circuit.""" u = None operators = self._operators_tdm if self._with_delay else self.operators nloss = 0 for op in operators: if isinstance(op, Barrier): continue if isinstance(op, PhotonLoss): nloss += 1 op.gate.wires = [op.wires[0], op.nmode + nloss - 1] op.gate.nmode = op.nmode + nloss if u is None: u = op.gate.get_unitary() continue else: u = torch.block_diag(u, torch.eye(1, dtype=u.dtype, device=u.device)) idx_r = torch.tensor(op.gate.wires, device=u.device) idx_c = torch.arange(op.gate.nmode, device=u.device) u_local = op.gate.update_matrix() else: if u is None: u = op.get_unitary() continue else: idx_r = torch.tensor(op.wires, device=u.device) idx_c = torch.arange(op.nmode + nloss, device=u.device) u_local = op.update_matrix() assert u_local.shape[-2] == u_local.shape[-1] == len(op.wires), ( 'The matrix may not act on creation operators.' ) u_update = u[idx_r[:, None], idx_c] new_val = u_local @ u_update u = u.index_put([idx_r[:, None], idx_c], new_val) if u is None: return torch.eye(self.nmode, dtype=torch.cfloat) else: return u
[docs] def get_symplectic(self) -> torch.Tensor: """Get the symplectic matrix of the photonic quantum circuit.""" s = None if self._with_delay: operators = self._operators_tdm nmode = self._nmode_tdm else: operators = self.operators nmode = self.nmode for op in operators: if isinstance(op, Barrier): continue s = op.get_symplectic() if s is None else op.get_symplectic() @ s if s is None: return torch.eye(2 * nmode, dtype=torch.float) return s
[docs] def get_displacement(self, init_mean: Any) -> torch.Tensor: """Get the final mean value of the Gaussian state in ``xxpp`` order.""" if not isinstance(init_mean, torch.Tensor): init_mean = torch.tensor(init_mean) if self._with_delay: operators = self._operators_tdm nmode = self._nmode_tdm else: operators = self.operators nmode = self.nmode mean = init_mean if self.backend == 'gaussian': mean = mean.reshape(-1, 2 * nmode, 1) elif self.backend == 'bosonic': if mean.ndim == 2: mean = mean.unsqueeze(0).unsqueeze(-1) elif mean.ndim == 3: if mean.shape[-1] == 1: mean = mean.unsqueeze(0) elif mean.shape[-1] == 2 * nmode: mean = mean.unsqueeze(-1) assert mean.ndim == 4 for op in operators: if isinstance(op, Barrier): continue mean = op.get_symplectic().to(mean.dtype) @ mean + op.get_displacement() return mean
def _get_permanent_norms(self, init_state: torch.Tensor, final_state: torch.Tensor) -> torch.Tensor: """Get the normalization factors for permanent.""" return torch.sqrt(product_factorial(init_state) * product_factorial(final_state))
[docs] def get_amplitude( self, final_state: Any, init_state: Any = None, unitary: torch.Tensor | None = None ) -> torch.Tensor: """Get the transfer amplitude between the final state and the initial state. Note: When states are expanded due to photon loss or batched initial states, the amplitudes of the reduced states can not be added, please try ``get_prob`` instead. Args: final_state: The final Fock basis state. init_state: The initial Fock basis state. Default: ``None`` unitary: The unitary matrix. Default: ``None`` """ assert self.backend == 'fock' if not isinstance(final_state, torch.Tensor): final_state = torch.tensor(final_state, dtype=torch.long) if init_state is None: init_state = self.init_state elif not isinstance(init_state, FockState): init_state = FockState(state=init_state, nmode=self.nmode, cutoff=self.cutoff, basis=self.basis) assert init_state.basis, 'The initial state must be a Fock basis state' assert max(final_state) < self.cutoff, 'The number of photons in the final state must be less than cutoff' if unitary is None: unitary = self.get_unitary() else: assert unitary.ndim == 2, 'The unitary matrix must be 2D' state = init_state.state.to(unitary.device) final_state = final_state.to(unitary.device) if state.ndim == 1: sub_mat = sub_matrix(unitary, state, final_state) per = permanent(sub_mat) amp = per / self._get_permanent_norms(state, final_state).to(per.dtype) else: idx_nonzero = torch.where(torch.sum(state, dim=-1) == torch.sum(final_state))[0] amp = torch.zeros(state.shape[0], dtype=unitary.dtype, device=unitary.device) if idx_nonzero.numel() != 0: sub_mats = vmap(sub_matrix, in_dims=(None, 0, None))(unitary, state[idx_nonzero], final_state) per_norms = self._get_permanent_norms(state[idx_nonzero], final_state).to(unitary.dtype) rst = vmap(self._get_amplitude_fock_vmap)(sub_mats, per_norms).flatten() amp[idx_nonzero] = rst return amp
def _get_amplitude_fock_vmap(self, sub_mat: torch.Tensor, per_norm: torch.Tensor) -> torch.Tensor: """Get the transfer amplitude.""" per = permanent(sub_mat) amp = per / per_norm return amp.reshape(-1)
[docs] def get_prob(self, final_state: Any, refer_state: Any = None, unitary: torch.Tensor | None = None) -> torch.Tensor: """Get the probability of the final state related to the reference state. Args: final_state: The final Fock basis state. refer_state: The initial Fock basis state or the final Gaussian state. Default: ``None`` unitary: The unitary matrix. Default: ``None`` """ if not isinstance(final_state, torch.Tensor): final_state = torch.tensor(final_state, dtype=torch.long) assert max(final_state) < self.cutoff, 'The number of photons in the final state must be less than cutoff' if self.backend == 'fock': if refer_state is None: refer_state = self._prepare_init_state(self.init_state.state) if unitary is None: unitary = self.get_unitary() else: assert unitary.ndim == 2, 'The unitary matrix must be 2D' if self._is_batch_expanded: identity = torch.eye(1, dtype=unitary.dtype, device=unitary.device) unitary = torch.block_diag(unitary, identity) nmode = final_state.shape[-1] if refer_state.shape[-1] == nmode: return self._get_prob_fock(final_state, refer_state, unitary) else: wires = list(range(nmode)) nphoton_final = torch.sum(final_state, dim=-1) max_photon = torch.sum(refer_state, dim=-1).max().item() nmode_expand = refer_state.shape[-1] - nmode expand_state = torch.tensor( fock_combinations(nmode_expand, max_photon - nphoton_final), dtype=torch.long, device=final_state.device, ) final_state = final_state.reshape(-1, nmode).expand(expand_state.shape[0], -1) final_states = torch.cat([final_state, expand_state], dim=-1) if refer_state.ndim == 1: rst = self._measure_fock_unitary_helper(refer_state, unitary, wires, final_states) else: rst = vmap(self._measure_fock_unitary_helper, in_dims=(0, None, None, None))( refer_state, unitary, wires, final_states ) rst = list(rst.values())[0] return rst elif self.backend == 'gaussian': nmode = self._nmode_tdm if self._with_delay else self.nmode if refer_state is None: refer_state = GaussianState(self.state, nmode=nmode, cutoff=self.cutoff) return self._get_prob_gaussian(final_state, refer_state)
def _get_prob_fock( self, final_state: Any, init_state: Any = None, unitary: torch.Tensor | None = None ) -> torch.Tensor: """Get the transfer probability between the final state and the initial state for the Fock backend. Args: final_state: The final Fock basis state. init_state: The initial Fock basis state. Default: ``None`` unitary: The unitary matrix. Default: ``None`` """ if init_state is None: # when mcmc nmode = self.nmode + self._nloss + self._is_batch_expanded init_state = FockState(state=self._init_state_sample, nmode=nmode, cutoff=self.cutoff, basis=self.basis) if unitary is None: # when mcmc unitary = self._unitary amplitude = self.get_amplitude(final_state, init_state, unitary) prob = torch.abs(amplitude) ** 2 return prob def _get_prob_fock_vmap(self, sub_mat: torch.Tensor, per_norm: torch.Tensor) -> torch.Tensor: """Get the transfer probability.""" amplitude = self._get_amplitude_fock_vmap(sub_mat, per_norm) prob = torch.abs(amplitude) ** 2 return prob def _get_prob_gaussian(self, final_state: Any, state: Any = None) -> torch.Tensor: """Get the batched probabilities of the final state for Gaussian backend.""" if not isinstance(final_state, torch.Tensor): final_state = torch.tensor(final_state, dtype=torch.long) if state is None: cov = self._cov mean = self._mean else: if not isinstance(state, GaussianState): state = GaussianState(state=state, cutoff=self.cutoff) cov = state.cov mean = state.mean if cov.ndim == 2: cov = cov.unsqueeze(0) if mean.ndim == 2: mean = mean.unsqueeze(0) assert cov.ndim == mean.ndim == 3 batch = cov.shape[0] probs = [] for i in range(batch): prob = self._get_probs_gaussian_helper(final_state, cov=cov[i], mean=mean[i], detector=self.detector)[0] probs.append(prob) return torch.stack(probs).squeeze() def _get_probs_gaussian_helper( self, final_states: torch.Tensor, cov: torch.Tensor, mean: torch.Tensor, detector: str = 'pnrd', purity: bool | None = None, loop: bool | None = None, ) -> torch.Tensor: """Get the probabilities of the final states for Gaussian backend.""" if loop is None: loop = ~torch.all(mean == 0) if final_states.ndim == 1: final_states = final_states.unsqueeze(0) assert final_states.ndim == 2 nmode = final_states.shape[-1] final_states = final_states.to(cov.device) identity = cov.new_ones(2 * nmode).diag_embed() cov_ladder = quadrature_to_ladder(cov) mean_ladder = quadrature_to_ladder(mean) q = cov_ladder + identity / 2 det_q = q.det() x_mat = identity.reshape(2, nmode, 2 * nmode).flip(0).reshape(2 * nmode, 2 * nmode) + 0j o_mat = identity - q.inverse() a_mat = x_mat @ o_mat gamma = mean_ladder.mH @ q.inverse() if detector == 'pnrd': matrix = a_mat elif detector == 'threshold': matrix = o_mat if purity is None: purity = GaussianState([cov, mean]).is_pure p_vac = torch.exp(-0.5 * mean_ladder.mH @ q.inverse() @ mean_ladder) / det_q.sqrt() batch_get_prob = vmap(self._get_prob_gaussian_base, in_dims=(0, None, None, None, None, None, None)) probs = batch_get_prob(final_states, matrix, gamma, p_vac, detector, purity, loop) return probs def _get_prob_gaussian_base( self, final_state: torch.Tensor, matrix: torch.Tensor, gamma: torch.Tensor, p_vac: torch.Tensor, detector: str = 'pnrd', purity: bool = True, loop: bool = False, ) -> torch.Tensor: """Get the probability of the final state for Gaussian backend.""" gamma = gamma.squeeze() nmode = len(final_state) with warnings.catch_warnings(): warnings.filterwarnings('ignore') # local warning gamma_n1 = torch.repeat_interleave(gamma[:nmode], final_state) gamma_n2 = torch.repeat_interleave(gamma[nmode:], final_state) sub_gamma = torch.cat([gamma_n1, gamma_n2]) if detector == 'pnrd': if purity: sub_mat = sub_matrix(matrix[:nmode, :nmode], final_state, final_state) half_len = len(sub_gamma) // 2 sub_gamma = sub_gamma[:half_len] else: final_state_double = torch.cat([final_state, final_state]) sub_mat = sub_matrix(matrix, final_state_double, final_state_double) if len(sub_gamma) == 1: sub_mat = sub_gamma else: sub_mat[torch.arange(len(sub_gamma)), torch.arange(len(sub_gamma))] = sub_gamma haf = abs(hafnian(sub_mat, loop=loop)) ** 2 if purity else hafnian(sub_mat, loop=loop) prob = p_vac * haf / product_factorial(final_state).to(haf.device, haf.dtype) elif detector == 'threshold': final_state_double = torch.cat([final_state, final_state]) sub_mat = sub_matrix(matrix, final_state_double, final_state_double) prob = p_vac * torontonian(sub_mat, sub_gamma) return abs(prob.real).squeeze() def _get_prob_mps(self, final_state: Any, wires: int | list[int] | None = None) -> torch.Tensor: """Get the probability of the given bit string for MPS. Args: final_state: The final Fock basis state. wires: The wires to measure. It can be an integer or a list of integers specifying the indices of the wires. """ if isinstance(final_state, FockState): final_state = final_state.state.tolist() wires = list(range(self.nmode)) if wires is None else self._convert_indices(wires) assert len(final_state) == len(wires) state = copy(self.state) if self.state[0].ndim == 3: state = [site.unsqueeze(0) for site in state] for i, wire in enumerate(wires): state[wire] = state[wire][..., [final_state[i]], :] return inner_product_mps(state, state).real
[docs] def measure( self, shots: int = 1024, with_prob: bool = False, wires: int | list[int] | None = None, detector: str | None = None, mcmc: bool = False, ) -> dict | list[dict] | None: """Measure the final state. Args: shots: The number of times to sample from the quantum state. Default: 1024 with_prob: A flag that indicates whether to return the probabilities along with the number of occurrences. Default: ``False`` wires: The wires to measure. It can be an integer or a list of integers specifying the indices of the wires. Default: ``None`` (which means all wires are measured) detector: For Gaussian backend, use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Default: ``None`` mcmc: Whether to use MCMC sampling method. Default: ``False`` See https://arxiv.org/pdf/2108.01622 for MCMC. """ assert self.backend in ('fock', 'gaussian'), 'Currently Fock measurement is not supported in Bosonic backend' if self.state is None: return if wires is None: wires = self.wires wires = sorted(self._convert_indices(wires)) if self.backend == 'fock': results = self._measure_fock(shots, with_prob, wires, mcmc) elif self.backend == 'gaussian': detector = self.detector if detector is None else detector.lower() results = self._measure_gaussian(shots, with_prob, wires, detector, mcmc) if len(results) == 1: results = results[0] return results
def _prob_dict_to_measure_result(self, prob_dict: dict, shots: int, with_prob: bool) -> dict: """Get the measurement result from the dictionary of probabilities.""" keys = list(prob_dict.keys()) probs = torch.cat(list(prob_dict.values())) samples = Counter(block_sample(probs, shots)) results = {keys[u]: v for u, v in samples.items()} if with_prob: results = {key: (value, prob_dict[key]) for key, value in results.items()} return results def _measure_fock(self, shots: int, with_prob: bool, wires: list[int], mcmc: bool) -> list[dict]: """Measure the final state for Fock backend.""" if isinstance(self.state, torch.Tensor): if self.basis: return self._measure_fock_unitary(shots, with_prob, wires, mcmc) else: assert not mcmc, "Final states have been calculated, we don't need mcmc!" return self._measure_fock_tensor(shots, with_prob, wires) elif isinstance(self.state, dict): assert not mcmc, "Final states have been calculated, we don't need mcmc!" return self._measure_dict(shots, with_prob, wires) elif isinstance(self.state, list): assert not mcmc, "Final states have been calculated, we don't need mcmc!" return self._measure_mps(shots, with_prob, wires) else: raise ValueError('Check your forward function or input!') def _measure_fock_unitary(self, shots: int, with_prob: bool, wires: list[int], mcmc: bool) -> list[dict]: """Measure the final state according to the unitary matrix for Fock backend.""" if self.state.ndim == 2: self.state = self.state.unsqueeze(0) batch = self.state.shape[0] init_state = self._init_state_forward if init_state.ndim == 1: init_state = init_state.unsqueeze(0) batch_init = init_state.shape[0] unitary = self.state if self._is_batch_expanded: identity = torch.eye(1, dtype=self.state.dtype, device=self.state.device) unitary = vmap(torch.block_diag, in_dims=(0, None))(self.state, identity) all_results = [] if mcmc: for i in range(batch): if batch_init == 1: samples_i = self._sample_mcmc_fock( shots=shots, init_state=init_state[0], unitary=unitary[i], num_chain=5 ) else: samples_i = self._sample_mcmc_fock( shots=shots, init_state=init_state[i], unitary=unitary[i], num_chain=5 ) results = defaultdict(list) if with_prob: for k in samples_i: prob = self._get_prob_fock(k) samples_i[k] = samples_i[k], prob for key in samples_i: state_b = [key[wire] for wire in wires] state_b = FockState(state=state_b) results[state_b].append(samples_i[key]) if with_prob: results = { key: (sum(count for count, _ in value), sum(prob for _, prob in value)) for key, value in results.items() } else: results = {key: sum(value) for key, value in results.items()} all_results.append(results) else: if batch_init == 1: prob_dict_batch = vmap(self._measure_fock_unitary_helper, in_dims=(None, 0, None))( init_state[0], unitary, wires ) else: prob_dict_batch = vmap(self._measure_fock_unitary_helper, in_dims=(0, 0, None))( init_state, unitary, wires ) for i in range(batch): prob_dict = {key: value[i] for key, value in prob_dict_batch.items()} results = self._prob_dict_to_measure_result(prob_dict, shots, with_prob) all_results.append(results) return all_results def _measure_fock_unitary_helper( self, init_state: torch.Tensor, unitary: torch.Tensor, wires: int | list[int] | None = None, final_states: torch.Tensor | None = None, ) -> dict: """VMAP helper for measuring the final state according to the unitary matrix for Fock backend. Returns: A dictionary of probabilities for final states. """ if final_states is None: final_states = self._out_fock_basis sub_mats = vmap(sub_matrix, in_dims=(None, None, 0))(unitary, init_state, final_states) per_norms = self._get_permanent_norms(init_state, final_states).to(unitary.dtype) rst = vmap(self._get_prob_fock_vmap)(sub_mats, per_norms) state_dict = {} prob_dict = defaultdict(list) for i in range(len(final_states)): final_state = FockState(state=final_states[i]) state_dict[final_state] = rst[i] for key in state_dict: state_b = key.state[wires] state_b = FockState(state=state_b) prob_dict[state_b].append(state_dict[key]) prob_dict = {key: sum(value) for key, value in prob_dict.items()} return prob_dict def _measure_dict(self, shots: int, with_prob: bool, wires: list[int]) -> list[dict]: """Measure the final state according to the dictionary of amplitudes or probabilities.""" if self._with_delay: wires = [self._unroll_dict[wire][-1] for wire in wires] all_results = [] batch = len(self.state[list(self.state.keys())[0]]) is_complex = any(v.dtype.is_complex for v in self.state.values()) is_prob = not (self.backend == 'fock' and is_complex) for i in range(batch): prob_dict = defaultdict(list) for key in self.state: if wires == self.wires: state_b = key else: state_b = key.state[wires] state_b = FockState(state=state_b) if is_prob: prob_dict[state_b].append(self.state[key][i].reshape(-1)) else: prob_dict[state_b].append(abs(self.state[key][i].reshape(-1)) ** 2) prob_dict = {key: sum(value) for key, value in prob_dict.items()} results = self._prob_dict_to_measure_result(prob_dict, shots, with_prob) all_results.append(results) return all_results def _measure_fock_tensor(self, shots: int, with_prob: bool, wires: list[int]) -> list[dict]: """Measure the final state according to Fock state tensor for Fock backend.""" all_results = [] if self.state.is_complex(): if self.den_mat: state_tensor = self.state.reshape(-1, self.cutoff**self.nmode, self.cutoff**self.nmode) state_tensor = abs(state_tensor.diagonal(dim1=-2, dim2=-1)).reshape([-1] + [self.cutoff] * self.nmode) else: state_tensor = self.tensor_rep(abs(self.state) ** 2) else: state_tensor = self.state.reshape([-1] + [self.cutoff] * self.nmode) batch = state_tensor.shape[0] combi = list(itertools.product(range(self.cutoff), repeat=len(wires))) for i in range(batch): prob_dict = {} probs = state_tensor[i] if wires == self.wires: ptrace_probs = probs else: sum_idx = list(range(self.nmode)) for idx in wires: sum_idx.remove(idx) ptrace_probs = probs.sum(dim=sum_idx) for p_state in combi: p_state_b = FockState(list(p_state)) prob_dict[p_state_b] = ptrace_probs[p_state].reshape(-1) results = self._prob_dict_to_measure_result(prob_dict, shots, with_prob) all_results.append(results) return all_results def _measure_mps(self, shots: int, with_prob: bool, wires: list[int]) -> list[dict]: """Measure the final state according to MPS.""" all_results = [] samples = [] for _ in range(shots): samples.append(self._generate_chain_sample(wires, self.detector)) for j in range(samples[0].shape[0]): samples_j = [tuple(sample[j].tolist()) for sample in samples] samples_j = dict(Counter(samples_j)) keys = list(map(FockState, samples_j.keys())) results = dict(zip(keys, samples_j.values(), strict=True)) if with_prob: for k in results: prob = self._get_prob_mps(k, wires)[j] results[k] = results[k], prob all_results.append(results) return all_results def _sample_mcmc_fock(self, shots: int, init_state: torch.Tensor, unitary: torch.Tensor, num_chain: int): """Sample the output states for Fock backend via SC-MCMC method.""" self._init_state_sample = init_state self._unitary = unitary self._out_fock_basis = self._get_all_fock_basis(init_state) merged_samples = sample_sc_mcmc( prob_func=self._get_prob_fock, proposal_sampler=self._proposal_sampler, shots=shots, num_chain=num_chain ) return merged_samples def _measure_gaussian(self, shots: int, with_prob: bool, wires: list[int], detector: str, mcmc: bool) -> list[dict]: """Measure the final state for Gaussian backend.""" if isinstance(self.state, list): return self._measure_gaussian_state(shots, with_prob, wires, detector, mcmc) elif isinstance(self.state, dict): assert not mcmc, "Final states have been calculated, we don't need mcmc!" print('Automatically using the default detector!') return self._measure_dict(shots, with_prob, wires) else: raise ValueError('Check your forward function or input!') def _measure_gaussian_state( self, shots: int, with_prob: bool, wires: list[int], detector: str, mcmc: bool ) -> list[dict]: """Measure the final state according to Gaussian state for Gaussian backend. See https://arxiv.org/pdf/2108.01622 """ assert not self._with_delay, 'Currently Fock measurement is not supported with delay loops' cov, mean = self.state batch = cov.shape[0] all_results = [] all_samples = [] if mcmc: print('Using MCMC method to sample the final states!') for i in range(batch): samples_i = self._sample_mcmc_gaussian( shots=shots, cov=cov[i], mean=mean[i], detector=detector, num_chain=5 ) all_samples.append(samples_i) else: # chain-rule method with small number of shots print('Using chain-rule method to sample the final states!') samples = [] for _ in range(shots): sample = self._generate_chain_sample(wires, detector) samples.append(sample) samples = torch.stack(samples).permute(1, 0, 2) # (batch, shots, wires) for i in range(batch): sample_lst = samples[i].tolist() sample_tup = [tuple(s) for s in sample_lst] samples_i = defaultdict(int, Counter(sample_tup)) all_samples.append(samples_i) for i, samples_i in enumerate(all_samples): # post-process samples results = defaultdict(list) if with_prob: for k in samples_i: if mcmc: prob = self._get_prob_gaussian(k, [cov[i], mean[i]]) else: wires_ = torch.tensor(wires, device=cov.device) idx = torch.cat([wires_, wires_ + self.nmode]) prob = self._get_prob_gaussian(k, [cov[i][idx[:, None], idx], mean[i][idx, :]]) samples_i[k] = samples_i[k], prob for key in samples_i: state_b = [key[wire] for wire in wires] if mcmc else list(key) state_b = FockState(state=state_b) results[state_b].append(samples_i[key]) if with_prob: results = { key: (sum(count for count, _ in value), sum(prob for _, prob in value)) for key, value in results.items() } else: results = {key: sum(value) for key, value in results.items()} all_results.append(results) return all_results def _sample_mcmc_gaussian(self, shots: int, cov: torch.Tensor, mean: torch.Tensor, detector: str, num_chain: int): """Sample the output states for Gaussian backend via SC-MCMC method.""" self._cov = cov self._mean = mean self.detector = detector if detector == 'threshold' and not torch.allclose(mean, torch.zeros_like(mean)): # For the displaced state, aggregate PNRD detector samples to derive threshold detector results self.detector = 'pnrd' merged_samples_pnrd = sample_sc_mcmc( prob_func=self._get_prob_gaussian, proposal_sampler=self._proposal_sampler, shots=shots, num_chain=num_chain, ) merged_samples = defaultdict(int) for key in list(merged_samples_pnrd.keys()): key_threshold = (torch.tensor(key) != 0).int() key_threshold = tuple(key_threshold.tolist()) merged_samples[key_threshold] += merged_samples_pnrd[key] self.detector = 'threshold' else: merged_samples = sample_sc_mcmc( prob_func=self._get_prob_gaussian, proposal_sampler=self._proposal_sampler, shots=shots, num_chain=num_chain, ) return merged_samples def _proposal_sampler(self): """The proposal sampler for MCMC sampling.""" if self.backend == 'fock': assert self.basis, 'Currently NOT supported.' sample = self._out_fock_basis[torch.randint(0, len(self._out_fock_basis), (1,))[0]] elif self.backend == 'gaussian': sample = self._generate_rand_sample(self.detector) return tuple(sample.tolist()) def _generate_rand_sample(self, detector: str = 'pnrd'): """Generate a random sample according to uniform proposal distribution.""" nmode = self._nmode_tdm if self._with_delay else self.nmode if detector == 'threshold': sample = torch.randint(0, 2, [nmode]) elif detector == 'pnrd': sample = torch.randint(0, self.cutoff, [nmode]) return sample def _generate_chain_sample(self, wires: list[int], detector: str) -> torch.Tensor: """Generate batched random samples via chain rule. Args: wires: The wires to measure. It can be a list of integers specifying the indices of the wires. detector: For Gaussian backend, use ``'pnrd'`` for the photon-number-resolving detector or ``'threshold'`` for the threshold detector. Returns: Tensor of shape (batch, nwire). """ sample = [] if self.backend == 'fock': assert self.mps mps = copy(self.state) if mps[0].ndim == 3: mps = [site.unsqueeze(0) for site in mps] for i in wires: p = vmap(get_prob_mps)(mps, wire=i) sample_single_wire = torch.multinomial(p, num_samples=1) sample.append(sample_single_wire) index = sample_single_wire.reshape(-1, 1, 1, 1).expand(-1, mps[i].shape[-3], -1, mps[i].shape[-1]) mps[i] = torch.gather(mps[i], dim=2, index=index) sample = torch.stack(sample, dim=-1).squeeze(1) elif self.backend == 'gaussian': # chain rule for GBS sample = self._generate_chain_sample_gaussian(wires, detector) return sample def _generate_chain_sample_gaussian(self, wires: list[int], detector: str) -> torch.Tensor: """Generate batched random samples via chain rule for Gaussian backend. See https://research-information.bris.ac.uk/en/studentTheses/classical-simulations-of-gaussian-boson-sampling Chapter 5 """ def _sample_wire(sample, cov_sub, mean_sub, cutoff, detector): """Sample for a wire.""" states = [torch.tensor(sample + [i], device=cov_sub.device) for i in range(cutoff)] probs = [self._get_probs_gaussian_helper(s, cov_sub, mean_sub, detector) for s in states] sample_wire = torch.multinomial(torch.cat(probs), num_samples=1) return sample_wire def _sample_pure(cov, mean, wires, nmode, cutoff, detector): """Sample for a pure state.""" wires = torch.tensor(wires, device=cov.device) sample = [] for i in range(1, len(wires) + 1): idx = torch.cat([wires[:i], wires[:i] + nmode]) cov_sub = cov[idx[:, None], idx] mean_sub = mean[idx, :] sample_wire = _sample_wire(sample, cov_sub, mean_sub, cutoff, detector) sample.append(sample_wire) return torch.cat(sample) def _sample_mixed(cov, mean, wires, nmode, cutoff, detector, eps=5e-5): """Sample for a mixed state.""" wires = torch.tensor(wires, device=cov.device) _, s = williamson(cov) cov_t = s @ s.mT * dqp.hbar / (4 * dqp.kappa**2) cov_w = cov - cov_t # cov_mix = cov_t + cov_w cov_w += cov.new_ones(cov_w.shape[-1]).diag_embed() * eps mean0 = MultivariateNormal(mean.squeeze(-1), cov_w).sample([1])[0] # may be numerically unstable sample = [] mean_m = None for i in range(1, len(wires) + 1): wires_i = wires[i:].tolist() cov_m = cov.new_ones(2 * len(wires_i)).diag_embed() * dqp.hbar / (4 * dqp.kappa**2) # See Eq.(5.18) heterodyne = Generaldyne(cov_m=cov_m, nmode=nmode, wires=wires_i) # collapse the state state = [cov_t.unsqueeze(0), mean0.reshape(1, -1, 1)] if i < len(wires): cov_out, mean_out = heterodyne(state, mean_m) mean_m = heterodyne.samples[0] # with batch mask = torch.ones_like(mean_m, dtype=bool) idx_discard = torch.tensor([0, len(mean_m) // 2], device=mask.device) mask[idx_discard] = False mean_m = mean_m[mask] # discard the first mode else: cov_out, mean_out = state idx = torch.cat([wires[:i], wires[:i] + nmode]) cov_sub = cov_out[0, idx[:, None], idx] mean_sub = mean_out[0, idx, :] sample_wire = _sample_wire(sample, cov_sub, mean_sub, cutoff, detector) sample.append(sample_wire) return torch.cat(sample) sample = [] purity = GaussianState(self.state).is_pure cov, mean = self.state batch = cov.shape[0] cutoff = 2 if detector == 'threshold' else self.cutoff if purity: for i in range(batch): sample.append(_sample_pure(cov[i], mean[i], wires, self.nmode, cutoff, detector)) else: for i in range(batch): sample.append(_sample_mixed(cov[i], mean[i], wires, self.nmode, cutoff, detector)) sample = torch.stack(sample) return sample
[docs] def photon_number_mean_var(self, wires: int | list[int] | None = None) -> tuple[torch.Tensor, torch.Tensor] | None: """Get the expectation value and variance of the photon number operator. Args: wires: The wires to measure. It can be an integer or a list of integers specifying the indices of the wires. Default: ``None`` (which means all wires are measured) """ if self.state is None: return if wires is None: wires = self.wires wires = sorted(self._convert_indices(wires)) if self._with_delay: wires = [self._unroll_dict[wire][-1] for wire in wires] if self.backend == 'fock': assert not self.basis exp, var = photon_number_mean_var_fock(self.state, self.nmode, self.cutoff, wires, self.den_mat) elif self.backend in ('gaussian', 'bosonic'): if self.backend == 'gaussian': cov, mean = self.state elif self.backend == 'bosonic': cov, mean, weight = self.state shape_cov = cov.shape shape_mean = mean.shape batch = shape_cov[0] nwire = len(wires) cov = cov.reshape(-1, *shape_cov[-2:]) mean = mean.reshape(-1, *shape_mean[-2:]) covs, means = self._get_local_covs_means(cov, mean, wires) if self.backend == 'gaussian': weights = None elif self.backend == 'bosonic': covs = covs.reshape(*shape_cov[:2], nwire, 2, 2).transpose(1, 2) covs = covs.reshape(-1, shape_cov[-3], 2, 2) # (batch*nwire, ncomb, 2, 2) means = means.reshape(*shape_mean[:2], nwire, 2, 1).transpose(1, 2) means = means.reshape(-1, shape_mean[-3], 2, 1) if weight.shape[0] == 1: weights = weight else: weights = torch.stack([weight] * nwire, dim=-2).reshape(batch * nwire, weight.shape[-1]) ncomb = weights.shape[-1] if covs.shape[1] == 1: covs = covs.expand(-1, ncomb, -1, -1) if means.shape[1] == 1: means = means.expand(-1, ncomb, -1, -1) exp, var = photon_number_mean_var_cv(covs, means, weights) exp = exp.reshape(batch, nwire).squeeze() var = var.reshape(batch, nwire).squeeze() return exp, var
[docs] def quadrature_mean(self, wires: int | list[int] | None = None, phi: Any | None = None) -> torch.Tensor: r"""Get the expectation value of the quadratuere operator :math:`\hat{X}\cos\phi + \hat{P}\sin\phi`. If ``self.measurements`` is empty, this method directly computes the quadrature expectation values for the specified ``wires`` and ``phi``. If ``self.measurements`` is specified via ``self.homodyne``, the wires and corresponding homodyne angles are inferred from the stored measurement instructions, and the input arguments ``wires`` and ``phi`` are ignored. Args: wires: The wires to measure. It can be an integer or a list of integers specifying the indices of the wires. Default: ``None`` (which means all wires are measured). phi: The phi angles for quadrature operator :math:`\hat{X}\cos\phi + \hat{P}\sin\phi`. Default: ``None`` """ if len(self.measurements) == 0: if wires is None: wires = self.wires if phi is None: phi = torch.zeros(len(wires)) elif not isinstance(phi, torch.Tensor): phi = torch.tensor(phi) if phi.numel() == 1: phi = phi.reshape(-1).expand(len(wires)) else: wires = [] phi = [] for mea in self.measurements: wires = wires + mea.wires phi.append(mea.phi) phi = torch.cat(phi) wires = self._convert_indices(wires) assert len(wires) == len(phi), f'phi length {len(phi)} must match wires length {len(wires)}' if self.backend == 'fock': assert not self.basis state = self.state if self.den_mat: state = state.reshape([-1] + [self.cutoff] * 2 * self.nmode) for i in range(len(wires)): if not torch.isclose(phi[i], phi.new_tensor(0.0)): r = PhaseShift( inputs=-phi[i], nmode=self.nmode, wires=wires[i], cutoff=self.cutoff, den_mat=self.den_mat ) state = r(state) mean = quadrature_mean_fock(state, self.nmode, self.cutoff, wires, self.den_mat) return mean elif self.backend in ('gaussian', 'bosonic'): wires = torch.tensor(wires) idx = torch.cat([wires, wires + self.nmode]) # xxpp order means = self.state[1][..., idx, :] phi = phi.reshape(-1, 1) mean = means[..., : len(wires), :] * torch.cos(phi) + means[..., len(wires) :, :] * torch.sin(phi) if self.backend == 'bosonic': weight = self.state[2].unsqueeze(-1).unsqueeze(-2) mean = torch.sum(weight * mean, dim=1) return mean
def _get_local_covs_means( self, cov: torch.Tensor, mean: torch.Tensor, wires: list[int] ) -> tuple[torch.Tensor, torch.Tensor]: """Get the local covariance matrices and mean vectors of a Gaussian state according to the wires to measure.""" def extract_blocks(mat: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """Extract specified blocks from the input tensor. Args: mat: Input tensor. idx: Index tensor of shape (nblock, block_size), where each row contains the row/column indices for a block. Returns: Output tensor of shape (batch, nblock, block_size, -1) containing all extracted blocks. """ nblock, block_size = idx.shape if mat.shape[-2] == mat.shape[-1]: # cov rows = idx[:, :, None].expand(-1, -1, block_size) # (nblock, block_size, block_size) cols = idx[:, None, :].expand(-1, block_size, -1) # (nblock, block_size, block_size) all_rows = rows.reshape(-1) all_cols = cols.reshape(-1) out = mat[:, all_rows, all_cols] elif mat.shape[-1] == 1: # mean out = mat[:, idx, :] return out.reshape(mat.shape[0], nblock, block_size, -1) indices = [] nmode = self._nmode_tdm if self._with_delay else self.nmode for wire in wires: indices.append([wire] + [wire + nmode]) indices = torch.tensor(indices, device=cov.device) covs = extract_blocks(cov, indices).reshape(-1, 2, 2) # batch * nwire means = extract_blocks(mean, indices).reshape(-1, 2, 1) return covs, means
[docs] def measure_homodyne(self, shots: int = 10, wires: int | list[int] | None = None) -> torch.Tensor | None: """Get the homodyne measurement results. If ``self.measurements`` is specified via ``self.homodyne``, return the results of the conditional homodyne measurement. Otherwise, return the results of the ideal homodyne measurement. The Gaussian states after measurements are stored in ``self.state_measured``. Note: ``batch`` * ``shots`` can not be too large for Fock backend. Args: shots: The number of times to sample from the quantum state. Default: 10 wires: The wires to measure for the ideal homodyne. It can be an integer or a list of integers specifying the indices of the wires. Default: ``None`` (which means all wires are measured) """ if self.state is None: return assert isinstance(self.state, (list, torch.Tensor)), 'NOT valid when "is_prob" is True' if len(self.measurements) > 0: measurements = self._measurements_tdm if self._with_delay else self.measurements samples = [] if self.backend == 'fock': assert not self.basis assert not self.mps, 'Currently NOT supported.' shape = self.state.shape batch = shape[0] self.state_measured = torch.stack([self.state] * shots).reshape(-1, *shape[1:]) else: batch = self.state[0].shape[0] self.state_measured = [] state = align_shape(*self.state) if self.backend == 'bosonic' else self.state for s in state: # [cov, mean, weight] shape = s.shape self.state_measured.append(torch.stack([s] * shots).reshape(-1, *shape[1:])) for op_m in measurements: self.state_measured = op_m(self.state_measured) nwire = len(op_m.wires) samples.append(op_m.samples[:, :nwire].reshape(shots, batch, nwire).permute(1, 0, 2)) return torch.cat(samples, dim=-1).squeeze() # (batch, shots, nwire) else: if wires is None: wires = self.wires wires = torch.tensor(sorted(self._convert_indices(wires))) if self.backend == 'fock': assert not self.basis assert len(wires) == 1 # (batch, shots, 1) samples = sample_homodyne_fock(self.state, wires[0], self.nmode, self.cutoff, shots, self.den_mat) else: cov, mean = self.state[:2] if not is_positive_definite(cov): size = cov.size() if cov.dtype == torch.double: epsilon = 1e-16 elif cov.dtype == torch.float: epsilon = 1e-8 else: raise ValueError('Unsupported dtype.') cov += epsilon * cov.new_ones(size[:-1].numel()).reshape(size[:-1]).diag_embed() idx = torch.cat([wires, wires + self.nmode]) cov_sub = cov[..., idx[:, None], idx] mean_sub = mean[..., idx, :] if len(self.state) == 2: # (shots, batch, 2 * nwire) samples = MultivariateNormal(mean_sub.squeeze(-1), cov_sub).sample([shots]) samples = samples.permute(1, 0, 2) elif len(self.state) == 3: cov_sub, mean_sub, weight = align_shape(cov_sub, mean_sub, self.state[2]) samples = sample_reject_bosonic(cov_sub, mean_sub, weight, cov_sub.new_zeros(1), shots) return samples.squeeze()
@property def max_depth(self) -> int: """Get the max number of gates on the wires.""" return max(self.depth)
[docs] def draw(self, filename: str | None = None, unroll: bool = False): """Visualize the photonic quantum circuit. Args: filename: The path for saving the figure. unroll: Whether to draw the unrolled circuit. """ if self._with_delay and unroll: self._prepare_unroll_dict() self._unroll_circuit() nmode = self._nmode_tdm operators = self._operators_tdm measurements = self._measurements_tdm else: nmode = self.nmode operators = self.operators measurements = self.measurements self.draw_circuit = DrawCircuit(self.name, nmode, operators, measurements) self.draw_circuit.draw() if filename is not None: self.draw_circuit.save(filename) else: if self.nmode > 50: print('Too many modes in the circuit, please set filename to save the figure.') return self.draw_circuit.draw_
[docs] def cat(self, wires: int, r: Any = None, theta: Any = None, p: int = 1) -> None: """Prepare a cat state. ``r`` and ``theta`` are the displacement magnitude and angle respectively. ``p`` is the parity, corresponding to an even or odd cat state when ``p=0`` or ``p=1`` respectively. """ if self._bosonic_states is None: self._bosonic_states = nn.ModuleList([BosonicState(state='vac', nmode=1, cutoff=self.cutoff)] * self.nmode) cat = CatState(r=r, theta=theta, p=p, cutoff=self.cutoff) self._bosonic_states[wires] = cat
[docs] def gkp( self, wires: int, theta: Any = None, phi: Any = None, amp_cutoff: float = 0.1, epsilon: float = 0.05 ) -> None: """Prepare a GKP state. ``theta`` and ``phi`` are angles in Bloch sphere. ``amp_cutoff`` is the amplitude threshold for keeping the terms. ``epsilon`` is the finite energy damping parameter. """ if self._bosonic_states is None: self._bosonic_states = nn.ModuleList([BosonicState(state='vac', nmode=1, cutoff=self.cutoff)] * self.nmode) gkp = GKPState(theta=theta, phi=phi, amp_cutoff=amp_cutoff, epsilon=epsilon, cutoff=self.cutoff) self._bosonic_states[wires] = gkp
[docs] def add(self, op: Operation, encode: bool = False, wires: int | list[int] | None = None) -> None: """A method that adds an operation to the photonic quantum circuit. The operation can be a gate or another photonic quantum circuit. The method also updates the attributes of the photonic quantum circuit. If ``wires`` is specified, the parameters of gates are shared. Args: op: The operation to add. It is an instance of ``Operation`` class or its subclasses, such as ``Gate``, or ``QumodeCircuit``. encode: Whether the gate is to encode data. Default: ``False`` wires: The wires to apply the gate on. It can be an integer or a list of integers specifying the indices of the wires. Default: ``None`` (which means the gate has its own wires) Raises: AssertionError: If the input arguments are invalid or incompatible with the quantum circuit. """ assert isinstance(op, Operation) if wires is not None: assert isinstance(op, Gate) wires = self._convert_indices(wires) assert len(wires) == len(op.wires), 'Invalid input' op = copy(op) op.wires = wires if isinstance(op, QumodeCircuit): assert self.nmode == op.nmode self.operators += op.operators self.encoders += op.encoders self.measurements = op.measurements self.wires_homodyne = op.wires_homodyne self.npara += op.npara self.ndata += op.ndata self.depth += op.depth self._lossy = self._lossy or op._lossy self._nloss += op._nloss self._with_delay = self._with_delay or op._with_delay self._nmode_tdm += op._nmode_tdm - self.nmode for key, value in op._ntau_dict.items(): self._ntau_dict[key].extend(value) self._unroll_dict = None self._operators_tdm = None self._measurements_tdm = None elif isinstance(op, (Gate, Channel, Delay)): self.operators.append(op) for i in op.wires: self.depth[i] += 1 if encode: assert not op.requires_grad, 'Please set requires_grad of the operation to be False' self.encoders.append(op) self.ndata += op.npara else: self.npara += op.npara if isinstance(op, Delay): self._with_delay = True self._nmode_tdm += op.ntau self._ntau_dict[op.wires[0]].append(op.ntau) elif isinstance(op, Homodyne): self.measurements.append(op) self.wires_homodyne.append(op.wires[0])
[docs] def ps( self, wires: int, inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None ) -> None: """Add a phase shifter.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma ps = PhaseShift( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(ps, encode=encode)
[docs] def bs( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a beam splitter.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitter( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs, encode=encode)
[docs] def mzi( self, wires: list[int], inputs: Any = None, phi_first: bool = True, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a Mach-Zehnder interferometer.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma mzi = MZI( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, phi_first=phi_first, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(mzi, encode=encode)
[docs] def bs_theta( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: r"""Add a beam splitter with fixed :math:`\phi` at :math:`\pi/2`.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitterTheta( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs, encode=encode)
[docs] def bs_phi( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: r"""Add a beam splitter with fixed :math:`\theta` at :math:`\pi/4`.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitterPhi( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs, encode=encode)
[docs] def bs_rx( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add an Rx-type beam splitter.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitterSingle( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, convention='rx', requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs, encode=encode)
[docs] def bs_ry( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add an Ry-type beam splitter.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitterSingle( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, convention='ry', requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs, encode=encode)
[docs] def bs_h( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add an H-type beam splitter.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitterSingle( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, convention='h', requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs, encode=encode)
[docs] def dc(self, wires: list[int], mu: float | None = None, sigma: float | None = None) -> None: """Add a directional coupler.""" theta = torch.pi / 2 if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitterSingle( inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, convention='rx', requires_grad=False, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs)
[docs] def h(self, wires: list[int], mu: float | None = None, sigma: float | None = None) -> None: """Add a photonic Hadamard gate.""" theta = torch.pi / 2 if mu is None: mu = self.mu if sigma is None: sigma = self.sigma bs = BeamSplitterSingle( inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, convention='h', requires_grad=False, noise=self.noise, mu=mu, sigma=sigma, ) self.add(bs)
[docs] def any( self, unitary: Any, wires: int | list[int] | None = None, minmax: list[int] | None = None, name: str = 'uany' ) -> None: """Add an arbitrary unitary gate.""" uany = UAnyGate( unitary=unitary, nmode=self.nmode, wires=wires, minmax=minmax, cutoff=self.cutoff, den_mat=self.den_mat, name=name, ) self.add(uany)
[docs] def clements( self, unitary: Any, wires: int | list[int] | None = None, minmax: list[int] | None = None, mu: float | None = None, sigma: float | None = None, ) -> None: """Add the Clements architecture of the unitary matrix. This is equivalent to ``any``, using `'cssr'`-type Clements decomposition. When ``basis`` is ``False``, this implementation is much faster. """ if wires is None: if minmax is None: minmax = [0, self.nmode - 1] self._check_minmax(minmax) wires = list(range(minmax[0], minmax[1] + 1)) else: wires = self._convert_indices(wires) if mu is None: mu = self.mu if sigma is None: sigma = self.sigma # clements decomposition ud = UnitaryDecomposer(unitary, 'cssr') mzi_info = ud.decomp() dic_mzi = mzi_info[1] phase_angle = mzi_info[0]['phase_angle'] assert len(phase_angle) == len(wires), 'Please check wires' wires1 = wires[1::2] wires2 = wires[2::2] shift = wires[0] # clements decomposition starts from 0 for i in range(len(wires)): if i % 2 == 0: idx = i // 2 for j in range(len(wires1)): phi, theta = dic_mzi[(wires1[j] - 1 - shift, wires1[j] - shift)][idx] self.mzi(wires=[wires1[j] - 1, wires1[j]], inputs=[theta, phi], mu=mu, sigma=sigma) else: idx = (i - 1) // 2 for j in range(len(wires2)): phi, theta = dic_mzi[(wires2[j] - 1 - shift, wires2[j] - shift)][idx] self.mzi(wires=[wires2[j] - 1, wires2[j]], inputs=[theta, phi], mu=mu, sigma=sigma) for wire in wires: self.ps(wires=wire, inputs=phase_angle[wire - shift], mu=mu, sigma=sigma)
[docs] def s( self, wires: int, r: Any = None, theta: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a squeezing gate.""" requires_grad = not encode if r is None and theta is None: inputs = None else: requires_grad = False if r is None: inputs = [torch.rand(1)[0], theta] elif theta is None: inputs = [r, 0] else: inputs = [r, theta] if mu is None: mu = self.mu if sigma is None: sigma = self.sigma s = Squeezing( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(s, encode=encode)
[docs] def s2( self, wires: list[int], r: Any = None, theta: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a two-mode squeezing gate.""" requires_grad = not encode if r is None and theta is None: inputs = None else: requires_grad = False if r is None: inputs = [torch.rand(1)[0], theta] elif theta is None: inputs = [r, 0] else: inputs = [r, theta] if mu is None: mu = self.mu if sigma is None: sigma = self.sigma s2 = Squeezing2( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(s2, encode=encode)
[docs] def d( self, wires: int, r: Any = None, theta: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a displacement gate.""" requires_grad = not encode if r is None and theta is None: inputs = None else: requires_grad = False if r is None: inputs = [torch.rand(1)[0], theta] elif theta is None: inputs = [r, 0] else: inputs = [r, theta] if mu is None: mu = self.mu if sigma is None: sigma = self.sigma d = Displacement( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(d, encode=encode)
[docs] def x( self, wires: int, inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None ) -> None: """Add a position displacement gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma dx = DisplacementPosition( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(dx, encode=encode)
[docs] def z( self, wires: int, inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None ) -> None: """Add a momentum displacement gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma dp = DisplacementMomentum( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(dp, encode=encode)
[docs] def r( self, wires: int, inputs: Any = None, encode: bool = False, inv_mode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a rotation gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma r = PhaseShift( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, inv_mode=inv_mode, ) self.add(r, encode=encode)
[docs] def f(self, wires: int, mu: float | None = None, sigma: float | None = None) -> None: """Add a Fourier gate.""" theta = torch.pi / 2 if mu is None: mu = self.mu if sigma is None: sigma = self.sigma f = PhaseShift( inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=False, noise=self.noise, mu=mu, sigma=sigma, ) self.add(f)
[docs] def qp( self, wires: int, inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None ) -> None: """Add a quadratic phase gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma qp = QuadraticPhase( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(qp, encode=encode)
[docs] def cx( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a controlled-X gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma cx = ControlledX( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(cx, encode=encode)
[docs] def cz( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a controlled-Z gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma cz = ControlledZ( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(cz, encode=encode)
[docs] def cp( self, wires: int, inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None ) -> None: """Add a cubic phase gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma cp = CubicPhase( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(cp, encode=encode)
[docs] def k( self, wires: int, inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None ) -> None: """Add a Kerr gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma k = Kerr( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(k, encode=encode)
[docs] def ck( self, wires: list[int], inputs: Any = None, encode: bool = False, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a cross-Kerr gate.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma ck = CrossKerr( inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, noise=self.noise, mu=mu, sigma=sigma, ) self.add(ck, encode=encode)
[docs] def delay( self, wires: int, ntau: int = 1, inputs: Any = None, convention: str = 'bs', encode: bool = False, loop_gates: list | None = None, mu: float | None = None, sigma: float | None = None, ) -> None: """Add a delay loop.""" requires_grad = not encode if inputs is not None: requires_grad = False if mu is None: mu = self.mu if sigma is None: sigma = self.sigma if convention == 'bs': delay = DelayBS( inputs=inputs, ntau=ntau, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, loop_gates=loop_gates, noise=self.noise, mu=mu, sigma=sigma, ) elif convention == 'mzi': delay = DelayMZI( inputs=inputs, ntau=ntau, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, requires_grad=requires_grad, loop_gates=loop_gates, noise=self.noise, mu=mu, sigma=sigma, ) self.add(delay, encode=encode)
[docs] def homodyne( self, wires: int, phi: Any = None, eps: float = 2e-4, mu: float | None = None, sigma: float | None = None ) -> None: """Add a homodyne measurement.""" if mu is None: mu = self.mu if sigma is None: sigma = self.sigma homodyne = Homodyne( phi=phi, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, eps=eps, requires_grad=False, noise=self.noise, mu=mu, sigma=sigma, ) self.add(homodyne)
[docs] def homodyne_x(self, wires: int, eps: float = 2e-4, mu: float | None = None, sigma: float | None = None) -> None: """Add a homodyne measurement for quadrature x.""" phi = 0.0 if mu is None: mu = self.mu if sigma is None: sigma = self.sigma homodyne = Homodyne( phi=phi, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, eps=eps, requires_grad=False, noise=self.noise, mu=mu, sigma=sigma, ) self.add(homodyne)
[docs] def homodyne_p(self, wires: int, eps: float = 2e-4, mu: float | None = None, sigma: float | None = None) -> None: """Add a homodyne measurement for quadrature p.""" phi = np.pi / 2 if mu is None: mu = self.mu if sigma is None: sigma = self.sigma homodyne = Homodyne( phi=phi, nmode=self.nmode, wires=wires, cutoff=self.cutoff, den_mat=self.den_mat, eps=eps, requires_grad=False, noise=self.noise, mu=mu, sigma=sigma, ) self.add(homodyne)
[docs] def loss(self, wires: int, inputs: Any = None, encode: bool = False) -> None: """Add a photon loss channel. The `inputs` corresponds to `theta` of the loss channel. """ if self.backend == 'fock' and not self.basis: assert self.den_mat, 'Please use the density matrix representation' self._lossy = True self._nloss += 1 requires_grad = not encode if inputs is not None: requires_grad = False loss = PhotonLoss(inputs=inputs, nmode=self.nmode, wires=wires, cutoff=self.cutoff, requires_grad=requires_grad) self.add(loss, encode=encode)
[docs] def loss_t(self, wires: int, inputs: Any = None, encode: bool = False) -> None: """Add a photon loss channel. The `inputs` corresponds to the transmittance of the loss channel. """ if self.backend == 'fock' and not self.basis: assert self.den_mat, 'Please use the density matrix representation' self._lossy = True self._nloss += 1 requires_grad = not encode if inputs is not None: requires_grad = False if not isinstance(inputs, torch.Tensor): inputs = torch.tensor(inputs, dtype=torch.float) theta = torch.arccos(inputs**0.5) * 2 loss = PhotonLoss(inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, requires_grad=requires_grad) self.add(loss, encode=encode)
[docs] def loss_db(self, wires: int, inputs: Any = None, encode: bool = False) -> None: """Add a photon loss channel. The `inputs` corresponds to the probability of loss with the unit of dB and is positive. """ if self.backend == 'fock' and not self.basis: assert self.den_mat, 'Please use the density matrix representation' self._lossy = True self._nloss += 1 requires_grad = not encode if inputs is not None: requires_grad = False if not isinstance(inputs, torch.Tensor): inputs = torch.tensor(inputs, dtype=torch.float) t = 10 ** (-inputs / 10) theta = torch.arccos(t**0.5) * 2 loss = PhotonLoss(inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, requires_grad=requires_grad) self.add(loss, encode=encode)
[docs] def barrier(self, wires: int | list[int] | None = None) -> None: """Add a barrier.""" br = Barrier(nmode=self.nmode, wires=wires, cutoff=self.cutoff) self.add(br)
[docs] class DistributedQumodeCircuit(QumodeCircuit): """Photonic quantum circuit for a distributed Fock state. Args: nmode: The number of modes in the circuit. init_state: The initial state of the circuit. It can be a vacuum state with ``'vac'`` or ``'zeros'``. It can be a Fock basis state, e.g., ``[1,0,0]``, or a Fock state tensor, e.g., ``[(1/2**0.5, [1,0]), (1/2**0.5, [0,1])]``. cutoff: The Fock space truncation. Default: ``None`` name: The name of the circuit. Default: ``None`` """ def __init__(self, nmode: int, init_state: Any, cutoff: int | None = None, name: str | None = None) -> None: super().__init__( nmode, init_state, cutoff, backend='fock', basis=False, den_mat=False, detector='pnrd', name=name, mps=False, chi=None, noise=False, mu=0, sigma=0.1, )
[docs] def set_init_state(self, init_state: Any = None) -> None: """Set the initial state of the circuit.""" if isinstance(init_state, DistributedFockState): self.init_state = init_state else: self.init_state = DistributedFockState(init_state, self.nmode, self.cutoff) self.cutoff = self.init_state.cutoff
[docs] @torch.no_grad() def forward( self, data: torch.Tensor | None = None, state: DistributedFockState | None = None ) -> DistributedFockState: """Perform a forward pass of the photonic quantum circuit and return the final state. This method applies the ``operators`` of the photonic quantum circuit to the initial state or the given state and returns the resulting state. If ``data`` is given, it is used as the input for the ``encoders``. The ``data`` must be a 1D tensor. Args: data: The input data for the ``encoders``. Default: ``None`` state: The initial state for the photonic quantum circuit. Default: ``None`` """ if state is None: self.init_state.reset() else: self.init_state = state self.encode(data) self.state = self.operators(self.init_state) return self.state
[docs] def measure( self, shots: int = 1024, with_prob: bool = False, wires: int | list[int] | None = None, block_size: int = 2**24 ) -> dict | None: """Measure the final state. Args: shots: The number of times to sample from the quantum state. Default: 1024 with_prob: A flag that indicates whether to return the probabilities along with the number of occurrences. Default: ``False`` wires: The wires to measure. It can be an integer or a list of integers specifying the indices of the wires. Default: ``None`` (which means all wires are measured) block_size: The block size for sampling. Default: 2**24 """ if wires is None: wires = list(range(self.nmode)) wires = sorted(self._convert_indices(wires)) if self.state is None: return else: if self.state.world_size == 1: return measure_fock_tensor(self.state.amps.unsqueeze(0), shots, with_prob, wires, block_size) else: return measure_dist(self.state, shots, with_prob, wires, block_size)