Source code for adpeps.utils.ctmtensors

""" Contains utility class that contains all iPEPS tensors """

from dataclasses import dataclass, field, fields, replace, asdict
from functools import partial
from typing import List

from .empty_tensor import EmptyT
from .nested import Nested
from .tlist import TList, hold_write

[docs]@dataclass class CTMTensors: """ This is a utility class that contains all tensors related to an iPEPS. In effect, this forms a representation of the full state, including the site tensors with variational parameters, the boundary tensors generated by CTM and the projectors. There are several fields of this dataclass that are generated automatically, which provide convenient wrappers for the tensors. For example, the :attr:`CTMTensors.A` field returns the site tensors (contained in a :class:`adpeps.utils.tlist.TList`) """ A: TList Ad: TList Cs: List[TList] = field(default_factory=list) Ts: List[TList] = field(default_factory=list) B: TList = field(default=None, metadata={'init_tlist': True}) Bd: TList = field(default=None, metadata={'init_tlist': True}) B_Cs: List[TList] = field(default=None, metadata={'init_tlists': True}) B_Ts: List[TList] = field(default=None, metadata={'init_tlists': True}) Bd_Cs: List[TList] = field(default=None, metadata={'init_tlists': True}) Bd_Ts: List[TList] = field(default=None, metadata={'init_tlists': True}) BB_Cs: List[TList] = field(default=None, metadata={'init_tlists': True}) BB_Ts: List[TList] = field(default=None, metadata={'init_tlists': True}) Pl: TList = None Pr: TList = None Pt: TList = None Pb: TList = None Plb: TList = None Prb: TList = None Ptb: TList = None Pbb: TList = None observables: List = field(default_factory=list) def _get_field_item(self, fieldname=None, ix=None): return getattr(self, fieldname).__getitem__(ix) def _get_field_nested_item(self, fieldname=None, ix=None): return getattr(self, fieldname)(ix) def _set_field_nested_item(self, fieldname=None, ix=None, value=None): return getattr(self, f"{fieldname}_set")(ix, value) def hold(self, *fields): def _convert_all(f): if f.startswith('all_'): f = f[4:] return [f, f"B_{f}", f"Bd_{f}", f"BB_{f}"] else: return (f,) fields = [field for fs in [_convert_all(f) for f in fields] for field in fs] tensors = tuple([getattr(self, field) for field in fields]) return hold_write(*tensors) def __post_init__(self): base_tlist = self.A for f in fields(self): try: if getattr(self, f.name) is None and f.metadata['init_tlist']: setattr(self, f.name, TList.empty_like(base_tlist, empty_obj=EmptyT())) except KeyError: pass try: if getattr(self, f.name) is None and f.metadata['init_tlists']: setattr(self, f.name, [TList.empty_like(base_tlist, empty_obj=EmptyT()) for _ in range(4)]) except KeyError: pass def all_Cs(self, ix): base_tlist = self.A res = TList.empty_like(base_tlist, empty_obj=EmptyT()) for i in range(len(res._data)): res._data[i] = Nested([self.Cs[ix][i], self.B_Cs[ix][i], self.Bd_Cs[ix][i], self.BB_Cs[ix][i]]) return res def all_Ts(self, ix): base_tlist = self.A res = TList.empty_like(base_tlist, empty_obj=EmptyT()) for i in range(len(res._data)): res._data[i] = Nested([self.Ts[ix][i], self.B_Ts[ix][i], self.Bd_Ts[ix][i], self.BB_Ts[ix][i]]) return res def update(self, fieldnames, ixs, values): if isinstance(fieldnames, str): fieldnames = (fieldnames,) values = (values,) ixs = (ixs,) for i,f in enumerate(fieldnames): value = values[i] ix = ixs[i] assert isinstance(value, Nested), "Use the all_Ci setter only with Nested tensors" getattr(self, f)[ix] = value[0] getattr(self, f"B_{f}")[ix] = value[1] getattr(self, f"Bd_{f}")[ix] = value[2] getattr(self, f"BB_{f}")[ix] = value[3] @property def all_A(self): base_tlist = self.A res = TList.empty_like(base_tlist, empty_obj=EmptyT()) for i in range(len(res._data)): res._data[i] = Nested([self.A._data[i], self.B._data[i], EmptyT(), EmptyT()]) return res @property def all_Ad(self): base_tlist = self.A res = TList.empty_like(base_tlist, empty_obj=EmptyT()) for i in range(len(res._data)): res._data[i] = Nested([self.Ad._data[i], EmptyT(), self.Bd._data[i], EmptyT()]) return res def stop_gradient(self, only_boundaries=True): for i in range(4): self.Cs[i] = self.Cs[i].stop_gradient() self.Ts[i] = self.Ts[i].stop_gradient() self.B_Cs[i] = self.B_Cs[i].stop_gradient() self.B_Ts[i] = self.B_Ts[i].stop_gradient() self.Bd_Cs[i] = self.Bd_Cs[i].stop_gradient() self.Bd_Ts[i] = self.Bd_Ts[i].stop_gradient() self.BB_Cs[i] = self.BB_Cs[i].stop_gradient() self.BB_Ts[i] = self.BB_Ts[i].stop_gradient() if not only_boundaries: self.A = self.A.stop_gradient() self.Ad = self.Ad.stop_gradient() self.B = self.B.stop_gradient() self.Bd = self.Bd.stop_gradient()
def _wrap_f(self, fieldname=None, ix=None): return self._get_field_item(fieldname, ix) def _wrap_nested_f(self, fieldname=None, ix=None): return self._get_field_nested_item(fieldname, ix) def _wrap_nested_f_set(self, value, fieldname=None, ix=None): return self._set_field_nested_item(fieldname, ix, value) attrs = ['Cs', 'Ts', 'B_Cs', 'B_Ts', 'Bd_Cs', 'Bd_Ts', 'BB_Cs', 'BB_Ts'] for attr in attrs: for i in range(4): new_attr = property(partial(_wrap_f, fieldname=attr, ix=i)) # new_attr.__doc__ = f"Boundary tensors `{attr}` computed by CTM" new_attr.__doc__ = ":meta private:" setattr(CTMTensors, f"{attr[:-1]}{i+1}", new_attr) attrs = ['all_Cs', 'all_Ts'] for attr in attrs: for i in range(4): new_attr = property(partial(_wrap_nested_f, fieldname=attr, ix=i), partial(_wrap_nested_f_set, fieldname=attr, ix=i)) new_attr.__doc__ = ":meta private:" # new_attr.__doc__ = ":class:`adpeps.utils.tlist.TList` containing all versions \ # (only ground-state, ground state + one `B` tensor, ground state \ # + one `Bdagger` tensor, ground state + one `B` tensor + one `Bdagger`\ # tensor) of the boundary tensors." setattr(CTMTensors, f"{attr[:-1]}{i+1}", new_attr)