Source code for rime.models.transformer
import torch, functools, numpy as np
from .third_party.word_language_model import TransformerModel
from .rnn import (RNN, Trainer, _LitRNNModel, _LitValidated, _collate_fn, LearningRateMonitor)
class _LitTransformerModel(_LitRNNModel, _LitValidated):
def __init__(self, ntoken, *args, truncated_bptt_steps=None, lr=0.1 / 4, **kw):
super(_LitValidated, self).__init__()
self.model = TransformerModel(ntoken, *args, **kw)
self.loss = torch.nn.NLLLoss(ignore_index=0)
self.ntoken = ntoken
self.lr = lr
def training_step(self, batch, batch_idx):
""" max length defined through truncated_input_steps=256 """
x, y = batch[0].T, batch[1].T # NT -> TN layout
out = self.model(x, True)
# print(batch_idx, out.softmax(dim=-1).detach().cpu().numpy().round(2))
loss = self.loss(out.view(-1, self.ntoken), y.view(-1))
self.log("train_loss", loss)
return loss
[docs]class Transformer(RNN):
[docs] def __init__(
self, item_df, max_item_size=int(30e3),
num_hidden=128, nlayers=2, max_epochs=20, nhead=2, lr=0.1 / 4,
gpus=int(torch.cuda.is_available()),
truncated_input_steps=256, batch_size=64,
load_from_checkpoint=None, tie_weights=True, auto_pad_item=True,
):
self._padded_item_list = [None] * auto_pad_item + item_df.index[:max_item_size].tolist()
self._tokenize = {k: i for i, k in enumerate(self._padded_item_list)}
self._truncated_input_steps = truncated_input_steps
self.model = _LitTransformerModel(
len(self._padded_item_list),
num_hidden, nhead, num_hidden, nlayers, 0, lr=lr, tie_weights=tie_weights)
if load_from_checkpoint is not None:
self.model.load_state_dict(
torch.load(load_from_checkpoint)['state_dict'])
self.trainer = Trainer(
max_epochs=max_epochs, gpus=gpus,
callbacks=[self.model._checkpoint, LearningRateMonitor()])
print("trainer log at:", self.trainer.logger.log_dir)
self.batch_size = batch_size