Source code for deepquantum.photonic.decompose

"""Decompose the unitary matrix"""

from collections import defaultdict

import numpy as np
import torch


[docs] class UnitaryDecomposer: """This class is to decompose a unitary matrix into the Clements/Reck architecture. Args: unitary: The unitary matrix to be decomposed. method: The decomposition method, only 16 values (``'rssr'``, ``'rsdr'``, ``'rdsr'``, ``'rddr'``, ``'rssl'``, ``'rsdl'``, ``'rdsl'``, ``'rddl'``, ``'cssr'``, ``'csdr'``, ``'cdsr'``, ``'cddr'``, ``'cssl'``, ``'csdl'``, ``'cdsl'``, ``'cddl'``) are valid. The first char denotes the Clements or Reck architecture. The second char denotes single or double arms of outer phase shifters. The third char denotes single or double arms of inner phase shifters. The last char denotes the position of a column of phase shifters, i.e., ``'l'`` for left and ``'r'`` for right. Default: ``'cssr'`` """ def __init__(self, unitary: np.ndarray | torch.Tensor, method: str = 'cssr') -> None: if isinstance(unitary, np.ndarray): self.unitary = unitary.copy() elif isinstance(unitary, torch.Tensor): self.unitary = unitary.cpu().clone().detach().numpy() else: raise TypeError('The matrix to be decomposed must be in the type of numpy array or torch tensor.') if len(self.unitary.shape) != 2 or self.unitary.shape[0] != self.unitary.shape[1]: raise TypeError('The matrix to be decomposed must be a square matrix.') if np.abs(unitary @ unitary.conj().T - np.eye(len(unitary))).sum() / len(unitary) ** 2 > 1e-6: print('Make sure the input matrix is unitary, in case of an abnormal computation result.') self.unitary[np.abs(self.unitary) < 1e-32] = 1e-32 self.method = method
[docs] def decomp(self) -> tuple[dict, dict, dict]: """Decompose the unitary matrix. The third dictionary is the representation of the positions and the angles of all phase shifters. """ def period_cut(input_angle: float, period: float = np.pi * 2) -> float: return input_angle - np.floor(input_angle / period) * period def decomp_rr(unitary: np.ndarray, method: str) -> dict: n_dim = len(unitary) info = {} info['N'] = n_dim info['method'] = method info['MZI_list'] = [] # jj,ii,phi,theta if 'dd' in method: period_theta = 2 * np.pi period_phi = 4 * np.pi elif 'ds' in method: period_theta = 4 * np.pi period_phi = 4 * np.pi else: period_theta = 2 * np.pi period_phi = 2 * np.pi for i in range(n_dim): ii = n_dim - 1 - i # 基准列 ii for jj in range(ii)[::-1]: # print(ii,jj) # 要用 uniatry[:,ii] 把 unitary[:,jj]的第ii号元素变成0 # if uniatry[ii,jj] == 0: # continue ratio = unitary[ii, ii] / (unitary[ii, jj] + 1e-32) theta = 2 * np.arctan(np.abs(ratio)) phi = -np.angle(-ratio) multiple = get_matrix_inverse_r([jj, ii, phi, theta], n_dim, method) unitary = unitary @ multiple phi = period_cut(phi, period_phi) theta = period_cut(theta, period_theta) info['MZI_list'].append([jj, ii, phi, theta]) diagonal = np.diag(unitary) info['phase_angle'] = np.angle(diagonal) mask = np.logical_or(info['phase_angle'] >= 2 * np.pi, info['phase_angle'] < 0) info['phase_angle'][mask] -= np.floor(info['phase_angle'][mask] / np.pi / 2) * np.pi * 2 return info, unitary def decomp_cr(unitary: np.ndarray, method: str) -> dict: n_dim = len(unitary) info = {} info['N'] = n_dim info['method'] = method info['MZI_list'] = [] # jj,ii,phi,theta info['right'] = [] info['left'] = [] if 'dd' in method: period_theta = 2 * np.pi period_phi = 4 * np.pi elif 'ds' in method: period_theta = 4 * np.pi period_phi = 4 * np.pi else: period_theta = 2 * np.pi period_phi = 2 * np.pi for i in range(n_dim - 1): # 从下往上第i个反对角线 if i % 2: # 左乘, 利用TU消元; for j in range(i + 1): # 反对角线的元素计数 # 消元顺序:从左上到右下 jj = j # 当前待消元元素列号 ii = n_dim - 1 - i + j # 当前待消元元素行号 # print(ii,jj) # if unitary[ii,jj] == 0: # continue ratio = unitary[ii - 1, jj] / (unitary[ii, jj] + 1e-32) theta = 2 * np.arctan(np.abs(ratio)) phi = -np.angle(ratio) multiple = get_matrix_constr_r([ii - 1, ii, phi, theta], n_dim, method) unitary = multiple @ unitary info['left'].append([ii - 1, ii, phi, theta]) else: # 利用UT^{-1}消元,即利用 unitary[ii,jj+1] 消去 unitary[ii,jj] for j in range(i + 1)[::-1]: # 反对角线的元素计数 # 消元顺序:从右下到左上 jj = j # 当前待消元元素列号 ii = n_dim - 1 - i + j # 当前待消元元素行号 # print(ii,jj) # if unitary[ii,jj] == 0: # continue ratio = unitary[ii, jj + 1] / (unitary[ii, jj] + 1e-32) theta = 2 * np.arctan(np.abs(ratio)) phi = -np.angle(-ratio) multiple = get_matrix_inverse_r([jj, jj + 1, phi, theta], n_dim, method) unitary = unitary @ multiple info['right'].append([jj, jj + 1, phi, theta]) phase_angle = np.angle(np.diag(unitary)) info['phase_angle_ori'] = phase_angle.copy() # unitary=LLLDRRR,本行保存D for idx in range(len(info['right'])): info['right'][idx][2] = period_cut(info['right'][idx][2], period_phi) info['right'][idx][3] = period_cut(info['right'][idx][3], period_theta) info['MZI_list'].append(info['right'][idx]) left_list = info['left'][::-1] for idx in range(len(left_list)): jj, ii, phi, theta = left_list[idx] phi_, theta_, phase_angle[jj], phase_angle[ii] = clements_diagonal_transform( phi, theta, phase_angle[jj], phase_angle[ii], method ) phi_ = period_cut(phi_, period_phi) theta_ = period_cut(theta_, period_theta) info['MZI_list'].append([jj, ii, phi_, theta_]) info['phase_angle'] = phase_angle.copy() # unitary=D'L'L'L'RRR,本行保存新的D mask = np.logical_or(info['phase_angle'] >= 2 * np.pi, info['phase_angle'] < 0) info['phase_angle'][mask] -= np.floor(info['phase_angle'][mask] / np.pi / 2) * np.pi * 2 return info, unitary def decomp_rl(unitary: np.ndarray, method: str) -> dict: n_dim = len(unitary) info = {} info['N'] = n_dim info['method'] = method info['MZI_list'] = [] # jj,ii,phi,theta if 'dd' in method: period_theta = 2 * np.pi period_phi = 4 * np.pi elif 'ds' in method: period_theta = 4 * np.pi period_phi = 4 * np.pi else: period_theta = 2 * np.pi period_phi = 2 * np.pi for i in range(n_dim): ii = n_dim - 1 - i # 基准行 ii for jj in range(ii)[::-1]: # print(ii,jj) # 要用 unitary[ii] 把 unitary[jj]的第ii号元素变成0 # if unitary[jj,ii] == 0: # continue ratio = unitary[ii, ii] / (unitary[jj, ii] + 1e-32) theta = 2 * np.arctan(np.abs(ratio)) phi = -np.angle(-ratio) multiple = get_matrix_inverse_l([jj, ii, phi, theta], n_dim, method) unitary = multiple @ unitary phi = period_cut(phi, period_phi) theta = period_cut(theta, period_theta) info['MZI_list'].append([jj, ii, phi, theta]) diagonal = np.diag(unitary) info['phase_angle'] = np.angle(diagonal) mask = np.logical_or(info['phase_angle'] >= 2 * np.pi, info['phase_angle'] < 0) info['phase_angle'][mask] -= np.floor(info['phase_angle'][mask] / np.pi / 2) * np.pi * 2 return info, unitary def decomp_cl(unitary: np.ndarray, method: str) -> dict: n_dim = len(unitary) info = {} info['N'] = n_dim info['method'] = method info['MZI_list'] = [] # jj,ii,phi,theta info['right'] = [] info['left'] = [] if 'dd' in method: period_theta = 2 * np.pi period_phi = 4 * np.pi elif 'ds' in method: period_theta = 4 * np.pi period_phi = 4 * np.pi else: period_theta = 2 * np.pi period_phi = 2 * np.pi for i in range(n_dim - 1): # 从下往上第i个反对角线 if i % 2: # 左乘, 利用T^{-1}U消元; for j in range(i + 1): # 反对角线的元素计数 # 消元顺序:从左上到右下 jj = j # 当前待消元元素列号 ii = n_dim - 1 - i + j # 当前待消元元素行号 # print(ii,jj) # if unitary[ii,jj] == 0: # continue ratio = unitary[ii - 1, jj] / (unitary[ii, jj] + 1e-32) theta = 2 * np.arctan(np.abs(ratio)) phi = np.angle(ratio) multiple = get_matrix_inverse_l([ii - 1, ii, phi, theta], n_dim, method) unitary = multiple @ unitary info['left'].append([ii - 1, ii, phi, theta]) else: # 利用UT消元,即利用 unitary[ii,jj+1] 消去 unitary[ii,jj] for j in range(i + 1)[::-1]: # 反对角线的元素计数 # 消元顺序:从右下到左上 jj = j # 当前待消元元素列号 ii = n_dim - 1 - i + j # 当前待消元元素行号 # print(ii,jj) # if unitary[ii,jj] == 0: # continue ratio = unitary[ii, jj + 1] / (unitary[ii, jj] + 1e-32) theta = 2 * np.arctan(np.abs(ratio)) phi = np.angle(-ratio) multiple = get_matrix_constr_l([jj, jj + 1, phi, theta], n_dim, method) unitary = unitary @ multiple info['right'].append([jj, jj + 1, phi, theta]) phase_angle = np.angle(np.diag(unitary)) info['phase_angle_ori'] = phase_angle.copy() # U=LLLDRRR,本行保存D for idx in range(len(info['left'])): info['left'][idx][2] = period_cut(info['left'][idx][2], period_phi) info['left'][idx][3] = period_cut(info['left'][idx][3], period_theta) info['MZI_list'].append(info['left'][idx]) left_list = info['right'][::-1] for idx in range(len(left_list)): jj, ii, phi, theta = left_list[idx] phi_, theta_, phase_angle[jj], phase_angle[ii] = clements_diagonal_transform( phi, theta, phase_angle[jj], phase_angle[ii], method ) phi_ = period_cut(phi_, period_phi) theta_ = period_cut(theta_, period_theta) info['MZI_list'].append([jj, ii, phi_, theta_]) info['phase_angle'] = phase_angle.copy() # unitary =D'L'L'L'RRR,本行保存新的D mask = np.logical_or(info['phase_angle'] >= 2 * np.pi, info['phase_angle'] < 0) info['phase_angle'][mask] -= np.floor(info['phase_angle'][mask] / np.pi / 2) * np.pi * 2 return info, unitary def calc_factor_inverse(method, phi, theta): # 计算MZI矩阵T^{-1}的系数(相当于全局相位) if 'sd' in method: return -1j elif 'ss' in method: return -1j * np.exp(-1j * theta / 2) elif 'dd' in method: return -1j * np.exp(-1j * (theta - phi) / 2) elif 'ds' in method: return -1j * np.exp(1j * phi / 2) def calc_factor_constr(method, phi, theta): # 计算MZI矩阵T的系数(相当于全局相位) return calc_factor_inverse(method, phi, theta).conjugate() def get_matrix_constr_l(info, n_dim, method): jj, ii, phi, theta = info factor = calc_factor_constr(method, phi, theta) multiple = np.eye(n_dim, dtype=complex) multiple[jj, jj] = factor * np.exp(1j * phi) * np.sin(theta / 2) multiple[jj, ii] = factor * np.exp(1j * phi) * np.cos(theta / 2) multiple[ii, jj] = factor * np.cos(theta / 2) multiple[ii, ii] = factor * -np.sin(theta / 2) return multiple def get_matrix_inverse_l(info, n_dim, method): jj, ii, phi, theta = info factor = calc_factor_inverse(method, phi, theta) multiple = np.eye(n_dim, dtype=complex) multiple[jj, jj] = factor * np.exp(-1j * phi) * np.sin(theta / 2) multiple[jj, ii] = factor * np.cos(theta / 2) multiple[ii, jj] = factor * np.exp(-1j * phi) * np.cos(theta / 2) multiple[ii, ii] = factor * -np.sin(theta / 2) return multiple def get_matrix_constr_r(info, n_dim, method): jj, ii, phi, theta = info factor = calc_factor_constr(method, phi, theta) multiple = np.eye(n_dim, dtype=complex) multiple[jj, jj] = factor * np.exp(1j * phi) * np.sin(theta / 2) multiple[jj, ii] = factor * np.cos(theta / 2) multiple[ii, jj] = factor * np.exp(1j * phi) * np.cos(theta / 2) multiple[ii, ii] = factor * -np.sin(theta / 2) return multiple def get_matrix_inverse_r(info, n_dim, method): jj, ii, phi, theta = info factor = calc_factor_inverse(method, phi, theta) multiple = np.eye(n_dim, dtype=complex) multiple[jj, jj] = factor * np.exp(-1j * phi) * np.sin(theta / 2) multiple[jj, ii] = factor * np.exp(-1j * phi) * np.cos(theta / 2) multiple[ii, jj] = factor * np.cos(theta / 2) multiple[ii, ii] = factor * -np.sin(theta / 2) return multiple def clements_diagonal_transform(phi, theta, a1, a2, method): if 'sd' in method: theta_ = theta phi_ = a1 - a2 b1 = a2 - phi + np.pi b2 = a2 + np.pi return phi_, theta_, b1, b2 elif 'ss' in method: theta_ = theta phi_ = a1 - a2 b1 = a2 - phi + np.pi - theta b2 = a2 + np.pi - theta return phi_, theta_, b1, b2 elif 'dd' in method: theta_ = theta phi_ = a1 - a2 b1 = a2 - phi + np.pi - theta + (phi + phi_) / 2 b2 = a2 + np.pi - theta + (phi + phi_) / 2 return phi_, theta_, b1, b2 elif 'ds' in method: theta_ = theta phi_ = a1 - a2 b1 = a2 - phi + np.pi + (phi + phi_) / 2 b2 = a2 + np.pi + (phi + phi_) / 2 return phi_, theta_, b1, b2 method = self.method if method not in [ 'rssr', 'rsdr', 'rdsr', 'rddr', 'rssl', 'rsdl', 'rdsl', 'rddl', 'cssr', 'csdr', 'cdsr', 'cddr', 'cssl', 'csdl', 'cdsl', 'cddl', ]: raise LookupError('请检查分解方式!') elif method[0] + method[-1] == 'cr': temp_0 = decomp_cr(self.unitary, method)[0] elif method[0] + method[-1] == 'cl': temp_0 = decomp_cl(self.unitary, method)[0] elif method[0] + method[-1] == 'rr': temp_0 = decomp_rr(self.unitary, method)[0] elif method[0] + method[-1] == 'rl': temp_0 = decomp_rl(self.unitary, method)[0] temp_1 = self.sort_mzi(temp_0) temp_2 = self.ps_pos(temp_1, temp_0['phase_angle']) return temp_0, temp_1, temp_2
[docs] def sort_mzi(self, mzi_info): """Sort mzi parameters in the same array for plotting.""" dic_mzi = defaultdict(list) # 当key不存在时对应的value是[] mzi_list = mzi_info['MZI_list'] for i in mzi_list: dic_mzi[tuple(i[0:2])].append(i[2:]) return dic_mzi
[docs] def ps_pos(self, dic_mzi, phase_angle): """Label the position of each phaseshifter for ``'cssr'`` case.""" if self.method == 'cssr': dic_pos = {} nmode = self.unitary.shape[0] dic_ = dic_mzi for mode in range(nmode): pair = (mode, mode + 1) value = dic_[pair] value = np.array(value).flatten() for k in range(len(value)): dic_pos[(mode, k)] = np.round((value[k]), 4) if mode == nmode - 1: dic_pos[(mode, 0)] = np.round((phase_angle[mode]), 4) else: dic_pos[(mode, k + 1)] = np.round((phase_angle[mode]), 4) return dic_pos else: return None