Source code for rime.models.third_party.word_language_model

import numpy as np
from .model import RNNModel as _RNNModel, TransformerModel as _TransformerModel


class _ForwardLastPrediction:
    def forward_last_prediction(self, input_time_by_batch, lengths):
        all_hidden, *_ = self.forward_all_hidden(input_time_by_batch)
        hidden = all_hidden[lengths - 1, np.arange(len(lengths))]
        logsumexp = self.decoder(hidden).logsumexp(1)
        return hidden, -logsumexp


[docs]class RNNModel(_RNNModel, _ForwardLastPrediction): def forward_all_hidden(self, input, hidden=None): if hidden is None: hidden = self.init_hidden(input.shape[1]) emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) return output, hidden
[docs]class TransformerModel(_TransformerModel, _ForwardLastPrediction):
[docs] def __init__(self, *args, tie_weights=False, **kwargs): super().__init__(*args, **kwargs) if tie_weights: sizes = (self.encoder.weight.shape, self.decoder.weight.shape) assert sizes[0] == sizes[1], f"tie weights size conflict {sizes}" self.decoder.weight = self.encoder.weight
def forward_all_hidden(self, src, has_mask="true and ignored"): mask = self._generate_square_subsequent_mask(len(src)).to(src.device) src = self.encoder(src) * np.sqrt(self.ninp) src = self.pos_encoder(src) output = self.transformer_encoder(src, mask) return (output,)