Source code for adpeps.utils.nested

""" Contains utility class that represents a collection of tensors of 
    different types, with operations that can be applied to all 
    contained tensors at once
"""

import cmath
import jax.numpy as np

from adpeps.types import TensorType


[docs]class Nested: """ This is a helper class for the efficient contraction of variants of tensors, used in the energy evaluation of excited states A Nested tensor contains the following variants (some may be empty): - :attr:`tensors[0]`: regular tensor (no B or Bd) - :attr:`tensors[1]`: (terms with) a single B tensor - :attr:`tensors[2]`: (terms with) a single Bd tensor - :attr:`tensors[3]`: (terms with) both a B and a Bd tensor When two Nested tensors x,y are contracted, all combinations are taken into account and the result is again a Nested tensor, filled with the following variants: - :attr:`tensors[0]: x[0] * y[0]` - :attr:`tensors[1]: x[1] * y[0] + x[0] * y[1]` - :attr:`tensors[2]: x[2] * y[0] + x[0] * y[2]` - :attr:`tensors[3]: x[3] * y[0] + x[2] * y[1] + x[1] * y[2] + x[0] * y[3]` By using Nested tensors in a (large) contraction, the many different terms are resummed on the fly, leading to a potentially reduced computational cost Note: Most implented functions act as wrappers for the corresponding `numpy` functions on the individual tensors """ def __init__(self, tensors): self.tensors = tensors
[docs] def normalize(self): """ Normalize the contained tensors by the largest value of the first element of :attr:`self.tensors` """ factor = np.abs(self[0]).max() return self * (1 / factor), factor
[docs] def mult(self, other: TensorType, *args) -> 'Nested': """ Args: other: other tensor-like object to contract with *args: arguments to be passed to the contraction method (:code:`np.tensordot`) Returns: res: result of the contraction """ def _mult_function(A, B, *args): if hasattr(A, 'mult'): return A.mult(B, *args) elif len(B) == 0: return B.mult(A, *args) return np.tensordot(A, B, *args) if isinstance(other, np.ndarray): new_data = 4 * [[]] new_data[0] = _mult_function(self.tensors[0], other, *args) new_data[1] = _mult_function(self.tensors[1], other, *args) new_data[2] = _mult_function(self.tensors[2], other, *args) new_data[3] = _mult_function(self.tensors[3], other, *args) return Nested(new_data) new_data = 4 * [[]] new_data[0] = _mult_function(self.tensors[0], other.tensors[0], *args) new_data[1] = _mult_function(self.tensors[1], other.tensors[0], *args) +\ _mult_function(self.tensors[0], other.tensors[1], *args) new_data[2] = _mult_function(self.tensors[2], other.tensors[0], *args) +\ _mult_function(self.tensors[0], other.tensors[2], *args) new_data[3] = _mult_function(self.tensors[3], other.tensors[0], *args) +\ _mult_function(self.tensors[2], other.tensors[1], *args) +\ _mult_function(self.tensors[1], other.tensors[2], *args) +\ _mult_function(self.tensors[0], other.tensors[3], *args) res = Nested(new_data) return res
[docs] def transpose(self, *args) -> 'Nested': """ Applies :code:`transpose` to each contained tensor """ new_data = [self.tensors[i].transpose(*args) for i in range(4)] return Nested(new_data)
def __mul__(self, other): new_data = [self.tensors[i] * other for i in range(4)] return Nested(new_data) def __rmul__(self, other): new_data = [other * self.tensors[i] for i in range(4)] return Nested(new_data) def __truediv__(self, other): new_data = [self.tensors[i] / other for i in range(4)] return Nested(new_data) def __add__(self, other): if isinstance(other, Nested): new_data = [self.tensors[i] + other.tensors[i] for i in range(4)] else: new_data = [self.tensors[i] + other for i in range(4)] return Nested(new_data) def __radd__(self, other): return self + other def __getitem__(self, ix): return self.tensors[ix] def __setitem__(self, ix, value): self.tensors[ix] = value def __repr__(self): return "(Nested) " + self.tensors.__repr__() def __neg__(self): return Nested([-self.tensors[i] for i in range(4)]) def shift(self, phi): new_data = [self.tensors[0], self.tensors[1] * exp(phi), self.tensors[2] * exp(-phi), self.tensors[3]] return Nested(new_data) def __len__(self): try: return len(self.tensors[0]) except Exception: return self.tensors[0].size @property def real(self): res = Nested([self.tensors[i].real for i in range(4)]) return res @property def shape(self): return self.tensors[0].shape @property def dims(self): return self.tensors[0].dims def check_contr_inds(self, other, *args, **kwargs): return self[0].check_contr_inds(other[0], *args, **kwargs) def numel(self): return self[0].numel() @classmethod def only_gs(cls, tensor, empty_obj=[]): return cls([tensor, empty_obj, empty_obj, empty_obj])
def exp(phi): return cmath.exp(1j * phi)