"""Measurement pattern"""
from collections.abc import Iterable
from copy import copy, deepcopy
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import torch
from networkx import MultiDiGraph, draw_networkx_edges, draw_networkx_labels, draw_networkx_nodes, multipartite_layout
from torch import nn
from .command import Correction, Entanglement, Measurement, Node
from .operation import Operation
from .state import GraphState, SubGraphState
[docs]
class Pattern(Operation):
"""Measurement-based quantum computing (MBQC) pattern.
A pattern represents a measurement-based quantum computation, which consists of a sequence of
commands (node preparation, entanglement, measurement, and correction) applied to qubits arranged
in a graph structure.
Args:
nodes_state: The nodes of the input state in the initial graph state.
It can be an integer representing the number of nodes or a list of node indices. Default: ``None``
state: The input state of the initial graph state. The string representation of state could be
``'plus'``, ``'minus'``, ``'zero'``, and ``'one'``. Default: ``'plus'``
edges: Additional edges connecting the nodes in the initial graph state. Default: ``None``
nodes: Additional nodes to include in the initial graph state. Default: ``None``
name: The name of the pattern. Default: ``None``
reupload: Whether to use data re-uploading. Default: ``False``
Ref: V. Danos, E. Kashefi and P. Panangaden. J. ACM 54.2 8 (2007)
"""
def __init__(
self,
nodes_state: int | list[int] | None = None,
state: Any = 'plus',
edges: list | None = None,
nodes: int | list[int] | None = None,
name: str | None = None,
reupload: bool = False,
) -> None:
super().__init__(name=name, nodes=None)
self.reupload = reupload
self.init_state = GraphState(nodes_state, state, edges, nodes)
self.commands = nn.Sequential()
self.encoders = []
self.state = None
self.ndata = 0
self.nodes_out_seq = None
[docs]
def forward(self, data: torch.Tensor | None = None, state: GraphState | None = None) -> GraphState:
"""Perform a forward pass of the MBQC pattern and return the final graph state.
Args:
data: The input data for the ``encoders``. Default: ``None``
state: The initial graph state for the pattern. Default: ``None``
Returns:
The final graph state of the pattern after applying the ``commands``.
"""
if state is None:
self.state = deepcopy(self.init_state)
else:
self.state = state
self.encode(data)
self.state = self.commands(self.state)
self.state.set_nodes_out_seq(self.nodes_out_seq)
if data is not None and data.ndim == 2:
# for plotting the last data
self.encode(data[-1])
return self.state
[docs]
def encode(self, data: torch.Tensor | None) -> None:
"""Encode the input data into the measurement angles as parameters.
This method iterates over the ``encoders`` of the MBQC pattern and initializes their parameters
with the input data. If ``reupload`` is ``False``, the input data must be at least as long as the number
of parameters in the ``encoders``. If ``reupload`` is ``True``, the input data can be repeated to fill up
the parameters.
Args:
data: The input data for the ``encoders``, could be a 1D or 2D tensor.
Raises:
AssertionError: If input data is shorter than the number of parameters in the ``encoders``.
"""
if data is None:
return
if not self.reupload:
assert data.shape[-1] >= self.ndata, 'The pattern needs more data, or consider data re-uploading'
count = 0
if self.reupload and self.ndata > data.shape[-1]:
n = int(np.ceil(self.ndata / data.shape[-1]))
data = torch.cat([data] * n, dim=-1)
for op in self.encoders:
count_up = count + op.npara
if data.ndim == 2:
op.init_para(data[:, count:count_up])
else:
op.init_para(data[count:count_up])
count = count_up
[docs]
def add_graph(
self,
nodes_state: int | list[int] | None = None,
state: Any = 'plus',
edges: list | None = None,
nodes: int | list[int] | None = None,
index: int | None = None,
) -> None:
"""Add a subgraph state to the graph state.
Args:
nodes_state: The nodes of the input state in the subgraph state.
It can be an integer representing the number of nodes or a list of node indices. Default: ``None``
state: The input state of the subgraph state. The string representation of state could be
``'plus'``, ``'minus'``, ``'zero'``, and ``'one'``. Default: ``'plus'``
edges: Additional edges connecting the nodes in the subgraph state. Default: ``None``
nodes: Additional nodes to include in the subgraph state. Default: ``None``
index: The index where to insert the subgraph state. Default: ``None``
"""
self.init_state.add_subgraph(nodes_state=nodes_state, state=state, edges=edges, nodes=nodes, index=index)
@property
def graph(self) -> SubGraphState:
"""The combined graph state of the initial or final graph state."""
if self.state is None:
return self.init_state.graph
else:
return self.state.graph
[docs]
def set_nodes_out_seq(self, nodes: list[int] | None = None) -> None:
"""Set the output sequence of the nodes."""
self.nodes_out_seq = nodes
[docs]
def add(self, op: Operation, encode: bool = False) -> None:
"""A method that adds an operation to the MBQC pattern.
Args:
op: The operation to add. It is an instance of ``Operation`` class or its subclasses,
such as ``Node``, ``Entanglement``, ``Measurement``, or ``Correction``.
encode: Whether the command is to encode data. Default: ``False``
"""
assert isinstance(op, Operation)
self.commands.append(op)
if encode:
assert not op.requires_grad, 'Please set requires_grad of the operation to be False'
self.encoders.append(op)
self.ndata += op.npara
else:
self.npara += op.npara
[docs]
def n(self, nodes: int | list[int]) -> None:
"""Add a node command."""
n = Node(nodes=nodes)
self.add(n)
[docs]
def e(self, node1: int, node2: int) -> None:
"""Add an entanglement command."""
e = Entanglement(node1=node1, node2=node2)
self.add(e)
[docs]
def m(
self,
node: int,
angle: float = 0.0,
plane: str = 'xy',
t_domain: int | Iterable[int] | None = None,
s_domain: int | Iterable[int] | None = None,
encode: bool = False,
) -> None:
"""Add a measurement command."""
requires_grad = not encode
if angle is not None:
requires_grad = False
m = Measurement(
nodes=node, angle=angle, plane=plane, t_domain=t_domain, s_domain=s_domain, requires_grad=requires_grad
)
self.add(m, encode=encode)
[docs]
def x(self, node: int, domain: int | Iterable[int] | None = None) -> None:
"""Add an X-correction command."""
x = Correction(nodes=node, basis='x', domain=domain)
self.add(x)
[docs]
def z(self, node: int, domain: int | Iterable[int] | None = None) -> None:
"""Add a Z-correction command."""
z = Correction(nodes=node, basis='z', domain=domain)
self.add(z)
[docs]
def draw(self):
"""Draw the MBQC pattern."""
g = MultiDiGraph(self.init_state.graph.graph)
nodes_init = deepcopy(g.nodes())
for i in nodes_init:
g.nodes[i]['layer'] = 0
nodes_measured = []
edges_t_domain = []
edges_s_domain = []
for op in self.commands:
if isinstance(op, Node):
g.add_nodes_from(op.nodes, layer=2)
elif isinstance(op, Entanglement):
g.add_edge(*op.nodes)
elif isinstance(op, Measurement):
nodes_measured.append(op.nodes[0])
if op.nodes[0] not in nodes_init:
g.nodes[op.nodes[0]]['layer'] = 1
for i in op.t_domain:
edges_t_domain.append(tuple([i, op.nodes[0]]))
for i in op.s_domain:
edges_s_domain.append(tuple([i, op.nodes[0]]))
pos = multipartite_layout(g, subset_key='layer')
draw_networkx_nodes(g, pos, nodelist=nodes_init, node_color='#1f78b4', node_shape='s')
draw_networkx_nodes(g, pos, nodelist=nodes_measured, node_color='#1f78b4')
draw_networkx_nodes(
g, pos, nodelist=list(set(g.nodes()) - set(nodes_measured)), node_color='#d7dde0', node_shape='o'
)
draw_networkx_edges(g, pos, g.edges(), arrows=False)
draw_networkx_edges(
g, pos, edges_t_domain, arrows=True, style=':', edge_color='#4cd925', connectionstyle='arc3,rad=-0.2'
)
draw_networkx_edges(
g, pos, edges_s_domain, arrows=True, style=':', edge_color='#db1d2c', connectionstyle='arc3,rad=0.2'
)
draw_networkx_labels(g, pos)
plt.plot([], [], color='k', label='graph edge')
plt.plot([], [], ':', color='#4cd925', label='zflow')
plt.plot([], [], ':', color='#db1d2c', label='xflow')
plt.plot([], [], 's', color='#1f78b4', label='input nodes')
plt.plot([], [], 'o', color='#d7dde0', label='output nodes')
# plt.xlim(-width / 2, width / 2)
# plt.ylim(-width / 2, width / 2)
plt.legend(loc='upper right', fontsize=10)
plt.tight_layout()
plt.show()
[docs]
def is_standard(self) -> bool:
"""Determine whether the command sequence is standard.
Returns:
``True`` if the pattern follows NEMC standardization, ``False`` otherwise
"""
it = iter(self.commands)
try:
# Check if operations follow NEMC order
op = next(it)
while isinstance(op, Node): # First all Node operations
op = next(it)
while isinstance(op, Entanglement): # Then all Entanglement operations
op = next(it)
while isinstance(op, Measurement): # Then all Measurement operations
op = next(it)
while isinstance(op, Correction): # Finally all Correction operations
op = next(it)
return False # If we get here, there were operations after NEMC sequence
except StopIteration:
return True # If we run out of operations, pattern is standard
# -----------------------------------------------------------------------------
# Adapted from Graphix
# Original Copyright (c) 2022 Team Graphix
# Modified work Copyright (c) 2025-2026 TuringQ
# Licensed under the Apache License, Version 2.0
# Source: https://github.com/TeamGraphix/graphix/blob/0ca40c196c55da6bbb0488a8ea1045f2572fd0b6/graphix/pattern.py#L287
#
# Modifications:
# - Refactored to fit internal data structures.
# -----------------------------------------------------------------------------
[docs]
def standardize(self) -> None:
"""Standardize the command sequence into NEMC form.
This function reorders operations into the standard form:
- Node preparations (N)
- Entanglement operations (E)
- Measurement operations (M)
- Correction operations (C)
It handles the propagation of correction operations by:
1. Moving X-corrections through entanglements (generating Z-corrections)
2. Moving corrections through measurements (modifying measurement signal domains)
3. Collecting remaining corrections at the end
See https://arxiv.org/pdf/0704.1263 Ch.(5.4)
"""
# Initialize lists for each operation type
n_list = [] # Node operations
e_list = [] # Entanglement operations
m_list = [] # Measurement operations
z_dict = {} # Tracks Z corrections by node
x_dict = {} # Tracks X corrections by node
def add_correction_domain(domain_dict: dict, node, domain) -> None:
"""Helper function to update correction domains with XOR operation"""
if previous_domain := domain_dict.get(node):
previous_domain ^= domain
else:
domain_dict[node] = domain.copy()
# Process each operation and reorganize into standard form
for op in self.commands:
if isinstance(op, Node):
n_list.append(op)
elif isinstance(op, Entanglement):
for side in (0, 1):
# Propagate X corrections through entanglement (generates Z corrections)
if s_domain := x_dict.get(op.nodes[side]):
add_correction_domain(z_dict, op.nodes[1 - side], s_domain)
e_list.append(op)
elif isinstance(op, Measurement):
# Apply pending corrections to measurement parameters
new_op = copy(op)
if t_domain := z_dict.pop(op.nodes[0], None):
new_op.t_domain = new_op.t_domain ^ t_domain
if s_domain := x_dict.pop(op.nodes[0], None):
new_op.s_domain = new_op.s_domain ^ s_domain
m_list.append(new_op)
elif isinstance(op, Correction):
if op.basis == 'z':
add_correction_domain(z_dict, op.nodes[0], op.domain)
elif op.basis == 'x':
add_correction_domain(x_dict, op.nodes[0], op.domain)
# Reconstruct command sequence in standard order
self.commands = nn.Sequential(
*n_list,
*e_list,
*m_list,
*(Correction(nodes=node, basis='z', domain=domain) for node, domain in z_dict.items()),
*(Correction(nodes=node, basis='x', domain=domain) for node, domain in x_dict.items()),
)
# -----------------------------------------------------------------------------
# Adapted from Graphix
# Original Copyright (c) 2022 Team Graphix
# Modified work Copyright (c) 2025-2026 TuringQ
# Licensed under the Apache License, Version 2.0
# Source: https://github.com/TeamGraphix/graphix/blob/0ca40c196c55da6bbb0488a8ea1045f2572fd0b6/graphix/pattern.py#L426
#
# Modifications:
# - Refactored to fit internal data structures and conventions.
# -----------------------------------------------------------------------------
[docs]
def shift_signals(self) -> dict:
"""Perform signal shifting procedure.
This allows one to dispose of dependencies induced by the Z-action,
and obtain sometimes standard patterns with smaller computational depth complexity.
It handles the propagation of signal shifting commands by:
1. Extracting signals via t_domain (in XY plane cases) of measurements.
2. Moving signals to the left, through modifying other measurements and corrections.
See https://arxiv.org/pdf/0704.1263 Ch.(5.5)
Returns:
A signal dictionary including all the signal shifting commands.
"""
signal_dict = {}
def expand_domain(domain: set[int]) -> None:
for node in domain & signal_dict.keys():
domain ^= signal_dict[node]
for op in self.commands:
if isinstance(op, Measurement):
s_domain = set(op.s_domain)
t_domain = set(op.t_domain)
expand_domain(s_domain)
expand_domain(t_domain)
if op.plane in ['xy', 'yx']:
# M^{XY,α} X^s Z^t = M^{XY,(-1)^s·α+tπ}
# = S^t M^{XY,(-1)^s·α}
# = S^t M^{XY,α} X^s
if t_domain:
signal_dict[op.nodes[0]] = t_domain
t_domain = set()
elif op.plane in ['zx', 'xz']:
# M^{XZ,α} X^s Z^t = M^{XZ,(-1)^t((-1)^s·α+sπ)}
# = M^{XZ,(-1)^{s+t}·α+(-1)^t·sπ}
# = M^{XZ,(-1)^{s+t}·α+sπ} (since (-1)^t·π ≡ π (mod 2π))
# = S^s M^{XZ,(-1)^{s+t}·α}
# = S^s M^{XZ,α} Z^{s+t}
if s_domain:
signal_dict[op.nodes[0]] = s_domain
t_domain ^= s_domain
s_domain = set()
elif op.plane in ['yz', 'zy']: # noqa: SIM102
# positive Y axis as 0 angle
# M^{YZ,α} X^s Z^t = M^{YZ,(-1)^t·α+(s+t)π)}
# = S^s M^{YZ,(-1)^t·α+tπ}
# = S^s M^{YZ,α} Z^t
# still remains M^{YZ,(-1)^t·α+tπ)} after signal shifting,
# but dependency on s_domain has been reduced
if s_domain:
signal_dict[op.nodes[0]] = s_domain
s_domain = set()
op.s_domain = s_domain
op.t_domain = t_domain
elif isinstance(op, Correction):
domain = set(op.domain)
expand_domain(domain)
op.domain = domain
return signal_dict