Source code for deepquantum.adjoint
"""Adjoint differentiation"""
from copy import deepcopy
from typing import TYPE_CHECKING
import torch
from torch import nn
from torch.autograd import Function
from .distributed import dist_many_ctrl_one_targ_gate, dist_many_targ_gate, dist_one_targ_gate, inner_product_dist
from .gate import CombinedSingleGate, SingleGate
from .operation import Gate
from .state import DistributedQubitState
if TYPE_CHECKING:
from .layer import Observable
[docs]
class AdjointExpectation(Function):
"""Adjoint differentiation
See https://arxiv.org/pdf/2009.02823
Args:
state: The final quantum state.
operators: The quantum operations.
observable: The observable.
*parameters: The parameters of the quantum circuit.
"""
[docs]
@staticmethod
def forward(
ctx, state: DistributedQubitState, operators: nn.Sequential, observable: 'Observable', *parameters: torch.Tensor
) -> torch.Tensor:
ctx.state_phi = state
ctx.operators = operators
ctx.observable = observable
ctx.state_lambda = observable(deepcopy(state))
ctx.save_for_backward(*parameters)
return inner_product_dist(ctx.state_lambda, ctx.state_phi).real
[docs]
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> tuple[None, ...]:
parameters = [*ctx.saved_tensors]
grads = []
idx = 1
for op in ctx.operators[::-1]:
if isinstance(op, CombinedSingleGate):
gates = op.gates
elif isinstance(op, Gate):
gates = [op]
for gate in gates[::-1]:
if gate.npara > 0:
gate.init_para(parameters[-idx])
gate_dagger = gate.inverse()
ctx.state_phi = gate_dagger(ctx.state_phi)
if gate.npara > 0:
if parameters[-idx].requires_grad:
du_dx = gate.get_derivative(parameters[-idx]).unsqueeze(0).flatten(0, -3) # (npara, 2**n, 2**n)
wires = gate.controls + gate.wires
targets = [gate.nqubit - wire - 1 for wire in wires]
grads_gate = []
for mat in du_dx:
state_mu = deepcopy(ctx.state_phi)
if isinstance(gate, SingleGate):
if len(gate.controls) == 0:
state_mu = dist_one_targ_gate(state_mu, targets[0], mat)
else:
state_mu = dist_many_ctrl_one_targ_gate(
state_mu, targets[:-1], targets[-1], mat, True
)
else:
zeros = mat.new_zeros(2 ** len(wires) - 2 ** len(gate.wires)).diag_embed()
matrix = torch.block_diag(zeros, mat)
state_mu = dist_many_targ_gate(state_mu, targets, matrix)
grad = grad_out * 2 * inner_product_dist(ctx.state_lambda, state_mu).real
grads_gate.append(grad)
grads.append(torch.stack(grads_gate).reshape(parameters[-idx].shape))
else:
grads.append(None)
idx += 1
ctx.state_lambda = gate_dagger(ctx.state_lambda)
return None, None, None, *grads[::-1]