Source code for deepquantum.state

"""Quantum states"""

from typing import Any, Union

import torch
from torch import nn

from .bitmath import is_power_of_2, log_base2, power_of_2
from .communication import comm_get_rank, comm_get_world_size
from .qmath import amplitude_encoding, inner_product_mps, is_density_matrix, qr, svd
from .utils import apply_complex_fix


[docs] class QubitState(nn.Module): """A quantum state of n qubits, including both pure states and density matrices. Args: nqubit: The number of qubits in the state. Default: 1 state: The representation of the state. It can be one of the following strings: ``'zeros'``, ``'equal'``, ``'entangle'``, ``'GHZ'``, or ``'ghz'``. Alternatively, it can be a tensor that represents a custom state vector or density matrix. Default: ``'zeros'`` den_mat: Whether the state is a density matrix or not. Default: ``False`` """ def __init__(self, nqubit: int = 1, state: Any = 'zeros', den_mat: bool = False) -> None: super().__init__() self.nqubit = nqubit self.den_mat = den_mat if state == 'zeros': state = torch.zeros((2**nqubit, 1), dtype=torch.cfloat) state[0] = 1 if den_mat: state = state @ state.mH self.register_buffer('state', state) elif state == 'equal': state = torch.ones((2**nqubit, 1), dtype=torch.cfloat) state = nn.functional.normalize(state, p=2, dim=-2) if den_mat: state = state @ state.mH self.register_buffer('state', state) elif state in ('entangle', 'GHZ', 'ghz'): state = torch.zeros((2**nqubit, 1), dtype=torch.cfloat) state[0] = 1 / 2**0.5 state[-1] = 1 / 2**0.5 if den_mat: state = state @ state.mH self.register_buffer('state', state) else: if not isinstance(state, torch.Tensor): state = torch.tensor(state, dtype=torch.cfloat) ndim = state.ndim s = state.shape if den_mat and s[-1] == 2**nqubit and is_density_matrix(state): self.register_buffer('state', state) else: state = amplitude_encoding(data=state, nqubit=nqubit) if state.ndim > ndim: state = state.squeeze(0) if den_mat: state = state @ state.mH self.register_buffer('state', state) def _apply(self, fn: Any) -> 'QubitState': tensors_dict = {} name = 'state' tensor = self._buffers.pop(name) if tensor is not None: tensors_dict[name] = tensor super()._apply(fn) corrected = apply_complex_fix(fn, tensors_dict) for key, value in corrected.items(): self.register_buffer(key, value) return self
[docs] def forward(self) -> None: """Pass.""" pass
[docs] class MatrixProductState(nn.Module): r"""A matrix product state (MPS) for quantum systems. A matrix product state is a way of representing a quantum state as a product of local tensors. Each tensor has one physical index and one or two bond indices. The physical index corresponds to the local Hilbert space dimension of the qudit, while the bond indices correspond to the entanglement between qudits. Args: nsite: The number of sites of the MPS. Default: 1 state: The representation of the MPS. If ``'zeros'`` or ``'vac'``, the MPS is initialized to the all-zero state. If a list of tensors, the MPS is initialized to the given tensors. The tensors must have the correct shape and dtype. If a list of integers, the MPS is initialized to the corresponding basis state. Default: ``'zeros'`` chi: The maximum bond dimension of the MPS. Default: None (which means 10 * ``nsite``) qudit: The local Hilbert space dimension of each qudit. Default: 2 normalize: Whether to normalize the MPS after each operation. Default: ``True`` """ def __init__( self, nsite: int = 1, state: str | list[torch.Tensor] | list[int] = 'zeros', chi: int | None = None, qudit: int = 2, normalize: bool = True, ) -> None: super().__init__() if chi is None: chi = 10 * nsite self.nsite = nsite self.chi = chi self.qudit = qudit self.normalize = normalize self.center = -1 self.set_tensors(state) def _apply(self, fn: Any) -> 'MatrixProductState': tensors_dict = { name: tensor for i in range(self.nsite) if (tensor := self._buffers.pop(name := f'tensor{i}')) is not None } super()._apply(fn) corrected = apply_complex_fix(fn, tensors_dict) for key, value in corrected.items(): self.register_buffer(key, value) return self @property def tensors(self) -> list[torch.Tensor]: """Get the tensors of the matrix product state. Note: This output is provided for reading only. Please modify the tensors through buffers. """ tensors = [] for j in range(self.nsite): tensors.append(getattr(self, f'tensor{j}')) return tensors
[docs] def set_tensors(self, state: str | list[torch.Tensor] | list[int]) -> None: """Set the tensors of the matrix product state.""" if state in ('zeros', 'vac'): state = [0] * self.nsite assert isinstance(state, list), 'Invalid input type' if len(state) < self.nsite: state += [0] * (self.nsite - len(state)) for i in range(self.nsite): assert isinstance(state[i], (torch.Tensor, int)), 'Invalid input type' if isinstance(state[i], torch.Tensor): self.register_buffer(f'tensor{i}', state[i]) elif isinstance(state[i], int): assert 0 <= state[i] < self.qudit, 'Invalid input' tensor = torch.zeros(self.qudit, dtype=torch.cfloat) tensor[state[i]] = 1.0 # the bond dimension is 1 self.register_buffer(f'tensor{i}', tensor.reshape(1, self.qudit, 1))
[docs] def center_orthogonalization(self, c: int, dc: int = -1, normalize: bool = False) -> None: """Get the center-orthogonalization form of the MPS with center ``c``.""" if c == -1: c = self.nsite - 1 if self.center < -0.5: self.orthogonalize_n1_n2(0, c, dc, normalize) self.orthogonalize_n1_n2(self.nsite - 1, c, dc, normalize) elif self.center != c: self.orthogonalize_n1_n2(self.center, c, dc, normalize) self.center = c if normalize: self.normalize_central_tensor()
[docs] def check_center_orthogonality(self, prt: bool = False) -> list[torch.Tensor]: """Check if the MPS is in center-orthogonal form.""" tensors = self.tensors assert tensors[0].ndim == 3 if self.center < -0.5: if prt: print('MPS NOT in center-orthogonal form!') else: err = [None] * self.nsite for i in range(self.center): s = tensors[i].shape tmp = tensors[i].reshape(-1, s[-1]) tmp = tmp.mH @ tmp err[i] = (tmp - torch.eye(tmp.shape[0], device=tmp.device, dtype=tmp.dtype)).norm(p=1).item() for i in range(self.nsite - 1, self.center, -1): s = tensors[i].shape tmp = tensors[i].reshape(s[0], -1) tmp = tmp @ tmp.mH err[i] = (tmp - torch.eye(tmp.shape[0], device=tmp.device, dtype=tmp.dtype)).norm(p=1).item() if prt: print('Orthogonality check:') print('=' * 35) err_av = 0.0 for i in range(self.nsite): if err[i] is None: print('Site ' + str(i) + ': center') else: print('Site ' + str(i) + ': ', err[i]) err_av += err[i] print('-' * 35) print(f'Average error = {err_av / (self.nsite - 1)}') print('=' * 35) return err
[docs] def full_tensor(self) -> torch.Tensor: """Get the full tensor product of the state.""" tensors = self.tensors psi = tensors[0] for i in range(1, self.nsite): psi = torch.einsum('...abc,...cde->...abde', psi, tensors[i]) s = psi.shape psi = psi.reshape(-1, s[-4], s[-3] * s[-2], s[-1]) return psi.squeeze()
[docs] def inner( self, tensors: Union[list[torch.Tensor], 'MatrixProductState'], form: str = 'norm' ) -> torch.Tensor | list[torch.Tensor]: """Get the inner product with another matrix product state.""" # form: 'log' or 'list' if isinstance(tensors, list): return inner_product_mps(self.tensors, tensors, form=form) else: return inner_product_mps(self.tensors, tensors.tensors, form=form)
[docs] def normalize_central_tensor(self) -> None: """Normalize the center tensor.""" assert self.center in list(range(self.nsite)) tensors = self.tensors if tensors[self.center].ndim == 3: norm = tensors[self.center].norm() elif tensors[self.center].ndim == 4: norm = tensors[self.center].norm(p=2, dim=[1, 2, 3], keepdim=True) self._buffers[f'tensor{self.center}'] = self._buffers[f'tensor{self.center}'] / norm
[docs] def orthogonalize_left2right(self, site: int, dc: int = -1, normalize: bool = False) -> None: r"""Orthogonalize the tensor at ``site`` and update the next one at ``site`` + 1. It uses the QR decomposition or SVD, i.e., :math:`T = UR` for the QR decomposition and :math:`T = USV^{\dagger} = UR` for SVD. The tensor at ``site`` is replaced by :math:`U`. The tensor at ``site`` + 1 is updated by :math:`R`. Args: site: The site of tensor to be orthogonalized. dc: Keep the first ``dc`` singular values after truncation. Default: -1 (which means no truncation) normalize: Whether to normalize the tensor :math:`R`. Default: ``False`` """ assert site < self.nsite - 1 tensors = self.tensors shape = tensors[site].shape batch = 1 if len(shape) == 3 else shape[0] if_trun = 0 < dc < shape[-1] if if_trun: u, s, vh = svd(tensors[site].reshape(batch, -1, shape[-1])) u = u[:, :, :dc] r = s[:, :dc].diag_embed() @ vh[:, :dc, :] else: u, r = qr(tensors[site].reshape(batch, -1, shape[-1])) self._buffers[f'tensor{site}'] = u.reshape(batch, shape[-3], shape[-2], -1) if normalize: norm = r.norm(dim=[-2, -1], keepdim=True) r = r / norm self._buffers[f'tensor{site + 1}'] = torch.einsum('...ab,...bcd->...acd', r, tensors[site + 1]) if len(shape) == 3: tensors = self.tensors self._buffers[f'tensor{site}'] = tensors[site].squeeze(0) self._buffers[f'tensor{site + 1}'] = tensors[site + 1].squeeze(0)
# ruff: noqa: E741
[docs] def orthogonalize_right2left(self, site: int, dc: int = -1, normalize: bool = False) -> None: r"""Orthogonalize the tensor at ``site`` and update the next one at ``site`` - 1. It uses the QR decomposition or SVD, i.e., :math:`T^{\dagger} = QR` for the QR decomposition, which gives :math:`T = R^{\dagger}Q^{\dagger} = LV^{\dagger}`, and :math:`T = USV^{\dagger} = LV^{\dagger}` for SVD. The tensor at ``site`` is replaced by :math:`V^{\dagger}`. The tensor at ``site`` - 1 is updated by :math:`L`. Args: site: The site of tensor to be orthogonalized. dc: Keep the first ``dc`` singular values after truncation. Default: -1 (which means no truncation) normalize: Whether to normalize the tensor :math:`L`. Default: ``False`` """ assert site > 0 tensors = self.tensors shape = tensors[site].shape batch = 1 if len(shape) == 3 else shape[0] if_trun = 0 < dc < shape[-3] if if_trun: u, s, vh = svd(tensors[site].reshape(batch, shape[-3], -1)) vh = vh[:, :dc, :] l = u[:, :, :dc] @ s[:, :dc].diag_embed() else: q, r = qr(tensors[site].reshape(batch, shape[-3], -1).mH) vh = q.mH l = r.mH self._buffers[f'tensor{site}'] = vh.reshape(batch, -1, shape[-2], shape[-1]) if normalize: norm = l.norm(dim=[-2, -1], keepdim=True) l = l / norm self._buffers[f'tensor{site - 1}'] = torch.einsum('...abc,...cd->...abd', tensors[site - 1], l) if len(shape) == 3: tensors = self.tensors self._buffers[f'tensor{site}'] = tensors[site].squeeze(0) self._buffers[f'tensor{site - 1}'] = tensors[site - 1].squeeze(0)
[docs] def orthogonalize_n1_n2(self, n1: int, n2: int, dc: int, normalize: bool) -> None: """Orthogonalize the MPS from site ``n1`` to site ``n2``.""" if n1 < n2: for site in range(n1, n2, 1): self.orthogonalize_left2right(site, dc, normalize) else: for site in range(n1, n2, -1): self.orthogonalize_right2left(site, dc, normalize)
[docs] def apply_mpo(self, mpo: list[torch.Tensor], sites: list[int]) -> None: """Use TEBD algorithm to contract tensors (contract local states with local operators), i.e., >>> a >>> | >>> i-----O-----j a >>> | -> | >>> b ik---X---jl >>> | >>> k-----T-----l """ assert len(mpo) == len(sites) for i, site in enumerate(sites): tensor = torch.einsum('iabj,...kbl->...ikajl', mpo[i], self.tensors[site]) s = tensor.shape if len(s) == 5: self._buffers[f'tensor{site}'] = tensor.reshape(s[-5] * s[-4], s[-3], s[-2] * s[-1]) else: self._buffers[f'tensor{site}'] = tensor.reshape(-1, s[-5] * s[-4], s[-3], s[-2] * s[-1])
[docs] def forward(self) -> None: """Pass.""" pass
[docs] class DistributedQubitState(nn.Module): """A quantum state of n qubits distributed between w nodes. Args: nqubit: The number of qubits in the state. """ def __init__(self, nqubit: int) -> None: super().__init__() self.world_size = comm_get_world_size() self.rank = comm_get_rank() assert is_power_of_2(self.world_size) assert power_of_2(nqubit) >= self.world_size assert 0 <= self.rank < self.world_size self.nqubit = nqubit self.log_num_nodes = log_base2(self.world_size) self.log_num_amps_per_node = nqubit - self.log_num_nodes self.num_amps_per_node = power_of_2(self.log_num_amps_per_node) amps = torch.zeros(self.num_amps_per_node) + 0j buffer = torch.zeros_like(amps) self.register_buffer('amps', amps) self.register_buffer('buffer', buffer) self.reset() def _apply(self, fn: Any) -> 'DistributedQubitState': tensors_dict = {} names = ['amps', 'buffer'] tensors_dict = {name: tensor for name in names if (tensor := self._buffers.pop(name)) is not None} super()._apply(fn) corrected = apply_complex_fix(fn, tensors_dict) for key, value in corrected.items(): self.register_buffer(key, value) return self
[docs] def reset(self): """Reset the state to the vacuum state.""" self.amps.zero_() self.buffer.zero_() if self.rank == 0: self.amps[0] = 1.0