import pandas as pd, numpy as np, scipy as sp
import functools, collections, time, contextlib, torch, gc, warnings, json
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
from .score_array import * # noqa: F401, F403
from .plotting import plot_rec_results, plot_mtch_results
from tensorboard.backend.event_processing import event_accumulator
[docs]class timed(contextlib.ContextDecorator):
[docs] def __init__(self, name="", inline=True):
self.name = name
self.inline = inline
def __enter__(self):
self.tic = time.time()
print("timing", self.name, end=' ' if self.inline else '\n')
def __exit__(self, *args, **kw):
print("done", "." if self.inline else self.name,
"time {:.1f}s".format(time.time() - self.tic))
[docs]def warn_nan_output(func):
@functools.wraps(func)
def wrapped(*args, **kw):
out = func(*args, **kw)
values = getattr(out, "values", out)
if hasattr(values, "has_nan"):
has_nan = values.has_nan()
else:
has_nan = np.isnan(values).any()
if has_nan:
warnings.warn(f"{func.__name__} output contains NaN", stacklevel=2)
return out
return wrapped
def _empty_cache():
gc.collect()
torch.cuda.empty_cache()
def _get_cuda_objs():
objs = []
for obj in gc.get_objects():
try:
flag = torch.is_tensor(obj) # or \
# (hasattr(obj, 'data') and torch.is_tensor(obj.data))
except Exception:
flag = False
try:
if flag and torch.device(obj.device) != torch.device("cpu"):
objs.append(obj)
except RuntimeError as e:
warnings.warn(str(e))
return objs
[docs]def empty_cache_on_exit(func):
@functools.wraps(func)
def wrapped(*args, **kw):
_empty_cache()
start_list = _get_cuda_objs()
tic = time.time()
out = func(*args, **kw)
print(func.__name__, "time {:.1f}s".format(time.time() - tic))
_empty_cache()
end_list = _get_cuda_objs()
for obj in set(end_list) - set(start_list):
print(func.__name__, "memory leak",
type(obj), obj.size(), obj.device, flush=True)
del start_list
del end_list
_empty_cache()
return out
return wrapped
@contextlib.contextmanager
def _to_cuda(model):
if torch.cuda.is_available():
orig_device = model.device
print("running model on cuda device")
yield model.to("cuda")
model.to(orig_device)
else:
yield model
[docs]def perplexity(x):
x = np.ravel(x) / x.sum()
return float(np.exp(- x @ np.log(np.where(x > 0, x, 1e-10))))
@empty_cache_on_exit
def _assign_topk(S, k, tie_breaker=1e-10, device="cpu", batch_size=None):
""" Return a sparse matrix where each row contains k non-zero values.
Used for both ItemRec (if S is user-by-item) and UserRec (if S is transposed).
Expect the shape to align with (len(D.user_in_test), len(D.item_in_test))
or its transpose.
"""
indices = []
if hasattr(S, "collate_fn"):
if batch_size is None and hasattr(S, 'batch_size'):
batch_size = S.batch_size
batches = map(lambda i: S[i:min(len(S), i + batch_size)],
range(0, len(S), batch_size))
else:
batches = [S]
for s in batches:
if hasattr(s, "as_tensor"):
s = s.as_tensor(device)
else:
s = torch.tensor(s, device=device)
if tie_breaker:
s = s + torch.rand(*s.shape, device=device) * tie_breaker
s_topk = s.topk(k).indices.cpu().numpy()
indices.append(s_topk)
indices = np.vstack(indices)
return sp.sparse.csr_matrix((
np.ones(indices.size),
np.ravel(indices),
np.arange(0, indices.size + 1, indices.shape[1]),
), shape=S.shape)
assign_topk = _assign_topk
@empty_cache_on_exit
def _argsort(S, tie_breaker=1e-10, device="cpu"):
print(f"_argsort {S.size:,} scores on device {device}; ", end="")
if hasattr(S, "batch_size") and S.batch_size < S.shape[0]:
warnings.warn(f"switching numpy.argsort due to {S.batch_size}<{S.shape[0]}")
device = None
if hasattr(S, "as_tensor"):
S = S.as_tensor(device)
shape = S.shape
if device is None:
if tie_breaker > 0:
S = S + np.random.rand(*S.shape) * tie_breaker
S = -S.reshape(-1)
_empty_cache()
argsort_ind = np.argsort(S)
else:
S = torch.as_tensor(S, device=device)
if tie_breaker > 0:
S = S + torch.rand(*S.shape, device=device) * tie_breaker
S = -S.reshape(-1)
_empty_cache()
argsort_ind = torch.argsort(S).cpu().numpy()
return np.unravel_index(argsort_ind, shape)
[docs]def groupby_unexplode(series, index=None, return_type='series'):
"""
assume the input is an exploded dataframe with block-wise indices
>>> groupby_unexplode(pd.Series([1,2,3,4,5], index=[1,1,2,3,3])).to_dict()
{1: [1, 2], 2: [3], 3: [4, 5]}
>>> groupby_unexplode(pd.Series([1,2,3,4,5], index=[1,1,2,3,3]), index=[0,1,-1,2,3,4]).to_dict()
{0: [], 1: [1, 2], -1: [], 2: [3], 3: [4, 5], 4: []}
"""
if index is None:
splits = np.where(
np.array(series.index.values[1:]) != # 1, 2, 3, 3
np.array(series.index.values[:-1]) # 1, 1, 2, 3
)[0] + 1 # [2, 3]
index = series.index.values[np.hstack([[0], splits])] # 1, 2, 3
elif len(series) == 0:
splits = [0 for _ in range(len(index) - 1)]
else: # something like searchsorted, but unordered
splits = []
tape = enumerate(series.index)
N = len(series)
i, value = next(tape)
for key in index:
splits.append(i)
while i < N and key == value:
i, value = next(tape, (N, None)) # move past the current chunk
assert i == N, f"mismatch between series and index. {series.index} vs {index}"
splits = splits[1:]
if return_type == 'splits':
return splits
else:
return pd.Series([x.tolist() for x in np.split(series.values, splits)],
index=index)
[docs]def indices2csr(indices, shape1, data=None):
indptr = np.cumsum([0] + [len(x) for x in indices])
if data is None:
data = [np.ones(indptr[-1])]
return sps.csr_matrix((
np.hstack(data), np.hstack(indices), indptr
), shape=(len(indices), shape1))
[docs]def fill_factory_inplace(df, isna, kv):
for k, v in kv.items():
if k is None: # series
df[:] = [v() if do else x for do, x in zip(isna, df.values)]
elif k in df.columns:
df[k] = [v() if do else x for do, x in zip(isna, df[k])]
else:
warnings.warn(f"fill_factory_inplace missing {k}")
[docs]def sample_groupA(user_df, frac=0.5, seed=888):
return user_df.index.isin(
user_df.sample(frac=frac, random_state=seed).index
)
[docs]def filter_min_len(event_df, min_user_len, min_item_len):
""" CAVEAT: use in conjunction with dataclass filter to avoid future-leaking bias """
users = event_df.groupby('USER_ID').size()
items = event_df.groupby('ITEM_ID').size()
return event_df[
event_df['USER_ID'].isin(users[users >= min_user_len].index) &
event_df['ITEM_ID'].isin(items[items >= min_item_len].index)
]
[docs]def explode_user_titles(user_hist, item_titles, gamma=0.5, min_gamma=0.1, pad_title='???'):
""" explode last few user events and match with item titles;
return splits and discount weights; empty user_hist will be turned into a single pad_title. """
keep_last = int(np.log(min_gamma) / np.log(np.clip(gamma, 1e-10, 1 - 1e-10))) + 1 # default=4
explode_titles = pd.Series([x[-keep_last:] for x in user_hist.values]).explode() \
.to_frame('ITEM_ID').join(item_titles.to_frame('TITLE'), on='ITEM_ID')['TITLE']
explode_titles = pd.Series(
[x if not na else pad_title for x, na in
zip(explode_titles.tolist(), explode_titles.isna().tolist())],
index=explode_titles.index)
splits = np.where(
np.array(explode_titles.index.values[1:]) != np.array(explode_titles.index.values[:-1])
)[0] + 1
weights = np.hstack([gamma ** (np.cumsum(x) - np.sum(x)) # -2, -1, 0
for x in np.split(np.ones(len(explode_titles)), splits)])
weights = np.hstack([x / x.sum() for x in np.split(weights, splits)])
return explode_titles.values, splits, weights
def _get_loss_value(loss):
if isinstance(loss, dict):
return loss['loss']
else:
return loss
class _LitValidated(LightningModule):
def training_step(self, batch, batch_idx):
if hasattr(self, 'training_and_validation_step'):
loss = self.training_and_validation_step(batch, batch_idx)
self.log('loss', _get_loss_value(loss), on_step=True, on_epoch=True, prog_bar=True)
return loss # value or dictionary
else:
raise NotImplementedError("alternatively, define training_step in subclass")
def validation_step(self, batch, batch_idx):
training_and_validation_step = getattr(self, 'training_and_validation_step', self.training_step)
loss = training_and_validation_step(batch, batch_idx)
self.log('val_loss', _get_loss_value(loss), on_step=False, on_epoch=True, prog_bar=True)
return _get_loss_value(loss)
@cached_property
def _checkpoint(self):
return ModelCheckpoint(monitor="val_loss", save_weights_only=True)
def _load_best_checkpoint(self, msg="loading"):
best_model_path = self._checkpoint.best_model_path
best_model_score = self._checkpoint.best_model_score
if best_model_score is not None:
print(f"{msg} checkpoint {best_model_path} with score {best_model_score}")
self.load_state_dict(torch.load(best_model_path)['state_dict'])
else:
warnings.warn("No checkpoints found!")
class _ReduceLRLoadCkpt(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, *args, model, **kw):
super().__init__(*args, **kw)
self.model = model
def _reduce_lr(self, epoch):
super()._reduce_lr(epoch)
self.model._load_best_checkpoint()
[docs]def default_random_split(dataset):
N = len(dataset)
if N >= 5:
return random_split(dataset, [N * 4 // 5, N - N * 4 // 5])
else:
warnings.warn(f"short dataset len={len(dataset)}; "
"setting valid_set identical to train_set")
return dataset, dataset
[docs]@dataclasses.dataclass
class MissingModel:
name: str
err: Exception
verbose: bool = True
def __post_init__(self):
if self.verbose:
warnings.warn(f"Model {self.name} is missing due to {self.err}")
[docs]def export_jsondump(writer):
# minor modifications from https://discuss.pytorch.org/t/tensorboard-json-dump-of-all-scalars/69334
if hasattr(writer, "log_dir"):
log_dir = writer.log_dir
elif hasattr(writer, "_get_file_writer"):
log_dir = writer._get_file_writer().get_logdir()
else:
warnings.warn("export jsondump failure")
return
tf_files = [] # -> list of paths from writer.log_dir to all files in that directory
for root, dirs, files in os.walk(log_dir):
for file in files:
tf_files.append(os.path.join(root, file)) # go over every file recursively in the directory
for file_id, file in enumerate(tf_files):
path = os.path.join('/'.join(file.split('/')[:-1])) # determine path to folder in which file lies
# seperate file created by add_scalar from add_scalars
name = os.path.join(file.split('/')[-2]) if file_id > 0 else os.path.join('log_events')
event_acc = event_accumulator.EventAccumulator(file)
event_acc.Reload()
data = {}
hparam_file = False # I save hparam files as 'hparam/xyz_metric'
for tag in sorted(event_acc.Tags()["scalars"]):
if tag.split('/')[0] == 'hparam':
hparam_file = True # check if its a hparam file
step, value = [], []
for scalar_event in event_acc.Scalars(tag):
step.append(scalar_event.step)
value.append(scalar_event.value)
data[tag] = (step, value)
if not hparam_file and bool(data): # if its not a hparam file and there is something in the data -> dump it
with open(f'{path}/{name}.json', "w") as f:
json.dump(data, f)