Source code for rime.models.zero_shot.item_knn

from transformers import AutoModelForCausalLM, AutoTokenizer, BertForMaskedLM
import torch, pandas as pd, numpy as np, functools, json, warnings
from rime.util import (matrix_reindex, LazyDenseMatrix, _to_cuda, empty_cache_on_exit,
                       explode_user_titles)
from tqdm import tqdm


[docs]class ItemKNN: """ softplus(x'y / temperature + log_p_y) for numerical stability """
[docs] def __init__(self, item_df, batch_size=100, item_pop_power=1, item_pop_pseudo=0.01, model_name='bert-base-uncased', # gpt2 pooling=None, # cls or mean temperature=None, gamma=0.5): assert "TITLE" in item_df or "embedding" in item_df, "require TITLE or embedding" self.item_index = item_df.index self.item_biases = item_pop_power * np.log(item_df['_hist_len'].values + item_pop_pseudo) self.batch_size = batch_size if temperature is None: temperature = { 'bert-base-uncased': 10, 'gpt2': 100, }[model_name] if pooling is None: pooling = 'cls' if 'bert' in model_name else 'mean' self.pooling = pooling self.temperature = temperature self.gamma = gamma if "embedding" in item_df: self.item_embeddings = np.vstack(item_df["embedding"].values) else: self.tokenizer = AutoTokenizer.from_pretrained(model_name) assert self.tokenizer.padding_side == 'right', "expect right padding" if model_name == 'gpt2': self.tokenizer.pad_token = self.tokenizer.eos_token if model_name.startswith('bert'): self.model = BertForMaskedLM.from_pretrained(model_name) else: # gpt2 self.model = AutoModelForCausalLM.from_pretrained(model_name) self.model.eval() # eval mode self.item_embeddings = self._compute_embeddings(item_df["TITLE"].values)
@empty_cache_on_exit @torch.no_grad() def _compute_embeddings(self, titles): """ find embedding of a batch of sequences """ with _to_cuda(self.model) as model: embeddings = [] for batch in tqdm(np.split(titles, range(0, len(titles), self.batch_size)[1:])): inputs = self.tokenizer(batch.tolist(), padding=True, return_tensors='pt') if hasattr(self.model, 'bert'): for key in inputs.keys(): # 'input_ids', 'attention_mask', 'token_type_ids' inputs[key] = inputs[key][:, :512] offset = 1 # [cls] seq [sep] hidden_states = self.model.bert(**inputs.to(model.device))[0] else: offset = 0 hidden_states = self.model.transformer(**inputs.to(model.device))[0] if self.pooling == 'mean': segments = [slice(offset, n - offset) for n in inputs['attention_mask'].sum(1).tolist()] elif self.pooling == 'cls': segments = [slice(0, 1) for _ in inputs['attention_mask']] else: raise NotImplementedError hidden_states = torch.vstack([ # mean-pooling on causal lm states x[slc].mean(0, keepdims=True) for x, slc in zip(hidden_states, segments)]) embeddings.append(hidden_states.double().cpu().numpy()) return np.vstack(embeddings) @empty_cache_on_exit def transform(self, D): explode_embeddings, splits, weights = explode_user_titles( D.user_in_test['_hist_items'], pd.Series(self.item_embeddings.tolist(), self.item_index), self.gamma, pad_title=np.zeros_like(self.item_embeddings[0]).tolist()) user_embeddings = np.vstack([w @ x for w, x in zip( np.split(weights, splits), np.split(np.vstack(explode_embeddings), splits) )]) / self.temperature # not very useful due to linearity in softplus item_reindex = lambda x, fill_value=0: matrix_reindex( x, self.item_index, D.item_in_test.index, axis=0, fill_value=fill_value) return (LazyDenseMatrix(user_embeddings) @ item_reindex(self.item_embeddings).T + item_reindex(self.item_biases, fill_value=-np.inf)).softplus()