Source code for deepquantum.mbqc.state

"""Quantum states"""

from collections import defaultdict
from typing import Any

import networkx as nx
import numpy as np
import torch
from torch import nn, vmap

from ..circuit import QubitCircuit
from ..qmath import inverse_permutation, multi_kron
from ..state import QubitState
from ..utils import apply_complex_fix


[docs] class SubGraphState(nn.Module): """A subgraph state of a quantum state. Args: nodes_state: The nodes of the input state in the subgraph state. It can be an integer representing the number of nodes or a list of node indices. Default: ``None`` state: The input state of the subgraph state. The string representation of state could be ``'plus'``, ``'minus'``, ``'zero'``, and ``'one'``. Default: ``'plus'`` edges: Additional edges connecting the nodes in the subgraph state. Default: ``None`` nodes: Additional nodes to include in the subgraph state. Default: ``None`` """ def __init__( self, nodes_state: int | list[int] | None = None, state: Any = 'plus', edges: list | None = None, nodes: int | list[int] | None = None, # primarily, for the single-node case ) -> None: super().__init__() self.nodes_out_seq = None self.set_graph(nodes_state, edges, nodes) self.set_state(state) self.measure_dict = defaultdict(list) # record the measurement results: {node: batched_bit} def _apply(self, fn: Any) -> 'SubGraphState': tensors_dict = {} name = 'state' tensor = self._buffers.pop(name) if tensor is not None: tensors_dict[name] = tensor super()._apply(fn) corrected = apply_complex_fix(fn, tensors_dict) for key, value in corrected.items(): self.register_buffer(key, value) return self @property def nodes(self, **kwargs): """Nodes of the graph.""" return self.graph.nodes(**kwargs) @property def edges(self, **kwargs): """Edges of the graph.""" return self.graph.edges(**kwargs) @property def full_state(self) -> torch.Tensor: """Compute and return the full quantum state of the subgraph state.""" nqubit = len(self.nodes) nodes_bg = list(self.nodes) for i in self.nodes_state: nodes_bg.remove(i) nodes = self.nodes_state + nodes_bg wires = [0] + list(map(lambda node: self.node2wire_dict[node] + 1, nodes)) # [0] for batch plus = torch.tensor([[1], [1]], dtype=self.state.dtype, device=self.state.device) / 2**0.5 init_state = multi_kron([self.state] + [plus] * len(nodes_bg)).reshape([-1] + [2] * nqubit) init_state = init_state.permute(inverse_permutation(wires)).reshape([-1, 2**nqubit]) cir = QubitCircuit(nqubit=nqubit, init_state=init_state) edges = list(filter(lambda edge: edge[2]['cz'], self.edges(data=True))) for edge in edges: cir.cz(self.node2wire_dict[edge[0]], self.node2wire_dict[edge[1]]) cir.to(init_state.device, init_state.real.dtype) return cir()
[docs] def set_graph( self, nodes_state: int | list[int] | None = None, edges: list | None = None, nodes: int | list[int] | None = None, ) -> None: """Set the graph structure for the subgraph state.""" if nodes_state is None: nodes_state = [] elif isinstance(nodes_state, int): nodes_state = list(range(nodes_state)) if edges is None: edges = [] if nodes is None: nodes = [] elif isinstance(nodes, int): nodes = [nodes] graph = nx.Graph() if len(nodes_state) > 1: nx.add_cycle(graph, nodes_state, cz=False) # 'cz' is the label for entanglement else: graph.add_nodes_from(nodes_state) graph.add_edges_from(edges, cz=True) graph.add_nodes_from(nodes) self.graph = graph self.nodes_state = nodes_state self.update_node2wire_dict()
[docs] def set_state(self, state: Any = 'plus') -> None: """Set the input state of the subgraph state.""" nqubit = len(self.nodes_state) if isinstance(state, str): if state == 'plus': state = torch.tensor([1, 1]) / 2**0.5 + 0j elif state == 'minus': state = torch.tensor([1, -1]) / 2**0.5 + 0j elif state == 'zero': state = torch.tensor([1, 0]) + 0j elif state == 'one': state = torch.tensor([0, 1]) + 0j if nqubit > 0: state = multi_kron([state] * nqubit) elif not isinstance(state, torch.Tensor): state = torch.tensor(state, dtype=torch.cfloat) if nqubit > 0: self.register_buffer('state', QubitState(nqubit, state).state) else: self.register_buffer('state', torch.tensor(1, dtype=state.dtype, device=state.device))
[docs] def set_nodes_out_seq(self, nodes: list[int] | None = None) -> None: """Set the output sequence of the nodes.""" if nodes is not None: assert len(nodes) == len(self.nodes) assert set(nodes) == set(self.nodes) self.nodes_out_seq = nodes self.update_node2wire_dict()
[docs] def add_nodes(self, nodes: int | list[int]) -> None: """Add nodes to the subgraph state.""" if isinstance(nodes, int): nodes = [nodes] self.graph.add_nodes_from(nodes) self.update_node2wire_dict()
[docs] def add_edges(self, edges: list) -> None: """Add edges to the subgraph state.""" self.graph.add_edges_from(edges, cz=True) self.update_node2wire_dict()
[docs] def shift_labels(self, n: int) -> None: """Shift the labels of the nodes in the graph by a given integer.""" self.graph = nx.relabel_nodes(self.graph, lambda x: x + n) self.nodes_state = (np.array(self.nodes_state) + n).tolist() self.measure_dict = {k + n: v for k, v in self.measure_dict.items()} self.update_node2wire_dict()
[docs] def compose(self, other: 'SubGraphState', relabel: bool = True) -> 'SubGraphState': """Compose this subgraph state with another subgraph state. Args: other: The other subgraph state to compose with. relabel: Whether to relabel nodes to avoid conflicts. Default: ``True`` Returns: A new subgraph state that is the composition of the two. """ if relabel and (set(self.nodes) & set(other.nodes)): shift = max(self.nodes) - min(other.nodes) + 1 other.shift_labels(shift) graph = nx.compose(self.graph, other.graph) for i in other.nodes_state: assert i not in self.nodes_state, 'Do NOT use repeated nodes for states' nodes_state = self.nodes_state + other.nodes_state if self.state.ndim == other.state.ndim == 3: if self.state.shape[0] == 1 or other.state.shape[0] == 1: state = torch.kron(self.state, other.state) else: state = vmap(torch.kron)(self.state, other.state) else: state = torch.kron(self.state, other.state) sgs = SubGraphState(nodes_state, state, graph.edges(data=True), graph.nodes) sgs.measure_dict = defaultdict(list) sgs.measure_dict.update(self.measure_dict) sgs.measure_dict.update(other.measure_dict) return sgs
[docs] def update_node2wire_dict(self) -> dict: """Update the mapping from nodes to wire indices. Returns: A dictionary mapping nodes to their corresponding wire indices. """ if self.nodes_out_seq is None: wires = inverse_permutation(np.argsort(self.nodes).tolist()) self.node2wire_dict = {node: wire for node, wire in zip(self.nodes, wires, strict=True)} else: self.node2wire_dict = {node: i for i, node in enumerate(self.nodes_out_seq)} return self.node2wire_dict
[docs] def draw(self, **kwargs): """Draw the graph using NetworkX.""" nx.draw(self.graph, with_labels=True, **kwargs)
[docs] def extra_repr(self) -> str: return f'nodes_state={self.nodes_state}, nodes={self.nodes}'
[docs] class GraphState(nn.Module): """A graph state composed by several SubGraphStates. Args: nodes_state: The nodes of the input state in the initial graph state. It can be an integer representing the number of nodes or a list of node indices. Default: ``None`` state: The input state of the initial graph state. The string representation of state could be ``'plus'``, ``'minus'``, ``'zero'``, and ``'one'``. Default: ``'plus'`` edges: Additional edges connecting the nodes in the initial graph state. Default: ``None`` nodes: Additional nodes to include in the initial graph state. Default: ``None`` """ def __init__( self, nodes_state: int | list[int] | None = None, state: Any = 'plus', edges: list | None = None, nodes: int | list[int] | None = None, ) -> None: super().__init__() sgs = SubGraphState(nodes_state, state, edges, nodes) self.subgraphs = nn.ModuleList([sgs]) self.nodes_out_seq = None
[docs] def add_subgraph( self, nodes_state: int | list[int] | None = None, state: Any = 'plus', edges: list | None = None, nodes: int | list[int] | None = None, measure_dict: dict | None = None, index: int | None = None, ) -> None: """Add a subgraph state to the graph state. Args: nodes_state: The nodes of the input state in the subgraph state. It can be an integer representing the number of nodes or a list of node indices. Default: ``None`` state: The input state of the subgraph state. The string representation of state could be ``'plus'``, ``'minus'``, ``'zero'``, and ``'one'``. Default: ``'plus'`` edges: Additional edges connecting the nodes in the subgraph state. Default: ``None`` nodes: Additional nodes to include in the subgraph state. Default: ``None`` measure_dict: A dictionary containing all measurement results. Default: ``None`` index: The index where to insert the subgraph state. Default: ``None`` """ sgs = SubGraphState(nodes_state, state, edges, nodes) if index is None: dtype = self.subgraphs[0].state.real.dtype device = self.subgraphs[0].state.device sgs.to(device, dtype) if measure_dict is not None: sgs.measure_dict = measure_dict if index is None: self.subgraphs.append(sgs) else: self.subgraphs.insert(index, sgs)
@property def graph(self) -> SubGraphState: """The combined graph state of all subgraph states.""" graph = None for subgraph in self.subgraphs: graph = subgraph if graph is None else graph.compose(subgraph, relabel=True) graph.set_nodes_out_seq(self.nodes_out_seq) return graph @property def full_state(self) -> torch.Tensor: """Compute and return the full quantum state of the graph state.""" return self.graph.full_state @property def measure_dict(self) -> dict: """A dictionary containing all measurement results for the graph state.""" return self.graph.measure_dict
[docs] def set_nodes_out_seq(self, nodes: list[int] | None = None) -> None: """Set the output sequence of the nodes.""" self.nodes_out_seq = nodes