"""Draw photonic quantum circuit"""
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import svgwrite
from matplotlib import patches
from torch import nn
from .channel import PhotonLoss
from .gate import (
Barrier,
BeamSplitter,
BeamSplitterSingle,
ControlledX,
ControlledZ,
CrossKerr,
CubicPhase,
Displacement,
Kerr,
MZI,
PhaseShift,
QuadraticPhase,
Squeezing,
Squeezing2,
UAnyGate,
)
from .measurement import Homodyne
from .operation import Delay
info_dic = {
'PS': ['teal', 0],
'S': ['royalblue', 3],
'S2': ['royalblue', 0],
'D': ['green', 3],
'U': ['cadetblue', 0],
'QP': ['peru', 0],
'CP': ['peru', 0],
'K': ['pink', 3],
'CX': ['gold', 0],
'CZ': ['gold', 0],
'CK': ['pink', 0],
}
[docs]
class DrawCircuit:
"""Draw the photonic quantum circuit.
Args:
circuit_name: The name of the circuit.
circuit_nmode: The number of modes in the circuit.
circuit_operators: The operators of the circuit.
measurements: The measurements of the circuit.
"""
def __init__(
self, circuit_name: str, circuit_nmode: int, circuit_operators: nn.Sequential, measurements: nn.ModuleList
) -> None:
if circuit_name is None:
circuit_name = 'circuit'
nmode = circuit_nmode
name = circuit_name + '.svg'
self.draw_ = svgwrite.Drawing(name, profile='full')
self.draw_['height'] = f'{10.5 / 11 * nmode}cm'
self.nmode = nmode
self.name = name
self.ops = circuit_operators
self.mea = measurements
[docs]
def draw(self, depth=None, ops=None, measurements=None):
"""Draw circuit."""
order_dic = defaultdict(list) # 当key不存在时对应的value是[]
nmode = self.nmode
if depth is None:
depth = [0] * nmode # record the depth of each mode
if ops is None:
ops = self.ops
if measurements is None:
measurements = self.mea
for op in ops:
if isinstance(op, BeamSplitter):
if isinstance(op, MZI):
name = 'MZI-PT' if op.phi_first else 'MZI-TP'
elif isinstance(op, BeamSplitterSingle):
name = 'BS-' + op.convention.upper()
else:
name = 'BS'
theta = op.theta.item()
try:
phi = op.phi.item()
except Exception:
phi = None
order = max(depth[op.wires[0]], depth[op.wires[1]])
self.draw_bs(name, order, op.wires, theta, phi)
order_dic[order] = order_dic[order] + op.wires
for i in op.wires:
depth[i] = depth[i] + 1
bs_depth = [depth[op.wires[0]], depth[op.wires[1]]][:]
depth[op.wires[0]] = max(bs_depth) ## BS 经过后相同线路深度
depth[op.wires[1]] = max(bs_depth)
elif isinstance(op, PhaseShift):
name_ = 'PS'
theta = op.theta.item()
order = depth[op.wires[0]]
self.draw_ps(order, op.wires, theta, name_)
order_dic[order] = order_dic[order] + op.wires
for i in op.wires:
depth[i] = depth[i] + 1
elif isinstance(op, (UAnyGate, Squeezing2)):
order = max(depth[min(op.wires) : max(op.wires) + 1])
if isinstance(op, UAnyGate):
name_ = 'U'
self.draw_any(order, op.wires, name_)
else:
name_ = 'S2'
para_dic = {'r': op.r.item(), 'θ': op.theta.item()}
self.draw_sq(order, op.wires, para_dic, name_)
order_dic[order] = order_dic[order] + op.wires
for i in op.wires:
depth[i] = order + 1
elif isinstance(op, (Squeezing, Displacement)):
para_dic = {'r': op.r.item(), 'θ': op.theta.item()}
order = depth[op.wires[0]]
name_ = 'S' if isinstance(op, Squeezing) else 'D'
self.draw_sq(order, op.wires, para_dic, name=name_)
order_dic[order] = order_dic[order] + op.wires
for i in op.wires:
depth[i] = depth[i] + 1
elif isinstance(op, Delay):
name_ = ''
order = depth[op.wires[0]]
inputs = [op.ntau, op.theta.item(), op.phi.item()]
self.draw_delay(order, op.wires, inputs=inputs)
order_dic[order] = order_dic[order] + op.wires
for i in op.wires:
depth[i] = depth[i] + 1
elif isinstance(op, PhotonLoss):
name_ = 'loss'
order = depth[op.wires[0]]
t = op.t.item()
self.draw_loss(order, op.wires, name_, t)
order_dic[order] = order_dic[order] + op.wires
for i in op.wires:
depth[i] = depth[i] + 1
elif isinstance(op, Barrier):
wires = op.wires
order = int(max(np.array(depth)[wires]))
self.barrier(order=order, wires=wires)
for i in wires:
depth[i] = order
elif isinstance(op, (QuadraticPhase, ControlledX, ControlledZ, CubicPhase, Kerr, CrossKerr)):
if isinstance(op, (QuadraticPhase, CubicPhase, Kerr)):
order = depth[op.wires[0]]
if isinstance(op, QuadraticPhase):
para_dic = {'s': op.s.item()}
name_ = 'QP'
elif isinstance(op, CubicPhase):
para_dic = {'γ': op.gamma.item()}
name_ = 'CP'
elif isinstance(op, Kerr):
para_dic = {'κ': op.kappa.item()}
name_ = 'K'
elif isinstance(op, (ControlledX, ControlledZ, CrossKerr)):
order = max(depth[min(op.wires) : max(op.wires) + 1])
if isinstance(op, ControlledX):
para_dic = {'s': op.s.item()}
name_ = 'CX'
elif isinstance(op, ControlledZ):
para_dic = {'s': op.s.item()}
name_ = 'CZ'
elif isinstance(op, CrossKerr):
para_dic = {'κ': op.kappa.item()}
name_ = 'CK'
self.draw_sq(order, op.wires, para_dic, name=name_)
order_dic[order] = order_dic[order] + op.wires
for i in op.wires:
depth[i] = order + 1
if len(measurements) > 0:
for mea in measurements:
if isinstance(mea, Homodyne):
name_ = 'M'
phi = mea.phi.detach()
for i in mea.wires:
order = depth[i]
self.draw_homodyne(order, i, phi.item(), name_)
order_dic[order] = order_dic[order] + [i]
depth[i] = depth[i] + 1
for key, value in order_dic.items():
op_line = value ## here lines represent for no operation
line_wires = [i for i in range(nmode) if i not in op_line]
if len(line_wires) > 0:
self.draw_lines(key, line_wires)
self.draw_mode_num() ## mode draw numbers
self.order_dic = order_dic
self.depth = depth
wid = 3 * (90 * (max(self.depth)) + 40) / 100
self.draw_['width'] = f'{wid}cm'
[docs]
def save(self, filename):
"""Save the circuit as svg."""
self.draw_.saveas(filename)
[docs]
def draw_mode_num(self):
nmode = self.nmode
for i in range(nmode):
self.draw_.add(self.draw_.text(str(i), insert=(25, i * 30 + 30), font_size=12))
[docs]
def draw_bs(self, name, order, wires, theta, phi=None):
"""Draw beamsplitter."""
x = 90 * order + 40
wires = sorted(wires)
y_up = wires[0]
y_down = wires[1]
y_delta = abs(y_down - y_up)
shift = -10
self.draw_.add(
self.draw_.polyline(
points=[
(x, y_up * 30 + 30),
(x + 30 + shift, y_up * 30 + 30), # need shift
(x + 60 + shift, y_up * 30 + 30 + 30 * y_delta),
(x + 90, y_up * 30 + 30 + 30 * y_delta),
],
fill='none',
stroke='black',
stroke_width=2,
)
)
self.draw_.add(
self.draw_.polyline(
points=[
(x, y_up * 30 + 30 + 30 * y_delta),
(x + 30 + shift, y_up * 30 + 30 + 30 * y_delta),
(x + 60 + shift, y_up * 30 + 30),
(x + 90, y_up * 30 + 30),
],
fill='none',
stroke='black',
stroke_width=2,
)
)
self.draw_.add(
self.draw_.text(name, insert=(x + 40 - (len(name) - 2) * 3 + shift, y_up * 30 + 25), font_size=9)
)
self.draw_.add(
self.draw_.text(
'θ=' + str(np.round(theta, 3)), insert=(x + 55 + shift, y_up * 30 + 30 + 20 - 6), font_size=7
)
)
if phi is not None:
self.draw_.add(
self.draw_.text(
'ϕ=' + str(np.round(phi, 3)), insert=(x + 55 + shift, y_up * 30 + 30 + 26 - 6), font_size=7
)
)
[docs]
def draw_ps(self, order, wires, theta=0, name=None):
"""Draw phaseshift (rotation) gate."""
fill_c = info_dic[name][0]
shift = info_dic[name][1]
x = 90 * order + 40
y_up = wires[0]
# y_down = wires[1]
self.draw_.add(
self.draw_.polyline(
points=[(x, y_up * 30 + 30), (x + 90, y_up * 30 + 30)], fill='none', stroke='black', stroke_width=2
)
)
self.draw_.add(
self.draw_.rect(
insert=(x + 42.5, y_up * 30 + 25),
size=(6, 12),
rx=0,
ry=0,
fill=fill_c,
stroke='black',
stroke_width=1.5,
)
)
self.draw_.add(self.draw_.text(name, insert=(x + 40 + shift, y_up * 30 + 20), font_size=9))
self.draw_.add(self.draw_.text('θ=' + str(np.round(theta, 3)), insert=(x + 55, y_up * 30 + 20), font_size=7))
[docs]
def draw_homodyne(self, order, wire, phi, name=None):
"""Draw homodyne measurement."""
fill_c = 'black'
shift = 5
x = 90 * order + 40
y_up = wire
self.draw_.add(
self.draw_.polyline(
points=[(x, y_up * 30 + 30), (x + 90, y_up * 30 + 30)], fill='none', stroke='black', stroke_width=2
)
)
self.draw_.add(
self.draw_.rect(
insert=(x + 42.5, y_up * 30 + 25),
size=(14, 14),
rx=0,
ry=0,
fill=fill_c,
stroke='black',
stroke_width=1.5,
)
)
self.draw_.add(self.draw_.text(name, insert=(x + 40 + shift, y_up * 30 + 20), font_size=9))
arc_radius = 6
arc_center_x = x + 42.5 + 14 / 2
arc_center_y = y_up * 30 + 25 + 14 / 2
start_x = arc_center_x - arc_radius
start_y = arc_center_y + 3
end_x = arc_center_x + arc_radius
end_y = arc_center_y + 3
arc_path = f'M {start_x} {start_y} A {arc_radius} {arc_radius} 0 0 1 {end_x} {end_y}'
self.draw_.add(self.draw_.path(d=arc_path, stroke='white', fill='none', stroke_width=1.5))
line_start_x = arc_center_x
line_start_y = arc_center_y + 3
line_end_x = arc_center_x
line_end_y = arc_center_y - arc_radius
line_path = f'M {line_start_x} {line_start_y} L {line_end_x} {line_end_y}'
rotation = 45
self.draw_.add(
self.draw_.path(
d=line_path,
stroke='white',
fill='none',
stroke_width=1.5,
transform=f'rotate({rotation} {arc_center_x} {arc_center_y})',
)
)
self.draw_.add(self.draw_.text('ϕ=' + str(np.round(phi, 3)), insert=(x + 55, y_up * 30 + 20), font_size=7))
[docs]
def draw_sq(self, order, wires, para_dic, name=None):
"""Draw squeezing gate, displacement gate."""
x = 90 * order + 40
wires = sorted(wires)
y_up = wires[0]
for i in range(len(wires)):
wire_i = wires[i]
self.draw_.add(
self.draw_.polyline(
points=[(x, wire_i * 30 + 30), (x + 90, wire_i * 30 + 30)],
fill='none',
stroke='black',
stroke_width=2,
)
)
fill_c = info_dic[name][0] # squeezing gate or displacement gate
shift = info_dic[name][1]
if len(wires) == 1:
height = 12
if len(wires) == 2:
height = 12 * 3 + 3
self.draw_.add(
self.draw_.rect(
insert=(x + 42.5, y_up * 30 + 25),
size=(10, height),
rx=0,
ry=0,
fill=fill_c,
stroke='black',
stroke_width=1.5,
)
)
self.draw_.add(self.draw_.text(name, insert=(x + 40 + shift, y_up * 30 + 20), font_size=9))
for k, key in enumerate(para_dic):
self.draw_.add(
self.draw_.text(
key + '=' + str(np.round(para_dic[key], 3)), insert=(x + 55, y_up * 30 + 18 + 6 * k), font_size=7
)
)
[docs]
def draw_delay(self, order, wires, inputs=None):
"""Draw delay loop."""
x = 90 * order + 40
y_up = wires[0]
for i in range(len(wires)):
wire_i = wires[i]
self.draw_.add(
self.draw_.polyline(
points=[(x, wire_i * 30 + 30), (x + 90, wire_i * 30 + 30)],
fill='none',
stroke='black',
stroke_width=2,
)
)
self.draw_.add(
self.draw_.circle(center=(x + 46, y_up * 30 + 25 - 4), r=9, stroke='black', fill='white', stroke_width=1.2)
)
self.draw_.add(self.draw_.text('N=' + str(inputs[0]), insert=(x + 40, y_up * 30 + 18), font_size=5))
self.draw_.add(
self.draw_.text('θ=' + str(np.round(inputs[1], 2)), insert=(x + 58, y_up * 30 + 18), font_size=6)
)
self.draw_.add(
self.draw_.text('ϕ=' + str(np.round(inputs[2], 2)), insert=(x + 58, y_up * 30 + 24), font_size=6)
)
[docs]
def draw_loss(self, order, wires, name, t):
"""Draw loss gate."""
x = 90 * order + 40
y_up = wires[0]
self.draw_.add(
self.draw_.polyline(
points=[(x, y_up * 30 + 30), (x + 90, y_up * 30 + 30)], fill='none', stroke='black', stroke_width=2
)
)
start = (x + 18, y_up * 30 + 23)
end = (x + 38, y_up * 30 + 23)
num_waves = 4
wave_amplitude = [1.5] * 3 + [3] * 2 + [1.5] * 3
wave_length = (end[0] - start[0]) / num_waves
path_d = f'M {start[0]},{start[1]} '
for i in range(num_waves * 2):
x = start[0] + i * wave_length / 2
y = start[1] + (-1) ** i * wave_amplitude[i]
path_d += f'L {x},{y} '
path_d += f'L {end[0]},{end[1]}'
path_d += f'L {end[0] + 12},{end[1]}'
path = self.draw_.path(d=path_d, fill='none', stroke='gray', stroke_width=2)
arrow_marker = self.draw_.marker(insert=(3.5, 1.8), size=(10, 5), orient='auto')
arrow_marker.add(self.draw_.path(d='M 0 0 L 5 1.5 L 0 4 Z', fill='gray'))
self.draw_.defs.add(arrow_marker)
path.set_markers((None, None, arrow_marker))
path.rotate(angle=-45, center=(x + 10, y_up * 30 + 18))
self.draw_.add(path)
self.draw_.add(self.draw_.text('T=' + str(np.round(t, 3)), insert=(x - 14, y_up * 30 + 25), font_size=7))
[docs]
def draw_any(self, order, wires, name, para_dict=None):
"""Draw arbitrary unitary gate."""
fill_c = info_dic[name][0]
# shift= info_dic[name][1]
x = 90 * order + 40
wires = sorted(wires)
y_up = wires[0]
h = (int(len(wires)) - 1) * 30 + 20
width = 50
for k in wires:
self.draw_.add(
self.draw_.polyline(
points=[(x, k * 30 + 30), (x + 20, k * 30 + 30)], fill='none', stroke='black', stroke_width=2
)
)
self.draw_.add(
self.draw_.polyline(
points=[(x + 70, k * 30 + 30), (x + 90, k * 30 + 30)], fill='none', stroke='black', stroke_width=2
)
)
self.draw_.add(
self.draw_.rect(
insert=(x + 20, y_up * 30 + 20),
size=(width, h),
rx=0,
ry=0,
fill=fill_c,
stroke='black',
stroke_width=2,
)
)
self.draw_.add(self.draw_.text(name, insert=((x + 2 * (10 + width) / 3), y_up * 30 + 15 + h / 2), font_size=10))
if para_dict is not None:
for i, key in enumerate(para_dict):
self.draw_.add(
self.draw_.text(
key + '=' + str(np.round(para_dict[key], 3)),
insert=((x + 2 * (10 + width) / 3 - 2), y_up * 30 + 15 + h / 2 + 8 * (i + 1)),
font_size=7,
)
)
[docs]
def draw_lines(self, order, wires):
"""Act nothing."""
x = 90 * order + 40
for k in wires:
self.draw_.add(
self.draw_.polyline(
points=[(x, k * 30 + 30), (x + 90, k * 30 + 30)], fill='none', stroke='black', stroke_width=2
)
)
[docs]
def barrier(self, order, wires, cl='black'):
x = 90 * order + 40
y_min = 15
y_max = self.nmode * 30 + 25
y_up = wires[0] * (y_max - y_min) / self.nmode + y_min
y_down = (1 + wires[-1]) * (y_max - y_min) / self.nmode + y_min
self.draw_.add(
self.draw_.polyline(
points=[(x, y_up), (x, y_down)], fill='none', stroke_dasharray='5,5', stroke=cl, stroke_width=2
)
)
[docs]
class DrawClements:
"""Draw the n-mode Clements architecture.
Args:
nmode: The number of modes of the Clements architecture.
mzi_info: The dictionary for mzi parameters, resulting from the decompose function.
cl: The color for plotting. Default: ``'dodgerblue'``
fs: The fontsize. Default: 30
method: The way for Clements decomposition, ``'cssr'`` or ``'cssl'``. Default: ``'cssr'``
"""
def __init__(self, nmode: int, mzi_info: dict, cl: str = 'dodgerblue', fs: int = 30, method: str = 'cssr') -> None:
self.nmode = nmode
self.method = method
self.mzi_info = mzi_info
self.color = cl
self.fontsize = fs
self.wid = 0.1
self.height = 0.08
self.axis_off = 'off'
self.phase_angle = self.mzi_info['phase_angle'] # for phase shifter
self.dic_mzi = self.sort_mzi() # for mzi parameters in the same array
self.ps_position = self.ps_pos()
[docs]
def plotting_clements(self):
"""Plot Clements structure with ``'cssr'`` or ``'cssl'`` type."""
if self.method == 'cssr':
assert self.nmode % 2 == 0, 'plotting only valid for even modes'
self.plotting_clements_1()
if self.method == 'cssl':
self.plotting_clements_2()
[docs]
def plotting_clements_1(self):
"""Plot ``'cssr'`` with left to right order."""
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(8 * 3, 5 * 3)
# plt.rcParams['figure.figsize'] = (8*3,5.0*3)
coords1 = []
coords2 = []
nmode = self.nmode
phase_angle = self.phase_angle
fs = self.fontsize
cl = self.color
wid = self.wid
height = self.height
for i in range(nmode):
plt.annotate(
'',
xy=(-0.1, 1 - 0.25 * i),
xytext=(-0.5, 1 - 0.25 * i),
arrowprops={'arrowstyle': '-|>', 'lw': 5},
va='center',
)
plt.text(-0.8, 1 - 0.25 * i, f'{i}', fontsize=fs)
plt.plot([0, 1.2], [1 - 0.25 * i, 1 - 0.25 * i], color=cl)
plt.text(
3.2 * (nmode / 2 - 1) + 2.2 + 2.1, 1 - 0.25 * i + 0.05, f'{phase_angle[i]:.3f}', fontsize=fs - 8
) # phase angle
ax.add_patch(
patches.Rectangle(
(3.2 * (nmode / 2 - 1) + 2.2 + 2.1, 1 - 0.25 * i - 0.05),
wid,
height,
edgecolor='green',
facecolor='green',
fill=True,
zorder=3,
)
) ## for PS
if nmode % 2 == 1:
plt.plot(
[2.2 + 3.2 * (int((nmode + 1) / 2) - 1), 3.2 * int((nmode + 1) / 2 - 1) + 2.2 + 2.2],
[1 - 0.25 * i, 1 - 0.25 * i],
color=cl,
)
if nmode % 2 == 0: # for even mode
for i in range(int(nmode / 2)):
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.2 + 2.2], [1, 1], color=cl)
plt.plot(
[2.2 + 3.2 * i, 3.2 * i + 2.2 + 2.2], [1 - 0.25 * (nmode - 1), 1 - 0.25 * (nmode - 1)], color=cl
)
for j in range(nmode):
plt.plot([1.5 + 3.2 * i, 3.2 * i + 1.9], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords1.append([1.5 + 3.2 * i, 3.2 * i + 1.9, 1 - 0.25 * j, 1 - 0.25 * j])
if 0 < j < nmode - 1:
plt.plot([3.1 + 3.2 * i, 3.2 * i + 3.5], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords2.append([3.1 + 3.2 * i, 3.2 * i + 3.5, 1 - 0.25 * j, 1 - 0.25 * j])
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.8], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
plt.plot([3.8 + 3.2 * i, 3.2 * i + 4.4], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
if nmode % 2 == 1: # for odd mode
for i in range(int((nmode + 1) / 2)):
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.2 + 2.2], [1, 1], color=cl)
# plt.plot([1.2+3.2*i, 3.2*i+2.2+2.2], [1-0.25*(nmode-1), 1-0.25*(nmode-1) ], color = cl)
for j in range(nmode):
if j < nmode - 1: # remove last line
plt.plot([1.5 + 3.2 * i, 3.2 * i + 1.9], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords1.append([1.5 + 3.2 * i, 3.2 * i + 1.9, 1 - 0.25 * j, 1 - 0.25 * j])
if j >= nmode - 1:
plt.plot([1.2 + 3.2 * i, 3.2 * i + 2.2], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
if i < int((nmode + 1) / 2) - 1 and 0 < j < nmode: # remove the last column
plt.plot([3.1 + 3.2 * i, 3.2 * i + 3.5], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords2.append([3.1 + 3.2 * i, 3.2 * i + 3.5, 1 - 0.25 * j, 1 - 0.25 * j])
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.8], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
plt.plot([3.8 + 3.2 * i, 3.2 * i + 4.4], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
# connecting lines i, i+1
for i in range(len(coords1)):
if i % 2 == 0:
self.connect1(coords1[i], ax, a=-0.5 - 0.4, c=0.7 - 0.7, cl=self.color)
if i % 2 == 1:
self.connect2(coords1[i], cl=self.color)
for i in range(len(coords2)):
if i % 2 == 0:
self.connect1(coords2[i], ax, a=-0.5 - 0.4, c=0.7 - 0.7, cl=self.color)
if i % 2 == 1:
self.connect2(coords2[i], cl=self.color)
# plotting paras
self.plot_paras_1(self.dic_mzi, fs=self.fontsize - 8)
plt.axis(self.axis_off)
# if self.axis_off:
# plt.axis('off')
plt.show()
[docs]
def plotting_clements_2(self):
"""Plot ``cssl`` with right to left order."""
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(8 * 3, 5 * 3)
# plt.rcParams['figure.figsize'] = (8*3,5.0*3)
coords1 = []
coords2 = []
nmode = self.nmode
phase_angle = self.phase_angle
fs = self.fontsize
cl = self.color
wid = self.wid
height = self.height
for i in range(nmode):
plt.annotate(
'',
xy=(-0.1, 1 - 0.25 * i),
xytext=(-0.5, 1 - 0.25 * i),
arrowprops={'arrowstyle': '-|>', 'lw': 5},
va='center',
)
plt.text(-0.8, 1 - 0.25 * i, f'{i}', fontsize=fs)
plt.plot([0, 1.2], [1 - 0.25 * i, 1 - 0.25 * i], color=cl)
plt.text(0.4, 1 - 0.25 * i + 0.05, f'{phase_angle[i]:.3f}', fontsize=fs - 8) # phase angle
ax.add_patch(
patches.Rectangle((0.5, 1 - 0.25 * i - 0.05), wid, height, edgecolor=cl, facecolor=cl, fill=True)
)
if nmode % 2 == 1:
plt.plot(
[2.2 + 3.2 * (int((nmode + 1) / 2) - 1), 3.2 * int((nmode + 1) / 2 - 1) + 2.2 + 2.2],
[1 - 0.25 * i, 1 - 0.25 * i],
color=cl,
)
if nmode % 2 == 0: # for even mode
for i in range(int(nmode / 2)):
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.2 + 2.2], [1, 1], color=cl)
plt.plot(
[2.2 + 3.2 * i, 3.2 * i + 2.2 + 2.2], [1 - 0.25 * (nmode - 1), 1 - 0.25 * (nmode - 1)], color=cl
)
for j in range(nmode):
plt.plot([1.5 + 3.2 * i, 3.2 * i + 1.9], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords1.append([1.5 + 3.2 * i, 3.2 * i + 1.9, 1 - 0.25 * j, 1 - 0.25 * j])
if 0 < j < nmode - 1:
plt.plot([3.1 + 3.2 * i, 3.2 * i + 3.5], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords2.append([3.1 + 3.2 * i, 3.2 * i + 3.5, 1 - 0.25 * j, 1 - 0.25 * j])
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.8], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
plt.plot([3.8 + 3.2 * i, 3.2 * i + 4.4], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
if nmode % 2 == 1: # for odd mode
for i in range(int((nmode + 1) / 2)):
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.2 + 2.2], [1, 1], color=cl)
# plt.plot([1.2+3.2*i, 3.2*i+2.2+2.2], [1-0.25*(nmode-1), 1-0.25*(nmode-1) ], color = cl)
for j in range(nmode):
if j < nmode - 1: # remove last line
plt.plot([1.5 + 3.2 * i, 3.2 * i + 1.9], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords1.append([1.5 + 3.2 * i, 3.2 * i + 1.9, 1 - 0.25 * j, 1 - 0.25 * j])
if j >= nmode - 1:
plt.plot([1.2 + 3.2 * i, 3.2 * i + 2.2], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
if i < int((nmode + 1) / 2) - 1 and 0 < j < nmode: # remove the last column
plt.plot([3.1 + 3.2 * i, 3.2 * i + 3.5], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
coords2.append([3.1 + 3.2 * i, 3.2 * i + 3.5, 1 - 0.25 * j, 1 - 0.25 * j])
plt.plot([2.2 + 3.2 * i, 3.2 * i + 2.8], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
plt.plot([3.8 + 3.2 * i, 3.2 * i + 4.4], [1 - 0.25 * j, 1 - 0.25 * j], color=cl)
# connecting lines i, i+1
for i in range(len(coords1)):
if i % 2 == 0:
self.connect1(coordinate=coords1[i], ax=ax, cl=self.color)
if i % 2 == 1:
self.connect2(coordinate=coords1[i], ax=ax, cl=self.color)
for i in range(len(coords2)):
if i % 2 == 0:
self.connect1(coordinate=coords2[i], ax=ax, cl=self.color)
if i % 2 == 1:
self.connect2(coordinate=coords2[i], ax=ax, cl='black')
# plotting paras
self.plot_paras(self.dic_mzi, self.nmode, fs=self.fontsize - 8)
plt.axis(self.axis_off)
# if self.axis_off:
# plt.axis('off')
plt.show()
[docs]
def sort_mzi(self):
"""Sort mzi parameters in the same array for plotting."""
dic_mzi = defaultdict(list) # 当key不存在时对应的value是[]
mzi_list = self.mzi_info['MZI_list']
for i in mzi_list:
dic_mzi[tuple(i[0:2])].append(i[2:])
return dic_mzi
[docs]
def ps_pos(self):
"""Label the position of each phaseshifter for ``'cssr'`` case."""
if self.method == 'cssr':
dic_pos = {}
nmode = self.nmode
phase_angle = self.phase_angle
dic_ = self.dic_mzi
for mode in range(nmode):
pair = (mode, mode + 1)
value = dic_[pair]
value = np.array(value).flatten()
for k in range(len(value)):
dic_pos[(mode, k)] = np.round(value[k], 4)
if mode == nmode - 1:
dic_pos[(mode, 0)] = np.round(phase_angle[mode], 4)
else:
dic_pos[(mode, k + 1)] = np.round(phase_angle[mode], 4)
return dic_pos
else:
return None
[docs]
@staticmethod
def connect1(coordinate, ax, cl, wid=0.1, height=0.08, a=-0.05, b=-0.05, c=0.7, d=-0.05):
"""Connect odd column."""
x0, x1, y0, y1 = coordinate
# print(x0,x1,y0,y1)
plt.plot([x0, x0 - 0.3], [y0, y0 - 0.25], color=cl)
plt.plot([x1, x1 + 0.3], [y1, y1 - 0.25], color=cl)
ax.add_patch(patches.Rectangle(((x0 + x1) / 2 + a, y0 + b), wid, height, edgecolor=cl, facecolor=cl, fill=True))
ax.add_patch(patches.Rectangle(((x0 + x1) / 2 + c, y0 + d), wid, height, edgecolor=cl, facecolor=cl, fill=True))
[docs]
@staticmethod
def connect2(coordinate, cl):
"""Connect even column."""
x0, x1, y0, y1 = coordinate
plt.plot([x0, x0 - 0.3], [y0, y0 + 0.25], color=cl)
plt.plot([x1, x1 + 0.3], [y1, y1 + 0.25], color=cl)
[docs]
@staticmethod
def plot_paras(sort_mzi_dic, nmode, fs=20):
"""Plot mzi parameters for ``'cssl'`` case."""
for i in sort_mzi_dic:
if i[0] % 2 == 0: # 0, 2, 4, 6..
temp_values = sort_mzi_dic[i]
len_ = len(temp_values)
for j in range(len_):
plt.text(
8.6 - 3.2 * j + 3.2 * ((nmode - 6) // 2 + nmode % 2),
1 - 0.25 * i[0] + 0.05,
f'{temp_values[j][0]:.3f}',
fontsize=fs,
)
plt.text(
7.8 - 3.2 * j + 3.2 * ((nmode - 6) // 2 + nmode % 2),
1 - 0.25 * i[0] + 0.05,
f'{temp_values[j][1]:.3f}',
fontsize=fs,
)
if i[0] % 2 == 1: # 1, 3..
temp_values = sort_mzi_dic[i]
len_ = len(temp_values)
for j in range(len_):
plt.text(
8.6 - 3.2 * j + 1.6 + 3.2 * ((nmode - 6) // 2),
1 - 0.25 * i[0] + 0.05,
f'{temp_values[j][0]:.3f}',
fontsize=fs,
)
plt.text(
7.8 - 3.2 * j + 1.6 + 3.2 * ((nmode - 6) // 2),
1 - 0.25 * i[0] + 0.05,
f'{temp_values[j][1]:.3f}',
fontsize=fs,
)
[docs]
@staticmethod
def plot_paras_1(sort_mzi_dic, fs=20):
"""Plot mzi parameters for ``'cssr'`` case."""
for i in sort_mzi_dic:
if i[0] % 2 == 0: # 0, 2, 4, 6..
temp_values = sort_mzi_dic[i]
len_ = len(temp_values)
for j in range(len_):
plt.text(3.2 * j + 0.6, 1 - 0.25 * i[0] + 0.05, f'{temp_values[j][0]:.3f}', fontsize=fs)
plt.text(3.2 * j + 0.6 + 0.9, 1 - 0.25 * i[0] + 0.05, f'{temp_values[j][1]:.3f}', fontsize=fs)
if i[0] % 2 == 1: # 1, 3..
temp_values = sort_mzi_dic[i]
len_ = len(temp_values)
for j in range(len_):
plt.text(3.2 * j + 0.6 + 1.6, 1 - 0.25 * i[0] + 0.05, f'{temp_values[j][0]:.3f}', fontsize=fs)
plt.text(3.2 * j + 0.6 + 2.4, 1 - 0.25 * i[0] + 0.05, f'{temp_values[j][1]:.3f}', fontsize=fs)