"""
List object with additional features, used for storing
the iPEPS tensors
Items in the list can be accessed by either a linear index
or a (i,j) double index, where i and j will be automatically
taken modulo the unit cell size (i.e. i = i % n_x)
Additionally, convenience functions that work on tensors can
be defined for the whole list, e.g. conj()
"""
import contextlib
import jax
import jax.numpy as np
import numpy as onp
from .empty_tensor import EmptyT
[docs]@contextlib.contextmanager
def cur_loc(*loc: int):
""" Shift the locations of the tensors relative to a
new zero (loc) while in this context
Args:
loc: shifts (x,y)
Example:
>>> l = TList([[1,2], [3,4]])
>>> l[0,0]
1
>>> with cur_loc(1,0):
>>> l[0,0]
2
>>> l[0,1]
4
>>> l[0,0]
1
Note that this applies to ALL TList objects while
inside the context
"""
pre_patched_value = getattr(TList, '_loc')
setattr(TList, '_loc', loc)
yield TList
setattr(TList, '_loc', pre_patched_value)
[docs]@contextlib.contextmanager
def hold_write(*lists: 'TList'):
""" Hold off on writing to the list while
inside the context
Args:
lists: one or more TList objects that should have the writing
action delayed until the context is disabled
Example:
>>> l = TList([[1,2], [3,4]])
>>> with hold_write(l):
>>> l[0,0] = 100
>>> l[0,0]
1
>>> l[0,0]
100
"""
for l in lists:
l._hold_write = True
yield
for l in lists:
l._purge_tmp()
[docs]@contextlib.contextmanager
def set_pattern(pattern):
""" Set pattern for all new TLists that are created while
the context is active
Args:
pattern:
"""
pre_patched_value = getattr(TList, '_default_pattern')
setattr(TList, '_default_pattern', pattern)
yield TList
setattr(TList, '_default_pattern', pre_patched_value)
class TList:
_loc = (0,0)
_default_pattern = None
_changed = None
def __init__(self, data=None, shape=None, pattern=None, empty_obj=[[]]):
self._tmpdata = None
self.pattern = pattern
self._hold_write = False
self.empty_obj = empty_obj
if pattern is None and self._default_pattern is not None:
self.pattern = self._default_pattern
if self.pattern is None:
if data is not None:
try:
iter(data) # Check if iterable
data = np.array(data, dtype='object')
self._data = data.reshape([-1], order='C').tolist()
if data.ndim == 1:
self.size = (data.shape[0], 1)
else:
self.size = (data.shape[1], data.shape[0])
except:
self._data = [data]
self.size = (1,1)
elif shape is not None:
self._data = (shape[0]*shape[1]) * empty_obj
self.size = shape
else:
self._data = None
self.size = ()
else:
self.pattern = np.array(self.pattern)
self.size = (self.pattern.shape[1], self.pattern.shape[0])
if data is not None:
try:
iter(data) # Check if iterable
data = np.array(data, dtype='object')
if data.size == np.unique(self.pattern).size:
self._data = data.reshape([-1], order='C').tolist()
else:
self._data = np.unique(self.pattern).size * empty_obj
for j in range(self.pattern.shape[1]):
for i in range(self.pattern.shape[0]):
self._data[self.pattern[i,j]] = data[i,j]
except:
self._data = [data]
self.size = (1,1)
else:
self._data = np.unique(self.pattern).size * empty_obj
assert len(self._data) == np.unique(self.pattern).size, \
"Data must contain one element for each unique identifier in pattern"
self.reset_changed()
def x_major(self):
return (self._conv_ix((x,y)) for y in range(self.size[1]) for x in range(self.size[0]))
def y_major(self):
return (self._conv_ix((x,y)) for x in range(self.size[0]) for y in range(self.size[1]))
def __len__(self):
return len(self._data)
def mean(self):
try:
finite_elems = [x for x in self._data if isfinite(x)]
return sum(finite_elems) / len(finite_elems)
except Exception as e:
return sum(self._data) / len(self)
def sum(self):
try:
finite_elems = [x for x in self._data if isfinite(x)]
return sum(finite_elems)
except Exception as e:
return sum(self._data)
def normalize(self):
new_list = TList(shape=self.size, pattern=self.pattern)
new_list._data = [a / np.max(np.abs(a)) for a in self._data]
return new_list
def conj(self):
new_list = TList(shape=self.size, pattern=self.pattern)
new_list._data = [a.conj() for a in self._data]
return new_list
def items(self):
return [a.item() for a in self._data]
def pack_data(self):
data = []
for a in self._data:
data.append(np.reshape(a, (-1,)))
return np.concatenate(data)
def reset_changed(self):
if self._data is not None:
self._changed = len(self._data) * [False]
return self
def mark_changed(self, linear_ix):
if self._changed is not None:
self._changed[linear_ix] = True
def is_changed(self, *ix):
if self._changed is None:
return False
linear_ix = self._conv_ix(ix)
return self._changed[linear_ix]
def fill(self, data, d=None, D=None):
new_list = TList(shape=self.size, pattern=self.pattern)
offset = 0
new_data = []
for i,a in enumerate(self):
siz = a.size
# new_data.append(np.reshape(data[offset:offset+siz], (d, D, D, D, D)))
new_data.append(np.reshape(data[offset:offset+siz], a.shape))
offset = offset + siz
new_list._data = new_data
return new_list
def tot_numel(self):
return sum([a.size for a in self._data])
def stop_gradient(self):
new_list = TList(shape=self.size, pattern=self.pattern)
new_list._data = [jax.lax.stop_gradient(a) if len(a)>0 else a for a in self._data]
return new_list
def _conv_ix(self, ix):
if isinstance(ix, (tuple,list)):
if len(self._loc) == 1:
# shift_i, shift_j = onp.unravel_index(self._loc[0], self.size, order='F')
shift_j, shift_i = np.unravel_index(self._loc[0], self.size)
else:
shift_i, shift_j = self._loc
i = (ix[0] + shift_i) % self.size[0]
j = (ix[1] + shift_j) % self.size[1]
# linear_ix = np.ravel_multi_index((i,j), self.size, order='F')
linear_ix = self._linear_ix(i,j)
else:
linear_ix = ix
return linear_ix
def _linear_ix(self, i, j):
if self.pattern is not None:
return self.pattern[j][i]
else:
return np.ravel_multi_index((i,j), self.size, order='F')
def _purge_tmp(self):
self._tmpdata = None
self._hold_write = False
def __eq__(self, other):
if self._data != other._data:
return False
if self.pattern is not None:
if other.pattern is None:
return False
if not (self.pattern == other.pattern).all():
return False
return True
def __getitem__(self, ix):
linear_ix = self._conv_ix(ix)
if self._tmpdata is not None and self._tmpdata[linear_ix] is not None:
return self._tmpdata[linear_ix]
return self._data[linear_ix]
def __setitem__(self, ix, value):
linear_ix = self._conv_ix(ix)
if self._hold_write:
if self._tmpdata is None:
self._tmpdata = [None] * len(self)
self._tmpdata[linear_ix] = self._data[linear_ix]
self._data[linear_ix] = value
self.mark_changed(linear_ix)
def __repr__(self):
if self._data is None:
return "TList{}[]"
repr_str = "TList{"
if self._loc is not None:
repr_str += "Loc=" + self._loc.__repr__()
if self.pattern is not None:
repr_str += ",Pat=" + self.pattern.__repr__()
repr_str += ",Size=" + self.size.__repr__()
repr_str += "}["
for j in range(self.size[1]):
repr_str += "["
for i in range(self.size[0]):
try:
repr_str += f"{self[i,j].shape}"
except:
repr_str += self[i,j].__repr__()
if i < self.size[0]-1:
repr_str += ", "
if j < self.size[1]-1:
repr_str += "], "
else:
repr_str += "]]"
return repr_str
@staticmethod
def empty_like(T, empty_obj=None):
if empty_obj is None:
empty_obj = T.empty_obj
return TList(shape=T.size, pattern=T.pattern, empty_obj=empty_obj)
def isfinite(x):
try:
return len(x) > 0
except Exception as e:
return np.isfinite(np.array(x))