Source code for rime.models.rnn

import numpy as np
import functools, warnings

import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence

from .third_party.word_language_model import RNNModel

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from ..util import (_LitValidated, empty_cache_on_exit, _ReduceLRLoadCkpt,
                    default_random_split, LazyDenseMatrix, matrix_reindex, export_jsondump)


[docs]class RNN:
[docs] def __init__( self, item_df, max_item_size=int(30e3), num_hidden=128, nlayers=2, max_epochs=20, gpus=int(torch.cuda.is_available()), truncated_input_steps=256, truncated_bptt_steps=32, batch_size=64, load_from_checkpoint=None, auto_pad_item=True, ): self._padded_item_list = [None] * auto_pad_item + item_df.index[:max_item_size].tolist() self._tokenize = {k: i for i, k in enumerate(self._padded_item_list)} self._truncated_input_steps = truncated_input_steps self.model = _LitRNNModel( 'GRU', len(self._padded_item_list), num_hidden, num_hidden, nlayers, dropout=0, tie_weights=True, truncated_bptt_steps=truncated_bptt_steps) if load_from_checkpoint is not None: self.model.load_state_dict( torch.load(load_from_checkpoint)['state_dict']) self.trainer = Trainer( max_epochs=max_epochs, gpus=gpus, callbacks=[self.model._checkpoint, LearningRateMonitor()]) print("trainer log at:", self.trainer.logger.log_dir) self.batch_size = batch_size
def _extract_data(self, user_df): return user_df['_hist_items'].apply( lambda x: [0] + [self._tokenize[i] for i in x if i in self._tokenize]).values @empty_cache_on_exit @torch.no_grad() def transform(self, D): dataset = self._extract_data(D.user_in_test) collate_fn = functools.partial( _collate_fn, truncated_input_steps=self._truncated_input_steps, training=False) m, n_events, sample = _get_dataset_stats(dataset, collate_fn) print(f"transforming {m} users with {n_events} events, " f"truncated@{self._truncated_input_steps} per user") print(f"sample[0]={sample[0].tolist()}") print(f"sample[1]={sample[1].tolist()}") if hasattr(dataset, "tolist"): # pytorch lightning bug cannot take array input dataset = dataset.tolist() batches = self.trainer.predict( self.model, dataloaders=DataLoader(dataset, 1000, collate_fn=collate_fn)) user_hidden, user_log_bias = [np.concatenate(x) for x in zip(*batches)] item_hidden = self.model.model.decoder.weight.detach().cpu().numpy() item_log_bias = self.model.model.decoder.bias.detach().cpu().numpy() item_reindex = lambda x, fill_value=0: matrix_reindex( x, self._padded_item_list, D.item_in_test.index, axis=0, fill_value=fill_value) return (LazyDenseMatrix(user_hidden) @ item_reindex(item_hidden).T + user_log_bias[:, None] + item_reindex(item_log_bias, -np.inf)[None, :]).exp() @empty_cache_on_exit def fit(self, D): dataset = self._extract_data(D.user_df[D.user_df['_hist_len'] > 0]) collate_fn = functools.partial( _collate_fn, truncated_input_steps=self._truncated_input_steps, training=True) m, n_events, sample = _get_dataset_stats(dataset, collate_fn) print(f"fitting {m} users with {n_events} events, " f"truncated@{self._truncated_input_steps} per user") print(f"sample[0]={sample[0].tolist()}") print(f"sample[1]={sample[1].tolist()}") train_set, valid_set = default_random_split(dataset) self.trainer.fit( self.model, DataLoader(train_set, self.batch_size, collate_fn=collate_fn, shuffle=True), DataLoader(valid_set, self.batch_size, collate_fn=collate_fn),) self.model._load_best_checkpoint("best") export_jsondump(self.trainer.logger.experiment) for name, param in self.model.named_parameters(): print(name, param.data.shape) return self
def _collate_fn(batch, truncated_input_steps, training): if truncated_input_steps > 0: batch = [seq[-(truncated_input_steps + training):] for seq in batch] batch = [torch.tensor(seq, dtype=torch.int64) for seq in batch] batch, lengths = pad_packed_sequence(pack_sequence(batch, enforce_sorted=False)) if training: return (batch[:-1].transpose(0, 1), batch[1:].transpose(0, 1)) # TBPTT assumes NT layout else: return (batch, lengths) # RNN default TN layout def _get_dataset_stats(dataset, collate_fn): truncated_input_steps = collate_fn.keywords['truncated_input_steps'] n_events = sum([min(truncated_input_steps, len(x)) for x in dataset]) sample = next(iter(DataLoader(dataset, 1, collate_fn=collate_fn, shuffle=True))) return len(dataset), n_events, sample class _LitRNNModel(_LitValidated): def __init__(self, *args, truncated_bptt_steps, lr=0.1, **kw): super().__init__() self.model = RNNModel(*args, **kw) self.loss = torch.nn.NLLLoss(ignore_index=0) self.truncated_bptt_steps = truncated_bptt_steps self.lr = lr def forward(self, batch): """ output user embedding at lengths-1 positions """ TN_layout, lengths = batch last_hidden, log_bias = self.model.forward_last_prediction(TN_layout, lengths) return last_hidden.cpu().numpy(), log_bias.cpu().numpy() def configure_optimizers(self): optimizer = torch.optim.Adagrad(self.parameters(), eps=1e-3, lr=self.lr) lr_scheduler = _ReduceLRLoadCkpt( optimizer, model=self, factor=0.25, patience=4, verbose=True) return {"optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, "monitor": "val_loss" }} def training_step(self, batch, batch_idx, hiddens=None): """ truncated_bptt_steps pass batch[:][:, slice] and hiddens """ x, y = batch[0].T, batch[1].T # transpose to TN layout if hiddens is None: hiddens = self.model.init_hidden(x.shape[1]) else: hiddens = hiddens.detach() out, hiddens = self.model(x, hiddens) loss = self.loss(out, y.view(-1)) self.log("train_loss", loss) return {'loss': loss, 'hiddens': hiddens}