Source code for rime.metrics.dual

import pandas as pd, numpy as np
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer, loggers
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from ..util import empty_cache_on_exit, get_batch_size, score_op, auto_cast_lazy_score
from ..util.dual_bisect import dual_solve_u


[docs]class Dual:
[docs] def __init__(self, score_mat, alpha_lb=-1, alpha_ub=2, beta_lb=-1, beta_ub=2, device='cpu', max_epochs=100, min_epsilon=1e-10, gpus=int(torch.cuda.is_available()), prefix='Dual'): score_mat = auto_cast_lazy_score(score_mat) self.score_max = float(score_op(score_mat, "max", device)) self.score_min = float(score_op(score_mat, "min", device)) print(f"entering {prefix} Dual score (min={self.score_min}, max={self.score_max})") self.model = _LitDual( *score_mat.shape, alpha_lb, alpha_ub, beta_lb, beta_ub, 0.1 / max(score_mat.shape), max_epochs, min_epsilon) tb_logger = loggers.TensorBoardLogger( "logs/", name=f"{prefix}+{np.mean(alpha_lb):.1%}+{np.mean(alpha_ub):.1%}" f"+{np.mean(beta_lb):.1%}+{np.mean(beta_ub):.1%}") self.trainer = Trainer(max_epochs=max_epochs, gpus=gpus, logger=tb_logger, log_every_n_steps=1, callbacks=[ModelCheckpoint()], # change default save path from . to logger path ) print("trainer log at:", self.trainer.logger.log_dir)
@empty_cache_on_exit @torch.no_grad() def transform(self, score_mat): score_mat = auto_cast_lazy_score(score_mat) / self.score_max u = self.trainer.predict( self.model, DataLoader(score_mat, self.model.batch_size, collate_fn=score_mat[0].collate_fn) ) return ((score_mat - np.hstack(u)[:, None] - self.model.v[None, :]) / self.model.epsilon).sigmoid() @empty_cache_on_exit def fit(self, score_mat): score_mat = auto_cast_lazy_score(score_mat) / self.score_max self.trainer.fit( self.model, DataLoader(score_mat, self.model.batch_size, True, collate_fn=score_mat[0].collate_fn) ) v = self.model.v.detach().cpu().numpy() print('v', pd.Series(v.ravel()).describe().to_dict()) return self
class _LitDual(LightningModule): def __init__(self, n_users, n_items, alpha_lb, alpha_ub, beta_lb, beta_ub, gtol, max_epochs=100, min_epsilon=1e-10, v=None, epsilon=1): super().__init__() self.alpha_lb = alpha_lb self.alpha_ub = alpha_ub self.beta_lb = beta_lb self.beta_ub = beta_ub self.gtol = gtol self.batch_size = get_batch_size((n_users, n_items)) n_batches = n_users / self.batch_size self.lr = n_items / n_batches self.epsilon = epsilon self.epsilon_gamma = (min_epsilon / epsilon) ** (1 / max_epochs) if v is None: if beta_lb <= 0: # ub-only v = torch.rand(n_items) elif beta_ub >= 1: # lb-only v = -torch.rand(n_items) else: # range or eq v = torch.rand(n_items) * 2 - 1 self.v = torch.nn.Parameter(v) def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=self.lr) def on_train_epoch_start(self): self.epsilon *= self.epsilon_gamma self.log("epsilon", self.epsilon, prog_bar=True) @torch.no_grad() def _solve_u(self, batch): s = batch - self.v[None, :] fn = lambda alpha: dual_solve_u(s, alpha, self.epsilon, gtol=self.gtol, s_guess=-self.v.max()) u_neg, u_neg_iters = fn(self.alpha_lb) u_pos, u_pos_iters = fn(self.alpha_ub) u = u_neg.clip(None, 0) + u_pos.clip(0, None) return u, u_neg_iters, u_pos_iters @torch.no_grad() def _solve_v(self, batch, u): s = batch.T - u[None, :] fn = lambda beta: dual_solve_u(s, beta, self.epsilon, gtol=self.gtol, s_guess=-u.max()) v_neg, v_neg_iters = fn(self.beta_lb) v_pos, v_pos_iters = fn(self.beta_ub) v = v_neg.clip(None, 0) + v_pos.clip(0, None) return v, v_neg_iters, v_pos_iters def training_step(self, batch, batch_idx): batch = batch.as_tensor(self.device) u, u_neg_iters, u_pos_iters = self._solve_u(batch) self.log("u_neg_iters", float(u_neg_iters), prog_bar=True) self.log("u_pos_iters", float(u_pos_iters), prog_bar=True) v, v_neg_iters, v_pos_iters = self._solve_v(batch, u) self.log("v_neg_iters", float(v_neg_iters), prog_bar=True) self.log("v_pos_iters", float(v_pos_iters), prog_bar=True) loss = ((self.v - v)**2).mean() / 2 self.log("train_loss", loss) return loss @torch.no_grad() def forward(self, batch): batch = batch.as_tensor(self.device) u, *_ = self._solve_u(batch) return u.cpu().numpy()