Source code for deepquantum.utils

"""Utilities"""

import time
from collections.abc import Callable
from functools import wraps
from typing import Any

import torch

import deepquantum as dq


[docs] def record_time(func: Callable) -> Callable: """A decorator that records the running time of a function.""" @wraps(func) def wrapped_function(*args, **kwargs): t1 = time.time() rst = func(*args, **kwargs) t2 = time.time() print(f'running time of "{func.__name__}": {t2 - t1}') return rst return wrapped_function
[docs] class Time: """A decorator that records the running time of a function.""" def __init__(self) -> None: pass def __call__(self, func: Callable) -> Callable: @wraps(func) def wrapped_function(*args, **kwargs): t1 = time.time() rst = func(*args, **kwargs) t2 = time.time() print(f'running time of "{func.__name__}": {t2 - t1}') return rst return wrapped_function
[docs] def apply_complex_fix(fn: Any, tensors_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Apply the function to the tensors in the dictionary and convert the result to complex dtype.""" first_tensor = next(iter(tensors_dict.values())) probe = fn(torch.empty(0, dtype=first_tensor.real.dtype, device=first_tensor.device)) target_dtype = dq.dtype_map.get(probe.dtype, probe.dtype) return {name: tensor.to(probe.device, target_dtype) for name, tensor in tensors_dict.items()}