Spaces:
Runtime error
Runtime error
from typing import Optional, Union | |
import torch | |
import transformers | |
import streamlit as st | |
from plotly import graph_objects as go | |
from utils import get_lm | |
class Generator: | |
def __init__(self, lm_model_name, device, entropy=None): | |
self.device = device | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained( | |
lm_model_name | |
) | |
self.lm = get_lm(lm_model_name).to(device) | |
self.lm.eval() | |
self.lm.config.pad_token_id = self.lm.config.eos_token_id | |
self.tokenizer.add_special_tokens( | |
{"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)} | |
) | |
self.caif_sampler = None | |
self.ordinary_sampler = None | |
self.entropy_based_stats = { | |
"skips": 0, | |
"avg_entropy": 0, | |
"count": 0, | |
} | |
self.entropy = entropy | |
def set_caif_sampler(self, sampler): | |
self.caif_sampler = sampler | |
def set_ordinary_sampler(self, sampler): | |
self.ordinary_sampler = sampler | |
def sample_sequences( | |
self, | |
num_samples: int, | |
input_prompt: Optional[str], | |
max_length: int, | |
caif_period: int, | |
caif_tokens_num: Union[int, None] = None, | |
entropy: float = None, | |
progress_bar=None, | |
**sampler_kwargs | |
): | |
self.entropy = entropy | |
input_ids, past, ended_sequences = self.get_input_ids( | |
input_prompt, | |
num_samples, | |
) | |
text = st.empty() | |
plot = st.empty() | |
gen_history = [] | |
layout = go.Layout({ | |
"xaxis": { | |
"title": "# Tokens" | |
}, | |
"yaxis": { | |
"title": "Desired Attribute" | |
}, | |
"plot_bgcolor": '#FFFFFF', | |
"template": "plotly_white", | |
"hovermode": "x", | |
}) | |
inp_len = len(input_ids[0]) | |
if self.caif_sampler is not None: | |
current_decoded = self.tokenizer.decode(input_ids[0]) | |
probs = torch.exp( | |
self.caif_sampler.get_classifier_log_probs( | |
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"] | |
) | |
).item() | |
gen_history += [probs] | |
for i in range(max_length): | |
is_caif_step = ( | |
i % caif_period == 0 and self.caif_sampler is not None | |
) | |
input_ids, past, ended_sequences = self.generation_step( | |
input_ids, | |
past, | |
ended_sequences, | |
is_caif_step, | |
caif_tokens_num=caif_tokens_num, | |
**sampler_kwargs | |
) | |
progress_bar.progress((i+1)/max_length) | |
if ended_sequences.all(): | |
break | |
current_decoded = self.tokenizer.decode(input_ids[0]) | |
if self.caif_sampler is not None: | |
probs = torch.exp( | |
self.caif_sampler.get_classifier_log_probs( | |
current_decoded, target_cls_id=sampler_kwargs["target_cls_id"] | |
) | |
).item() | |
gen_history += [probs] | |
scatter_data = go.Scatter({ | |
"x": list(range(len(gen_history))), | |
"y": gen_history, | |
"hovertext": ["[PROMPT]"] + [self.tokenizer.decode(t) for t in input_ids[0][inp_len:]] | |
}) | |
fig = go.Figure([scatter_data], layout=layout) | |
plot.plotly_chart(fig, use_container_width=True) | |
if i == 0: | |
with st.expander("What is it?"): | |
st.write("You can see how the probability of the desired attribute varies for every generation step.") | |
text.text(current_decoded) | |
return ( | |
[ | |
self.tokenizer.decode(sequence, skip_special_tokens=True) | |
for sequence in input_ids | |
], | |
input_ids, | |
) | |
def generation_step( | |
self, | |
input_ids, | |
past, | |
ended_sequences, | |
is_caif_step: bool, | |
caif_tokens_num=None, | |
**sampler_kwargs | |
): | |
prepared_inputs = self.lm.prepare_inputs_for_generation( | |
input_ids, past, use_cache=True | |
) | |
outputs = self.lm( | |
**prepared_inputs, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=True | |
) | |
past = outputs.past_key_values | |
if self.entropy is not None: | |
normalized = torch.nn.functional.log_softmax( | |
outputs.logits, dim=-1 | |
) | |
p = torch.exp(normalized) | |
output_probs = p | |
output_information = -normalized | |
output_entropy = (output_probs * output_information).sum(-1)[:, -1] | |
batch_size = output_entropy.shape[0] | |
caif_mask = torch.ge(output_entropy, self.entropy) | |
ordinary_mask = ~caif_mask | |
self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size | |
self.entropy_based_stats["count"] += 1 | |
self.entropy_based_stats["avg_entropy"] += ( | |
output_entropy.sum() / batch_size | |
) | |
flatten_entropy = output_entropy.view(-1).cpu().tolist() | |
if "entropy" not in self.entropy_based_stats.keys(): | |
self.entropy_based_stats["entropy"] = flatten_entropy | |
else: | |
self.entropy_based_stats["entropy"] += flatten_entropy | |
if caif_mask.sum() == 0: | |
next_tokens_sampler = self.ordinary_sampler | |
next_tokens = next_tokens_sampler( | |
input_ids, | |
outputs.logits, | |
caif_tokens_num=caif_tokens_num, | |
**sampler_kwargs | |
) | |
next_tokens = ( | |
next_tokens * (1 - ended_sequences.long()) | |
+ self.lm.config.eos_token_id * ended_sequences.long() | |
).long() | |
elif caif_mask.sum() == batch_size: | |
next_tokens_sampler = self.caif_sampler | |
next_tokens = next_tokens_sampler( | |
input_ids, | |
outputs.logits, | |
caif_tokens_num=caif_tokens_num, | |
**sampler_kwargs | |
) | |
next_tokens = ( | |
next_tokens * (1 - ended_sequences.long()) | |
+ self.lm.config.eos_token_id * ended_sequences.long() | |
).long() | |
else: | |
next_tokens_caif = self.caif_sampler( | |
input_ids[caif_mask], | |
outputs.logits[caif_mask], | |
caif_tokens_num=caif_tokens_num, | |
**sampler_kwargs | |
) | |
next_tokens_ordinary = self.ordinary_sampler( | |
input_ids[ordinary_mask], | |
outputs.logits[ordinary_mask], | |
caif_tokens_num=caif_tokens_num, | |
**sampler_kwargs | |
) | |
next_tokens_caif = ( | |
next_tokens_caif * (1 - ended_sequences[caif_mask].long()) | |
+ self.lm.config.eos_token_id | |
* ended_sequences[caif_mask].long() | |
).long() | |
next_tokens_ordinary = ( | |
next_tokens_ordinary | |
* (1 - ended_sequences[ordinary_mask].long()) | |
+ self.lm.config.eos_token_id | |
* ended_sequences[ordinary_mask].long() | |
).long() | |
next_tokens = torch.ones(batch_size).long().to(self.device) | |
next_tokens[caif_mask] = next_tokens_caif | |
next_tokens[ordinary_mask] = next_tokens_ordinary | |
else: | |
if is_caif_step: | |
next_tokens_sampler = self.caif_sampler | |
else: | |
next_tokens_sampler = self.ordinary_sampler | |
next_tokens = next_tokens_sampler( | |
input_ids, | |
outputs.logits, | |
caif_tokens_num=caif_tokens_num, | |
**sampler_kwargs | |
) | |
next_tokens = ( | |
next_tokens * (1 - ended_sequences.long()) | |
+ self.lm.config.eos_token_id * ended_sequences.long() | |
).long() | |
input_ids = torch.cat( | |
[input_ids, next_tokens[:, None].to(self.device)], dim=-1 | |
) | |
ended_sequences += next_tokens == self.lm.config.eos_token_id | |
return input_ids, past, ended_sequences | |
def get_input_ids(self, input_prompt, num_samples): | |
#input_ids = torch.tensor([[self.lm.config.bos_token_id]]) | |
if input_prompt is not None: | |
input_prompt = self.tokenizer( | |
input_prompt, return_tensors="pt" | |
).input_ids | |
input_ids = input_prompt | |
input_ids = input_ids.repeat(num_samples, 1).to(self.device) | |
past = None | |
ended_sequences = torch.zeros( | |
input_ids.shape[0], device=self.device | |
).bool() | |
return input_ids, past, ended_sequences | |
def sample(unscaled_probs, values): | |
samples = torch.multinomial(unscaled_probs, 1) | |
return torch.take_along_dim(values, samples, dim=1) | |