"""Common functions"""
import copy
from collections import Counter, defaultdict
from collections.abc import Callable
from typing import Any, TYPE_CHECKING
import numpy as np
import torch
from torch import nn, vmap
from tqdm import tqdm
if TYPE_CHECKING:
from .layer import Observable
[docs]
def is_power_of_two(n: int) -> bool:
"""Check if an integer is a power of two."""
def f(x):
if x < 2:
return False
elif x & (x - 1) == 0:
return True
return False
return np.vectorize(f)(n)
[docs]
def is_power(n: int, base: int) -> bool:
"""Check if an integer is a power of the given base."""
if n <= 0 or base <= 0 or base == 1:
return False
if n == 1:
return True
while n % base == 0:
n //= base
return n == 1
[docs]
def int_to_bitstring(x: int, n: int, debug: bool = False) -> str:
"""Convert from integer to bit string."""
assert isinstance(x, int)
assert isinstance(n, int)
if x < 2**n:
# remove '0b'
s = bin(x)[2:]
if len(s) <= n:
s = '0' * (n - len(s)) + s
else:
if debug:
print(f'Quantum register ({n}) overflowed for {x}.')
s = bin(x)[-n:]
return s
[docs]
def list_to_decimal(digits: list[int], base: int) -> int:
"""Convert from list of digits to decimal integer."""
result = 0
for digit in digits:
assert 0 <= digit < base, 'Invalid digit for the given base'
result = result * base + digit
return result
[docs]
def decimal_to_list(n: int, base: int, ndigit: int | None = None) -> list[int]:
"""Convert from decimal integer to list of digits."""
assert base >= 2, 'Base must be at least 2'
if n == 0:
if isinstance(ndigit, int):
return [0] * ndigit
else:
return [0]
digits = []
num = abs(n)
while num > 0:
num, remainder = divmod(num, base)
digits.insert(0, remainder)
if ndigit is not None:
digits = [0] * (ndigit - len(digits)) + digits
return digits
[docs]
def inverse_permutation(permute_shape: list[int]) -> list[int]:
"""Calculate the inversed permutation.
Args:
permute_shape: Shape of permutation.
Returns:
A list of integers that is the inverse of ``permute_shape``.
"""
# find the index of each element in the range of the list length
return [permute_shape.index(i) for i in range(len(permute_shape))]
[docs]
def is_unitary(matrix: torch.Tensor, rtol: float = 1e-5, atol: float = 1e-4) -> bool:
"""Check if a tensor is a unitary matrix.
Args:
matrix: Square matrix.
rtol: Relative tolerance. Default: 1e-5
atol: Absolute tolerance. Default: 1e-4
Returns:
``True`` if ``matrix`` is unitary, ``False`` otherwise.
"""
if matrix.shape[-1] != matrix.shape[-2]:
return False
conj_trans = matrix.t().conj()
product = torch.matmul(matrix, conj_trans)
return torch.allclose(
product, torch.eye(matrix.shape[0], dtype=matrix.dtype, device=matrix.device), rtol=rtol, atol=atol
)
[docs]
def is_density_matrix(rho: torch.Tensor) -> bool:
"""Check if a tensor is a valid density matrix.
A density matrix is a positive semi-definite Hermitian matrix with trace one.
Args:
rho: The tensor to check. It can be either 2D or 3D. If 3D, the first dimension is assumed to be
the batch dimension.
Returns:
``True`` if the tensor is a density matrix, ``False`` otherwise.
"""
if not isinstance(rho, torch.Tensor):
return False
if rho.ndim not in (2, 3):
return False
if not is_power_of_two(rho.shape[-2]):
return False
if not is_power_of_two(rho.shape[-1]):
return False
if rho.ndim == 2:
rho = rho.unsqueeze(0)
# Check if the tensor is Hermitian
hermitian = torch.allclose(rho, rho.mH)
if not hermitian:
return False
# Check if the trace of each matrix is one
trace_one = torch.allclose(vmap(torch.trace)(rho), torch.tensor(1.0, dtype=rho.dtype, device=rho.device))
if not trace_one:
return False
# Check if the eigenvalues of each matrix are non-negative
positive_semi_definite = torch.all(torch.linalg.eig(rho)[0].real >= 0).item()
return positive_semi_definite
[docs]
def is_positive_definite(mat: torch.Tensor) -> bool:
"""Check if the matrix is positive definite."""
is_herm = torch.equal(mat, mat.mH)
diag = torch.linalg.eigvalsh(mat)
return is_herm and torch.all(diag > 0).item()
[docs]
def safe_inverse(x: Any, epsilon: float = 1e-12) -> Any:
"""Safe inversion."""
return x / (x**2 + epsilon)
# -----------------------------------------------------------------------------
# Adapted from tensorgrad
# Original Copyright (c) 2019 tensorgrad Contributors
# Modified work Copyright (c) 2023-2026 TuringQ
# Licensed under the Apache License, Version 2.0
# Source: https://github.com/wangleiphy/tensorgrad/blob/732bd6430e86f1b69a8615045ca5b6399ad73061/tensornets/adlib/svd.py#L12
#
# Modifications:
# - Support complex matrices.
# - Refactored for framework integration and naming consistency.
# -----------------------------------------------------------------------------
[docs]
class SVD(torch.autograd.Function):
"""Customized backward of SVD for better numerical stability.
See https://readpaper.com/paper/2971614414
"""
generate_vmap_rule = True
[docs]
@staticmethod
def forward(a):
u, s, vh = torch.linalg.svd(a, full_matrices=False)
s = s.to(u.dtype)
return u, s, vh
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
# https://pytorch.org/docs/master/notes/extending.func.html
[docs]
@staticmethod
def setup_context(ctx, inputs, output):
# a = inputs
u, s, vh = output
ctx.save_for_backward(u, s, vh)
[docs]
@staticmethod
def backward(ctx, du, ds, dvh):
u, s, vh = ctx.saved_tensors
uh = u.mH
v = vh.mH
dv = dvh.mH
m = u.shape[-2]
n = v.shape[-2]
ns = s.shape[-1]
f = s.unsqueeze(-2) ** 2 - s.unsqueeze(-1) ** 2
f = safe_inverse(f)
f.diagonal(dim1=-2, dim2=-1).fill_(0)
j = f * (uh @ du)
k = f * (vh @ dv)
l = (vh @ dv).diagonal(dim1=-2, dim2=-1).diag_embed() # noqa: E741
s_inv = safe_inverse(s).diag_embed()
mat_s = s.diag_embed()
da = u @ (ds.diag_embed() + (j + j.mH) @ mat_s + mat_s @ (k + k.mH) + s_inv @ (l.mH - l) / 2) @ vh
if m > ns:
da += (torch.eye(m, dtype=du.dtype, device=du.device) - u @ uh) @ du @ s_inv @ vh
if n > ns:
da += u @ s_inv @ dvh @ (torch.eye(n, dtype=du.dtype, device=du.device) - v @ vh)
return da
# -----------------------------------------------------------------------------
# Adapted from TensorCircuit
# Original Copyright (c) 2020 TensorCircuit Contributors
# Modified work Copyright (c) 2023-2026 TuringQ
# Licensed under the Apache License, Version 2.0
# Source: https://github.com/tencent-quantum-lab/tensorcircuit/blob/cac74977f628e6e623bd34af95454fe55af399c2/tensorcircuit/backends/pytorch_ops.py#L15
#
# Modifications:
# - Refactored for framework integration and naming consistency.
# -----------------------------------------------------------------------------
[docs]
def torchqr_grad(a, q, r, dq, dr):
"""Get the gradient for QR."""
qr_epsilon = 1e-8
if r.shape[-2] > r.shape[-1] and q.shape[-2] == q.shape[-1]:
raise NotImplementedError(
'QrGrad not implemented when nrows > ncols '
'and full_matrices is true. Received r.shape='
f'{r.shape} with nrows={r.shape[-2]}'
f'and ncols={r.shape[-1]}.'
)
def _triangular_solve(x, r):
"""Equivalent to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
return torch.linalg.solve_triangular(r, x.adjoint(), upper=True, unitriangular=False).adjoint()
def _qr_grad_square_and_deep_matrices(q, r, dq, dr):
"""Get the gradient for matrix orders num_rows >= num_cols and full_matrices is false."""
# Modification begins
rdiag = torch.linalg.diagonal(r)
# if abs(rdiag[i]) < qr_epsilon then rdiag[i] = qr_epsilon otherwise keep the old value
qr_epsilon_diag = torch.ones_like(rdiag) * qr_epsilon
rdiag = torch.where(rdiag.abs() < qr_epsilon, qr_epsilon_diag, rdiag)
r = torch.diagonal_scatter(r, rdiag, dim1=-2, dim2=-1)
# delta_dq = math_ops.matmul(q, math_ops.matmul(dr, tf.linalg.adjoint(delta_r)))
# dq = dq + delta_dq
# Modification ends
qdq = torch.matmul(q.adjoint(), dq)
qdq_ = qdq - qdq.adjoint()
rdr = torch.matmul(r, dr.adjoint())
rdr_ = rdr - rdr.adjoint()
tril = torch.tril(qdq_ + rdr_)
grad_a = torch.matmul(q, dr + _triangular_solve(tril, r))
grad_b = _triangular_solve(dq - torch.matmul(q, qdq), r)
ret = grad_a + grad_b
if q.is_complex():
m = rdr - qdq.adjoint()
eyem = torch.diagonal_scatter(torch.zeros_like(m), torch.linalg.diagonal(m), dim1=-2, dim2=-1)
correction = eyem - torch.real(eyem).to(dtype=q.dtype)
ret = ret + _triangular_solve(torch.matmul(q, correction.adjoint()), r)
return ret
num_rows, num_cols = q.shape[-2], r.shape[-1]
if num_rows >= num_cols:
return _qr_grad_square_and_deep_matrices(q, r, dq, dr)
y = a[..., :, num_rows:]
u = r[..., :, :num_rows]
dv = dr[..., :, num_rows:]
du = dr[..., :, :num_rows]
dy = torch.matmul(q, dv)
dx = _qr_grad_square_and_deep_matrices(q, u, dq + torch.matmul(y, dv.adjoint()), du)
return torch.cat([dx, dy], dim=-1)
# -----------------------------------------------------------------------------
# Adapted from TensorCircuit
# Original Copyright (c) 2020 TensorCircuit Contributors
# Modified work Copyright (c) 2023-2026 TuringQ
# Licensed under the Apache License, Version 2.0
# Source: https://github.com/tencent-quantum-lab/tensorcircuit/blob/cac74977f628e6e623bd34af95454fe55af399c2/tensorcircuit/backends/pytorch_ops.py#L87
#
# Modifications:
# - Refactored for framework integration and naming consistency.
# -----------------------------------------------------------------------------
[docs]
class QR(torch.autograd.Function):
"""Customized backward of QR for better numerical stability."""
generate_vmap_rule = True
[docs]
@staticmethod
def forward(a):
q, r = torch.linalg.qr(a, mode='reduced')
# ctx.save_for_backward(a, q, r)
return q, r
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
# https://pytorch.org/docs/master/notes/extending.func.html
[docs]
@staticmethod
def setup_context(ctx, inputs, output):
(a,) = inputs
q, r = output
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(a, q, r)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
# ctx.dim = dim
[docs]
@staticmethod
def backward(ctx, dq, dr):
a, q, r = ctx.saved_tensors
return torchqr_grad(a, q, r, dq, dr)
svd = SVD.apply
qr = QR.apply
[docs]
def split_tensor(tensor: torch.Tensor, center_left: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
"""Split a tensor by QR."""
if center_left:
q, r = qr(tensor.mH)
return r.mH, q.mH
else:
return qr(tensor)
[docs]
def state_to_tensors(state: torch.Tensor, nsite: int, qudit: int = 2) -> list[torch.Tensor]:
"""Convert a quantum state to a list of tensors."""
state = state.reshape([qudit] * nsite)
tensors = []
nleft = 1
for _ in range(nsite - 1):
u, state = split_tensor(state.reshape(nleft * qudit, -1), center_left=False)
tensors.append(u.reshape(nleft, qudit, -1))
nleft = state.shape[0]
u, state = split_tensor(state.reshape(nleft * qudit, -1), center_left=False)
assert state.shape == (1, 1)
tensors.append(u.reshape(nleft, qudit, -1) * state[0, 0])
return tensors
[docs]
def slice_state_vector(
state: torch.Tensor, nqubit: int, wires: list[int], bits: str, normalize: bool = True
) -> torch.Tensor:
"""Get the sliced state vectors according to ``wires`` and ``bits``."""
if len(bits) == 1:
bits = bits * len(wires)
assert len(wires) == len(bits)
wires = [i + 1 for i in wires]
state = state.reshape([-1] + [2] * nqubit)
batch = state.shape[0]
permute_shape = list(range(nqubit + 1))
for i in wires:
permute_shape.remove(i)
permute_shape = wires + permute_shape
state = state.permute(permute_shape)
for b in bits:
b = int(b)
assert b in (0, 1)
state = state[b]
state = state.reshape(batch, -1)
if normalize:
state = nn.functional.normalize(state, p=2, dim=-1)
return state
[docs]
def multi_kron(lst: list[torch.Tensor]) -> torch.Tensor:
"""Calculate the Kronecker/tensor/outer product for a list of tensors.
Args:
lst: A list of tensors.
Returns:
The Kronecker/tensor/outer product of the input.
"""
n = len(lst)
if n == 1:
return lst[0].contiguous()
else:
mid = n // 2
rst = torch.kron(multi_kron(lst[0:mid]), multi_kron(lst[mid:]))
return rst.contiguous()
[docs]
def partial_trace(rho: torch.Tensor, nqudit: int, trace_lst: list[int], qudit: int = 2) -> torch.Tensor:
r"""Calculate the partial trace for a batch of density matrices.
Args:
rho: Density matrices with the shape of
:math:`(\text{batch}, \text{qudit}^{\text{nqudit}}, \text{qudit}^{\text{nqudit}})`.
nqudit: Total number of qudits.
trace_lst: A list of qudits to be traced.
qudit: The dimension of the qudits. Default: 2
Returns:
Reduced density matrices.
"""
if rho.ndim == 2:
rho = rho.unsqueeze(0)
assert rho.ndim == 3
assert rho.shape[1] == rho.shape[2] == qudit**nqudit
b = rho.shape[0]
n = len(trace_lst)
trace_lst = [i + 1 for i in trace_lst]
trace_lst2 = [i + nqudit for i in trace_lst]
trace_lst += trace_lst2
permute_shape = list(range(2 * nqudit + 1))
for i in trace_lst:
permute_shape.remove(i)
permute_shape += trace_lst
rho = rho.reshape([b] + [qudit] * 2 * nqudit).permute(permute_shape).reshape(-1, qudit**n, qudit**n)
rho = rho.diagonal(dim1=-2, dim2=-1).sum(-1)
return rho.reshape(b, qudit ** (nqudit - n), qudit ** (nqudit - n)).squeeze(0)
[docs]
def amplitude_encoding(data: Any, nqubit: int) -> torch.Tensor:
r"""Encode data into quantum states using amplitude encoding.
This function takes a batch of data and encodes each sample into a quantum state using amplitude encoding.
The quantum state is represented by a complex-valued tensor of shape :math:`(\text{batch}, 2^{\text{nqubit}})`.
The data is normalized to have unit norm along the last dimension before encoding. If the data size is smaller
than :math:`2^{\text{nqubit}}`, the remaining amplitudes are set to zero. If the data size is larger than
:math:`2^{\text{nqubit}}`, only the first :math:`2^{\text{nqubit}}` elements are used.
Args:
data: The input data to be encoded. It should have shape :math:`(\text{batch}, ...)`
where :math:`...` can be any dimensions. If it is not a torch.Tensor object, it will be converted to one.
nqubit: The number of qubits to use for encoding.
Returns:
The encoded quantum states as complex-valued tensors of shape
:math:`(\text{batch}, 2^{\text{nqubit}}, 1)`.
Examples:
>>> data = [[0.5, 0.5], [0.7, 0.3]]
>>> amplitude_encoding(data, nqubit=2)
tensor([[[0.7071+0.j],
[0.7071+0.j],
[0.0000+0.j],
[0.0000+0.j]],
[[0.9487+0.j],
[0.3162+0.j],
[0.0000+0.j],
[0.0000+0.j]]])
"""
if not isinstance(data, torch.Tensor):
data = torch.tensor(data)
is_single_state = data.ndim == 1 or (data.ndim == 2 and data.shape[-1] == 1)
batch = 1 if is_single_state else data.shape[0]
data = data.reshape(batch, -1)
size = data.shape[1]
n = 2**nqubit
state = torch.zeros(batch, n, dtype=data.dtype, device=data.device) + 0j
data = nn.functional.normalize(data[:, :n], p=2, dim=-1)
if n > size:
state[:, :size] = data[:, :]
else:
state[:, :] = data[:, :]
return state.unsqueeze(-1)
[docs]
def evolve_state(
state: torch.Tensor, matrix: torch.Tensor, nqudit: int, wires: list[int], qudit: int = 2
) -> torch.Tensor:
"""Perform the evolution of quantum states.
Args:
state: The batched state tensor.
matrix: The evolution matrix.
nqudit: The number of the qudits.
wires: The indices of the qudits that the quantum operation acts on.
qudit: The dimension of the qudits. Default: 2
"""
nt = len(wires)
wires = [i + 1 for i in wires]
pm_shape = list(range(nqudit + 1))
for i in wires:
pm_shape.remove(i)
pm_shape = wires + pm_shape
state = state.permute(pm_shape).reshape(qudit**nt, -1)
state = (matrix @ state).reshape([qudit] * nt + [-1] + [qudit] * (nqudit - nt))
state = state.permute(inverse_permutation(pm_shape))
return state
[docs]
def evolve_den_mat(
state: torch.Tensor, matrix: torch.Tensor, nqudit: int, wires: list[int], qudit: int = 2
) -> torch.Tensor:
"""Perform the evolution of density matrices.
Args:
state: The batched state tensor.
matrix: The evolution matrix.
nqudit: The number of the qudits.
wires: The indices of the qudits that the quantum operation acts on.
qudit: The dimension of the qudits. Default: 2
"""
nt = len(wires)
# left multiply
wires1 = [i + 1 for i in wires]
pm_shape = list(range(2 * nqudit + 1))
for i in wires1:
pm_shape.remove(i)
pm_shape = wires1 + pm_shape
state = state.permute(pm_shape).reshape(qudit**nt, -1)
state = (matrix @ state).reshape([qudit] * nt + [-1] + [qudit] * (2 * nqudit - nt))
state = state.permute(inverse_permutation(pm_shape))
# right multiply
wires2 = [i + 1 + nqudit for i in wires]
pm_shape = list(range(2 * nqudit + 1))
for i in wires2:
pm_shape.remove(i)
pm_shape = wires2 + pm_shape
state = state.permute(pm_shape).reshape(qudit**nt, -1)
state = (matrix.conj() @ state).reshape([qudit] * nt + [-1] + [qudit] * (2 * nqudit - nt))
state = state.permute(inverse_permutation(pm_shape))
return state
[docs]
def block_sample(probs: torch.Tensor, shots: int = 1024, block_size: int = 2**24) -> list:
"""Sample from a probability distribution using block sampling.
Args:
probs: The probability distribution to sample from.
shots: The number of samples to draw. Default: 1024
block_size: The block size for sampling. Default: 2**24
"""
samples = []
num_blocks = int(np.ceil(len(probs) / block_size))
probs_block = torch.zeros(num_blocks, device=probs.device)
start = (num_blocks - 1) * block_size
end = min(num_blocks * block_size, len(probs))
probs_block[:-1] = probs[:start].reshape(num_blocks - 1, block_size).sum(1)
probs_block[-1] = probs[start:end].sum()
blocks = torch.multinomial(probs_block, shots, replacement=True).cpu().numpy()
block_dict = Counter(blocks)
for idx_block, shots_block in block_dict.items():
start = idx_block * block_size
end = min((idx_block + 1) * block_size, len(probs))
samples_block = torch.multinomial(probs[start:end], shots_block, replacement=True)
samples.extend((samples_block + start).cpu().numpy())
return samples
[docs]
def measure(
state: torch.Tensor,
shots: int = 1024,
with_prob: bool = False,
wires: int | list[int] | None = None,
den_mat: bool = False,
block_size: int = 2**24,
) -> dict | list[dict]:
r"""A function that performs a measurement on a quantum state and returns the results.
The measurement is done by sampling from the probability distribution of the quantum state. The results
are given as a dictionary or a list of dictionaries, where each key is a bit string representing the
measurement outcome, and each value is either the number of occurrences or a tuple of the number of
occurrences and the probability.
Args:
state: The quantum state to measure. It can be a tensor of shape :math:`(2^n,)` or
:math:`(2^n, 1)` representing a state vector, or a tensor of shape :math:`(\text{batch}, 2^n)`
or :math:`(\text{batch}, 2^n, 1)` representing a batch of state vectors. It can also be a tensor
of shape :math:`(2^n, 2^n)` representing a density matrix or :math:`(\text{batch}, 2^n, 2^n)`
representing a batch of density matrices.
shots: The number of times to sample from the quantum state. Default: 1024
with_prob: A flag that indicates whether to return the probabilities along with the number of occurrences.
Default: ``False``
wires: The wires to measure. It can be an integer or a list of integers specifying the indices of the wires.
Default: ``None`` (which means all wires are measured)
den_mat: Whether the state is a density matrix or not. Default: ``False``
block_size: The block size for sampling. Default: 2**24
Returns:
The measurement results. If the state is a single state vector, it returns a dictionary
where each key is a bit string representing the measurement outcome, and each value is
either the number of occurrences or a tuple of the number of occurrences and the probability.
If the state is a batch of state vectors, it returns a list of dictionaries with the same format
for each state vector in the batch.
"""
if den_mat:
assert is_density_matrix(state), 'Please input density matrices'
state = state.diagonal(dim1=-2, dim2=-1)
is_single_state = state.ndim == 1 or (state.ndim == 2 and state.shape[-1] == 1)
batch = 1 if is_single_state else state.shape[0]
state = state.reshape(batch, -1)
assert is_power_of_two(state.shape[-1]), 'The length of the quantum state is not in the form of 2^n'
n = int(np.log2(state.shape[-1]))
if wires is not None:
if isinstance(wires, int):
wires = [wires]
assert isinstance(wires, list)
wires = sorted(wires)
pm_shape = list(range(n))
for w in wires:
pm_shape.remove(w)
pm_shape = wires + pm_shape
num_bits = len(wires) if wires else n
results_tot = []
for i in range(batch):
probs = torch.abs(state[i]) if den_mat else torch.abs(state[i]) ** 2
if wires is not None:
probs = probs.reshape([2] * n).permute(pm_shape).reshape([2] * len(wires) + [-1]).sum(-1).reshape(-1)
# Perform block sampling to reduce memory consumption
samples = Counter(block_sample(probs, shots, block_size))
results = {bin(key)[2:].zfill(num_bits): value for key, value in samples.items()}
if with_prob:
for k in results:
index = int(k, 2)
results[k] = results[k], probs[index]
results_tot.append(results)
if batch == 1:
return results_tot[0]
else:
return results_tot
[docs]
def sample_sc_mcmc(
prob_func: Callable, proposal_sampler: Callable, shots: int = 1024, num_chain: int = 5
) -> defaultdict:
"""Get the samples of the probability distribution function via SC-MCMC method."""
samples_chain = []
merged_samples = defaultdict(int)
cache_prob = {}
if shots <= 0:
return merged_samples
elif shots < num_chain:
num_chain = shots
shots_lst = [shots // num_chain] * num_chain
shots_lst[-1] += shots % num_chain
for trial in range(num_chain):
cache = []
len_cache = min(shots_lst)
if shots_lst[trial] > 1e5:
len_cache = 4000
# random start
sample_0 = proposal_sampler()
if not isinstance(sample_0, str):
if prob_func(sample_0) < 1e-12: # avoid the samples with almost-zero probability
sample_0 = tuple([0] * len(sample_0))
while prob_func(sample_0) < 1e-9:
sample_0 = proposal_sampler()
cache.append(sample_0)
sample_max = sample_0
if sample_max in cache_prob:
prob_max = cache_prob[sample_max]
else:
prob_max = prob_func(sample_0)
cache_prob[sample_max] = prob_max
dict_sample = defaultdict(int)
for i in tqdm(range(1, shots_lst[trial]), desc=f'chain {trial + 1}', ncols=80, colour='green'):
sample_i = proposal_sampler()
if sample_i in cache_prob:
prob_i = cache_prob[sample_i]
else:
prob_i = prob_func(sample_i)
cache_prob[sample_i] = prob_i
rand_num = torch.rand(1, device=prob_i.device)
# MCMC transfer to new state
if prob_i / prob_max > rand_num:
sample_max = sample_i
prob_max = prob_i
if i < len_cache: # cache not full
cache.append(sample_max)
else: # full
idx = np.random.randint(0, len_cache)
out_sample = copy.deepcopy(cache[idx])
cache[idx] = sample_max
out_sample_key = out_sample
if out_sample_key in dict_sample:
dict_sample[out_sample_key] = dict_sample[out_sample_key] + 1
else:
dict_sample[out_sample_key] = 1
# clear the cache
for i in range(len_cache):
out_sample = cache[i]
out_sample_key = out_sample
if out_sample_key in dict_sample:
dict_sample[out_sample_key] = dict_sample[out_sample_key] + 1
else:
dict_sample[out_sample_key] = 1
samples_chain.append(dict_sample)
for key, value in dict_sample.items():
merged_samples[key] += value
return merged_samples
[docs]
def get_prob_mps(mps_lst: list[torch.Tensor], wire: int) -> torch.Tensor:
"""Calculate the probability distribution (|0⟩ and |1⟩ probabilities) for a specific wire in an MPS.
This function computes the probability of measuring |0⟩ and |1⟩ for the k-th qubit in a quantum state
represented as a Matrix Product State (MPS). It does this by:
1. Contracting the tensors to the left of the target tensor
2. Contracting the tensors to the right of the target tensor
3. Computing the final contraction with the target tensor
Args:
mps_lst: A list of MPS tensors representing the quantum state.
Each 3-dimensional tensor should have shape (bond_dim_left, physical_dim, bond_dim_right).
wire: The index of the target qubit to compute probabilities for.
Returns:
A tensor containing [P(|0⟩), P(|1⟩)] probabilities for the target qubit.
"""
def contract_conjugate_pair(tensors: list[torch.Tensor]) -> torch.Tensor:
"""Contract a list of MPS tensors with their conjugates.
This helper function performs the contraction between a list of MPS tensors
and their complex conjugates, which is needed for probability calculation.
Args:
tensors: A list of MPS tensors to contract.
Returns:
Contracted tensor.
"""
if not tensors: # Handle empty tensor list case
return torch.tensor(1).reshape(1, 1, 1, 1).to(mps_lst[0].device, mps_lst[0].dtype)
# Contract first tensor with its conjugate
contracted = torch.tensordot(tensors[0].conj(), tensors[0], dims=([1], [1]))
contracted = contracted.permute(0, 2, 1, 3) # (left_c, left, right_c, right)
# Iteratively contract remaining tensors
for tensor in tensors[1:]:
pair_contracted = torch.tensordot(tensor.conj(), tensor, dims=([1], [1]))
pair_contracted = pair_contracted.permute(0, 2, 1, 3)
contracted = torch.tensordot(contracted, pair_contracted, dims=([2, 3], [0, 1]))
return contracted
# Split MPS into left and right parts relative to target qubit
left_tensors = mps_lst[:wire] if wire > 0 else []
right_tensors = mps_lst[wire + 1 :] if wire < len(mps_lst) - 1 else []
target_tensor = mps_lst[wire]
# Contract left and right parts separately
left_contracted = contract_conjugate_pair(left_tensors)
right_contracted = contract_conjugate_pair(right_tensors)
# Perform final contractions with target qubit tensor
temp1 = torch.tensordot(left_contracted, target_tensor.conj(), dims=([2], [0]))
temp2 = torch.tensordot(temp1, target_tensor, dims=([2], [0]))
final_tensor = torch.tensordot(right_contracted, temp2, dims=([0, 1], [3, 5])).squeeze()
# Extract probabilities from diagonal elements
probabilities = final_tensor.diagonal().real
return torch.clamp(probabilities, min=0) # Returns [P(|0⟩), P(|1⟩)]
[docs]
def inner_product_mps(
tensors0: list[torch.Tensor], tensors1: list[torch.Tensor], form: str = 'norm'
) -> torch.Tensor | list[torch.Tensor]:
r"""Computes the inner product of two matrix product states.
Args:
tensors0: The tensors of the first MPS, each with shape :math:`(..., d_0, d_1, d_2)`,
where :math:`d_0` is the bond dimension of the left site, :math:`d_1` is the physical dimension,
and :math:`d_2` is the bond dimension of the right site.
tensors1: The tensors of the second MPS, each with shape :math:`(..., d_0, d_1, d_2)`,
where :math:`d_0` is the bond dimension of the left site, :math:`d_1` is the physical dimension,
and :math:`d_2` is the bond dimension of the right site.
form: The form of the output. If ``'log'``, returns the logarithm of the absolute value of the inner product.
If ``'list'``, returns a list of norms at each step. Otherwise, returns the inner product as a scalar.
Default: ``'norm'``
Returns:
The inner product of the two MPS, or a list of norms at each step.
Raises:
AssertionError: If the tensors have incompatible shapes or lengths.
"""
assert tensors0[0].shape[-3] == tensors0[-1].shape[-1]
assert tensors1[0].shape[-3] == tensors1[-1].shape[-1]
assert len(tensors0) == len(tensors1)
v0 = torch.eye(tensors0[0].shape[-3], dtype=tensors0[0].dtype, device=tensors0[0].device)
v1 = torch.eye(tensors1[0].shape[-3], dtype=tensors0[0].dtype, device=tensors0[0].device)
v = torch.kron(v0, v1).reshape(
[tensors0[0].shape[-3], tensors1[0].shape[-3], tensors0[0].shape[-3], tensors1[0].shape[-3]]
)
norm_list = []
for n in range(len(tensors0)):
v = torch.einsum('...uvap,...adb,...pdq->...uvbq', v, tensors0[n].conj(), tensors1[n])
norm_v = v.norm(p=2, dim=[-4, -3, -2, -1], keepdim=True)
v = v / norm_v
norm_list.append(norm_v.squeeze())
if v.numel() > 1:
norm1 = torch.einsum('...acac->...', v)
norm_list.append(norm1)
else:
norm_list.append(v[0, 0, 0, 0])
if form == 'log':
norm = 0.0
for x in norm_list:
norm = norm + torch.log(x.abs())
elif form == 'list':
return norm_list
else:
norm = 1.0
for x in norm_list:
norm = norm * x
return norm
[docs]
def expectation(
state: torch.Tensor | list[torch.Tensor], observable: 'Observable', den_mat: bool = False, chi: int | None = None
) -> torch.Tensor:
"""A function that calculates the expectation value of an observable on a quantum state.
The expectation value is the average measurement outcome of the observable on the quantum state.
It is a real number that represents the mean of the probability distribution of the measurement outcomes.
Args:
state: The quantum state to measure. It can be a list of tensors representing a matrix product state,
or a tensor representing a density matrix or a state vector.
observable: The observable to measure. It is an instance of ``Observable`` class that
implements the measurement basis and the corresponding gates.
den_mat: Whether to use density matrix representation. Default: ``False``
chi: The bond dimension of the matrix product state. It is only used when the state is a list of tensors.
Default: ``None`` (which means no truncation)
Returns:
The expectation value of the observable on the quantum state. It is a scalar tensor with real values.
"""
if isinstance(state, list):
from .state import MatrixProductState
mps = MatrixProductState(nsite=len(state), state=state, chi=chi)
return inner_product_mps(state, observable(mps).tensors).real
if den_mat:
expval = (observable.get_unitary() @ state).diagonal(dim1=-2, dim2=-1).sum(-1).real
else:
expval = state.mH @ observable(state)
expval = expval.squeeze(-1).squeeze(-1).real
return expval
[docs]
def sample2expval(sample: dict) -> torch.Tensor:
"""Get the expectation value according to the measurement results."""
total = 0
exp = 0
for bitstring, ncount in sample.items():
coeff = (-1) ** (bitstring.count('1') % 2)
exp += ncount * coeff
total += ncount
return torch.tensor([exp / total])
[docs]
def meyer_wallach_measure(state_tsr: torch.Tensor) -> torch.Tensor:
r"""Calculate Meyer-Wallach entanglement measure.
See https://readpaper.com/paper/2945680873 Eq.(19)
Args:
state_tsr: Input with the shape of :math:`(\text{batch}, 2, ..., 2)`.
Returns:
The value of Meyer-Wallach measure.
"""
nqubit = len(state_tsr.shape) - 1
batch = state_tsr.shape[0]
rst = 0
for i in range(nqubit):
s1 = linear_map_mw(state_tsr, i, 0).reshape(batch, -1, 1)
s2 = linear_map_mw(state_tsr, i, 1).reshape(batch, -1, 1)
rst += generalized_distance(s1, s2).reshape(-1)
return rst * 4 / nqubit
[docs]
def linear_map_mw(state_tsr: torch.Tensor, j: int, b: int) -> torch.Tensor:
r"""Calculate the linear mapping for Meyer-Wallach measure.
See https://readpaper.com/paper/2945680873 Eq.(18)
Note:
Project on state with local projectors on the ``j`` th qubit.
See https://arxiv.org/pdf/quant-ph/0305094.pdf Eq.(2)
Args:
state_tsr: Input with the shape of :math:`(\text{batch}, 2, ..., 2)`.
j: The ``j`` th qubit to project on, from :math:`0` to :math:`\text{nqubit}-1`.
b: The basis of projection, :math:`\ket{0}` or :math:`\ket{1}`.
Returns:
Non-normalized state tensor after the linear mapping.
"""
assert b in (0, 1), 'b must be 0 or 1'
n = len(state_tsr.shape)
assert j < n - 1, 'j can not exceed nqubit'
permute_shape = list(range(n))
permute_shape.remove(j + 1)
permute_shape = [0] + [j + 1] + permute_shape[1:]
return state_tsr.permute(permute_shape)[:, b]
[docs]
def generalized_distance(state1: torch.Tensor, state2: torch.Tensor) -> torch.Tensor:
r"""Calculate the generalized distance.
See https://readpaper.com/paper/2945680873 Eq.(20)
Note:
Implemented according to https://arxiv.org/pdf/quant-ph/0310137.pdf Eq.(4)
Args:
state1: Input with the shape of :math:`(\text{batch}, 2^n, 1)`.
state2: Input with the shape of :math:`(\text{batch}, 2^n, 1)`.
Returns:
The generalized distance.
"""
return ((state1.mH @ state1) * (state2.mH @ state2) - (state1.mH @ state2) * (state2.mH @ state1)).real
[docs]
def meyer_wallach_measure_brennen(state_tsr: torch.Tensor) -> torch.Tensor:
r"""Calculate Meyer-Wallach entanglement measure, proposed by Brennen.
See https://arxiv.org/pdf/quant-ph/0305094.pdf Eq.(6)
Note:
This implementation is slower than ``meyer_wallach_measure`` when :math:`\text{nqubit} \ge 8`.
Args:
state_tsr: Input with the shape of :math:`(\text{batch}, 2, ..., 2)`.
Returns:
The value of Meyer-Wallach measure.
"""
nqubit = len(state_tsr.shape) - 1
batch = state_tsr.shape[0]
rho = state_tsr.reshape(batch, -1, 1) @ state_tsr.conj().reshape(batch, 1, -1)
rst = 0
for i in range(nqubit):
trace_list = list(range(nqubit))
trace_list.remove(i)
rho_i = partial_trace(rho, nqubit, trace_list)
rho_i = rho_i @ rho_i
trace_rho_i = rho_i.diagonal(offset=0, dim1=-2, dim2=-1).sum(-1).real
rst += trace_rho_i
return 2 * (1 - rst / nqubit)