caif / generator.py
Балаганский Никита Николаевич
everything wrapped in cache
e852933
from typing import Optional, Union
import torch
import transformers
import streamlit as st
from plotly import graph_objects as go
from utils import get_lm
class Generator:
def __init__(self, lm_model_name, device, entropy=None):
self.device = device
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
lm_model_name
)
self.lm = get_lm(lm_model_name).to(device)
self.lm.eval()
self.lm.config.pad_token_id = self.lm.config.eos_token_id
self.tokenizer.add_special_tokens(
{"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)}
)
self.caif_sampler = None
self.ordinary_sampler = None
self.entropy_based_stats = {
"skips": 0,
"avg_entropy": 0,
"count": 0,
}
self.entropy = entropy
def set_caif_sampler(self, sampler):
self.caif_sampler = sampler
def set_ordinary_sampler(self, sampler):
self.ordinary_sampler = sampler
def sample_sequences(
self,
num_samples: int,
input_prompt: Optional[str],
max_length: int,
caif_period: int,
caif_tokens_num: Union[int, None] = None,
entropy: float = None,
progress_bar=None,
**sampler_kwargs
):
self.entropy = entropy
input_ids, past, ended_sequences = self.get_input_ids(
input_prompt,
num_samples,
)
text = st.empty()
plot = st.empty()
gen_history = []
layout = go.Layout({
"xaxis": {
"title": "# Tokens"
},
"yaxis": {
"title": "Desired Attribute"
},
"plot_bgcolor": '#FFFFFF',
"template": "plotly_white",
"hovermode": "x",
})
inp_len = len(input_ids[0])
if self.caif_sampler is not None:
current_decoded = self.tokenizer.decode(input_ids[0])
probs = torch.exp(
self.caif_sampler.get_classifier_log_probs(
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"]
)
).item()
gen_history += [probs]
for i in range(max_length):
is_caif_step = (
i % caif_period == 0 and self.caif_sampler is not None
)
input_ids, past, ended_sequences = self.generation_step(
input_ids,
past,
ended_sequences,
is_caif_step,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
progress_bar.progress((i+1)/max_length)
if ended_sequences.all():
break
current_decoded = self.tokenizer.decode(input_ids[0])
if self.caif_sampler is not None:
probs = torch.exp(
self.caif_sampler.get_classifier_log_probs(
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"]
)
).item()
gen_history += [probs]
scatter_data = go.Scatter({
"x": list(range(len(gen_history))),
"y": gen_history,
"hovertext": ["[PROMPT]"] + [self.tokenizer.decode(t) for t in input_ids[0][inp_len:]]
})
fig = go.Figure([scatter_data], layout=layout)
plot.plotly_chart(fig, use_container_width=True)
if i == 0:
with st.expander("What is it?"):
st.write("You can see how the probability of the desired attribute varies for every generation step.")
text.text(current_decoded)
return (
[
self.tokenizer.decode(sequence, skip_special_tokens=True)
for sequence in input_ids
],
input_ids,
)
def generation_step(
self,
input_ids,
past,
ended_sequences,
is_caif_step: bool,
caif_tokens_num=None,
**sampler_kwargs
):
prepared_inputs = self.lm.prepare_inputs_for_generation(
input_ids, past, use_cache=True
)
outputs = self.lm(
**prepared_inputs,
output_attentions=False,
output_hidden_states=False,
return_dict=True
)
past = outputs.past_key_values
if self.entropy is not None:
normalized = torch.nn.functional.log_softmax(
outputs.logits, dim=-1
)
p = torch.exp(normalized)
output_probs = p
output_information = -normalized
output_entropy = (output_probs * output_information).sum(-1)[:, -1]
batch_size = output_entropy.shape[0]
caif_mask = torch.ge(output_entropy, self.entropy)
ordinary_mask = ~caif_mask
self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size
self.entropy_based_stats["count"] += 1
self.entropy_based_stats["avg_entropy"] += (
output_entropy.sum() / batch_size
)
flatten_entropy = output_entropy.view(-1).cpu().tolist()
if "entropy" not in self.entropy_based_stats.keys():
self.entropy_based_stats["entropy"] = flatten_entropy
else:
self.entropy_based_stats["entropy"] += flatten_entropy
if caif_mask.sum() == 0:
next_tokens_sampler = self.ordinary_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
elif caif_mask.sum() == batch_size:
next_tokens_sampler = self.caif_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
else:
next_tokens_caif = self.caif_sampler(
input_ids[caif_mask],
outputs.logits[caif_mask],
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens_ordinary = self.ordinary_sampler(
input_ids[ordinary_mask],
outputs.logits[ordinary_mask],
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens_caif = (
next_tokens_caif * (1 - ended_sequences[caif_mask].long())
+ self.lm.config.eos_token_id
* ended_sequences[caif_mask].long()
).long()
next_tokens_ordinary = (
next_tokens_ordinary
* (1 - ended_sequences[ordinary_mask].long())
+ self.lm.config.eos_token_id
* ended_sequences[ordinary_mask].long()
).long()
next_tokens = torch.ones(batch_size).long().to(self.device)
next_tokens[caif_mask] = next_tokens_caif
next_tokens[ordinary_mask] = next_tokens_ordinary
else:
if is_caif_step:
next_tokens_sampler = self.caif_sampler
else:
next_tokens_sampler = self.ordinary_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
input_ids = torch.cat(
[input_ids, next_tokens[:, None].to(self.device)], dim=-1
)
ended_sequences += next_tokens == self.lm.config.eos_token_id
return input_ids, past, ended_sequences
def get_input_ids(self, input_prompt, num_samples):
#input_ids = torch.tensor([[self.lm.config.bos_token_id]])
if input_prompt is not None:
input_prompt = self.tokenizer(
input_prompt, return_tensors="pt"
).input_ids
input_ids = input_prompt
input_ids = input_ids.repeat(num_samples, 1).to(self.device)
past = None
ended_sequences = torch.zeros(
input_ids.shape[0], device=self.device
).bool()
return input_ids, past, ended_sequences
@staticmethod
def sample(unscaled_probs, values):
samples = torch.multinomial(unscaled_probs, 1)
return torch.take_along_dim(values, samples, dim=1)