from transformers import AutoModelForCausalLM, AutoTokenizer, BertForMaskedLM
import torch, pandas as pd, numpy as np, warnings, functools
from rime.util import matrix_reindex, _to_cuda, empty_cache_on_exit, explode_user_titles
from tqdm import tqdm
[docs]class BayesLM:
""" p(y|x) propto lm(x|y) * p(y), where x is the last item in user history
and y is a candidate item.
https://huggingface.co/docs/transformers/v4.15.0/en/model_doc/auto
"""
[docs] def __init__(self, item_df, max_num_candidates=None, batch_size=100,
prompt="a user will watch {y} after watching {x}",
item_pop_power=1, item_pop_pseudo=0.01, temperature=1, gamma=0,
candidate_selection_method=None, model_name='gpt2', # bert-base-uncased
text_column_name='TITLE'):
assert text_column_name in item_df, f"require {text_column_name} as data(y)"
self.item_df = item_df.copy()
self.item_df['log_p_y'] = item_pop_power * np.log(item_df['_hist_len'] + item_pop_pseudo)
if max_num_candidates is None:
warnings.warn("please set max_num_candidates, default=2 only for testing purposes")
max_num_candidates = 2
self.max_num_candidates = max_num_candidates
self.batch_size = batch_size
self.prompt = prompt
self.temperature = temperature
self.gamma = gamma
self.text_column_name = text_column_name
if candidate_selection_method is None:
candidate_selection_method = 'greedy' if item_pop_power > 0 else 'sample'
self.candidate_selection_method = candidate_selection_method
# huggingface model initialization
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
assert self.tokenizer.padding_side == 'right', "expect right padding"
if model_name == 'gpt2':
self.tokenizer.pad_token = self.tokenizer.eos_token
if model_name.startswith('bert'):
self.model = BertForMaskedLM.from_pretrained(model_name)
else: # gpt2
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.eval() # eval mode
self.loss = torch.nn.CrossEntropyLoss(reduction='none')
@torch.no_grad()
def _compute_log_p_x_given_y(self, Y, x, device):
""" evaluate p_x_given_y for each y in Y; label = x.expand(...) """
batch_size = len(Y)
sequences = [self.prompt.format(y=y, x=x) for y in Y]
inputs = self.tokenizer(sequences, padding=True, return_tensors='pt').to(device)
targets = self.tokenizer(x, return_tensors='pt')['input_ids'].to(device)
targets = torch.vstack([targets for _ in range(batch_size)])
seq_len = inputs['attention_mask'].sum(1).tolist()
target_len = targets.shape[1]
if hasattr(self.model, "transformer"): # gpt causal lm
hidden_states = self.model.transformer(**inputs)[0]
hidden_states = torch.vstack([x[n - target_len - 1: n - 1]
for x, n in zip(hidden_states, seq_len)])
logits = self.model.lm_head(hidden_states)
elif hasattr(self.model, "bert"): # bert [CLS] sequence [SEP], performs similarly
targets = targets[:, 1:-1]
target_len = target_len - 2
hidden_states = self.model.bert(**inputs)[0]
hidden_states = torch.vstack([x[n - target_len - 1: n - 1] # [3-1-1 : 3-1]
for x, n in zip(hidden_states, seq_len)])
logits = self.model.cls(hidden_states)
else: # decoding non-target items can lead to 20% longer compute time
logits = self.model(**inputs).logits
logits = torch.vstack([x[n - target_len - 1: n - 1]
for x, n in zip(logits, seq_len)])
loss = self.loss(logits, targets.reshape(-1))
return (-loss).reshape(targets.shape).mean(1).tolist()