""" Contains utility class that contains all iPEPS tensors """
from dataclasses import dataclass, field, fields, replace, asdict
from functools import partial
from typing import List
from .empty_tensor import EmptyT
from .nested import Nested
from .tlist import TList, hold_write
[docs]@dataclass
class CTMTensors:
"""
This is a utility class that contains all tensors related to an iPEPS.
In effect, this forms a representation of the full state, including
the site tensors with variational parameters, the boundary tensors
generated by CTM and the projectors.
There are several fields of this dataclass that are generated
automatically, which provide convenient wrappers for the tensors.
For example, the :attr:`CTMTensors.A` field returns the site tensors
(contained in a :class:`adpeps.utils.tlist.TList`)
"""
A: TList
Ad: TList
Cs: List[TList] = field(default_factory=list)
Ts: List[TList] = field(default_factory=list)
B: TList = field(default=None, metadata={'init_tlist': True})
Bd: TList = field(default=None, metadata={'init_tlist': True})
B_Cs: List[TList] = field(default=None, metadata={'init_tlists': True})
B_Ts: List[TList] = field(default=None, metadata={'init_tlists': True})
Bd_Cs: List[TList] = field(default=None, metadata={'init_tlists': True})
Bd_Ts: List[TList] = field(default=None, metadata={'init_tlists': True})
BB_Cs: List[TList] = field(default=None, metadata={'init_tlists': True})
BB_Ts: List[TList] = field(default=None, metadata={'init_tlists': True})
Pl: TList = None
Pr: TList = None
Pt: TList = None
Pb: TList = None
Plb: TList = None
Prb: TList = None
Ptb: TList = None
Pbb: TList = None
observables: List = field(default_factory=list)
def _get_field_item(self, fieldname=None, ix=None):
return getattr(self, fieldname).__getitem__(ix)
def _get_field_nested_item(self, fieldname=None, ix=None):
return getattr(self, fieldname)(ix)
def _set_field_nested_item(self, fieldname=None, ix=None, value=None):
return getattr(self, f"{fieldname}_set")(ix, value)
def hold(self, *fields):
def _convert_all(f):
if f.startswith('all_'):
f = f[4:]
return [f, f"B_{f}", f"Bd_{f}", f"BB_{f}"]
else:
return (f,)
fields = [field for fs in [_convert_all(f) for f in fields] for field in fs]
tensors = tuple([getattr(self, field) for field in fields])
return hold_write(*tensors)
def __post_init__(self):
base_tlist = self.A
for f in fields(self):
try:
if getattr(self, f.name) is None and f.metadata['init_tlist']:
setattr(self, f.name, TList.empty_like(base_tlist, empty_obj=EmptyT()))
except KeyError:
pass
try:
if getattr(self, f.name) is None and f.metadata['init_tlists']:
setattr(self, f.name, [TList.empty_like(base_tlist, empty_obj=EmptyT()) for _ in range(4)])
except KeyError:
pass
def all_Cs(self, ix):
base_tlist = self.A
res = TList.empty_like(base_tlist, empty_obj=EmptyT())
for i in range(len(res._data)):
res._data[i] = Nested([self.Cs[ix][i], self.B_Cs[ix][i],
self.Bd_Cs[ix][i], self.BB_Cs[ix][i]])
return res
def all_Ts(self, ix):
base_tlist = self.A
res = TList.empty_like(base_tlist, empty_obj=EmptyT())
for i in range(len(res._data)):
res._data[i] = Nested([self.Ts[ix][i], self.B_Ts[ix][i],
self.Bd_Ts[ix][i], self.BB_Ts[ix][i]])
return res
def update(self, fieldnames, ixs, values):
if isinstance(fieldnames, str):
fieldnames = (fieldnames,)
values = (values,)
ixs = (ixs,)
for i,f in enumerate(fieldnames):
value = values[i]
ix = ixs[i]
assert isinstance(value, Nested), "Use the all_Ci setter only with Nested tensors"
getattr(self, f)[ix] = value[0]
getattr(self, f"B_{f}")[ix] = value[1]
getattr(self, f"Bd_{f}")[ix] = value[2]
getattr(self, f"BB_{f}")[ix] = value[3]
@property
def all_A(self):
base_tlist = self.A
res = TList.empty_like(base_tlist, empty_obj=EmptyT())
for i in range(len(res._data)):
res._data[i] = Nested([self.A._data[i], self.B._data[i],
EmptyT(), EmptyT()])
return res
@property
def all_Ad(self):
base_tlist = self.A
res = TList.empty_like(base_tlist, empty_obj=EmptyT())
for i in range(len(res._data)):
res._data[i] = Nested([self.Ad._data[i], EmptyT(),
self.Bd._data[i], EmptyT()])
return res
def stop_gradient(self, only_boundaries=True):
for i in range(4):
self.Cs[i] = self.Cs[i].stop_gradient()
self.Ts[i] = self.Ts[i].stop_gradient()
self.B_Cs[i] = self.B_Cs[i].stop_gradient()
self.B_Ts[i] = self.B_Ts[i].stop_gradient()
self.Bd_Cs[i] = self.Bd_Cs[i].stop_gradient()
self.Bd_Ts[i] = self.Bd_Ts[i].stop_gradient()
self.BB_Cs[i] = self.BB_Cs[i].stop_gradient()
self.BB_Ts[i] = self.BB_Ts[i].stop_gradient()
if not only_boundaries:
self.A = self.A.stop_gradient()
self.Ad = self.Ad.stop_gradient()
self.B = self.B.stop_gradient()
self.Bd = self.Bd.stop_gradient()
def _wrap_f(self, fieldname=None, ix=None):
return self._get_field_item(fieldname, ix)
def _wrap_nested_f(self, fieldname=None, ix=None):
return self._get_field_nested_item(fieldname, ix)
def _wrap_nested_f_set(self, value, fieldname=None, ix=None):
return self._set_field_nested_item(fieldname, ix, value)
attrs = ['Cs', 'Ts', 'B_Cs', 'B_Ts', 'Bd_Cs', 'Bd_Ts', 'BB_Cs', 'BB_Ts']
for attr in attrs:
for i in range(4):
new_attr = property(partial(_wrap_f, fieldname=attr, ix=i))
# new_attr.__doc__ = f"Boundary tensors `{attr}` computed by CTM"
new_attr.__doc__ = ":meta private:"
setattr(CTMTensors, f"{attr[:-1]}{i+1}", new_attr)
attrs = ['all_Cs', 'all_Ts']
for attr in attrs:
for i in range(4):
new_attr = property(partial(_wrap_nested_f, fieldname=attr, ix=i),
partial(_wrap_nested_f_set, fieldname=attr, ix=i))
new_attr.__doc__ = ":meta private:"
# new_attr.__doc__ = ":class:`adpeps.utils.tlist.TList` containing all versions \
# (only ground-state, ground state + one `B` tensor, ground state \
# + one `Bdagger` tensor, ground state + one `B` tensor + one `Bdagger`\
# tensor) of the boundary tensors."
setattr(CTMTensors, f"{attr[:-1]}{i+1}", new_attr)