Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from typing import Union, List, Dict, Optional | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM | |
from transformers.generation_logits_process import ( | |
LogitsProcessorList, | |
NoBadWordsLogitsProcessor, | |
NoRepeatNGramLogitsProcessor, | |
) | |
from utils import ( | |
NEGATIVE_INF, HALF_NEGATIVE_INF, | |
logits_to_entropy, mask_pad | |
) | |
from clipcap import ClipCap | |
class Policy(nn.Module): | |
def __init__(self, model_name, temperature, device, clipcap_path='', fix_gpt=False, | |
use_transformer_mapper: bool = False, use_ptuning_v2: bool = False, | |
prefix_length=10, clipcap_num_layers: int = 1, | |
label_path: str = '', model_weight: str = 'None', use_label_prefix: bool = False): | |
super().__init__() | |
self.device = device | |
self.model = ClipCap(model_name, device, | |
model_path=clipcap_path, fix_gpt=fix_gpt, | |
prefix_length=prefix_length, | |
num_layers=clipcap_num_layers, | |
label_path=label_path, model_weight=model_weight, | |
use_transformer_mapper=use_transformer_mapper, | |
use_ptuning_v2=use_ptuning_v2, | |
use_label_prefix=use_label_prefix) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|endoftext|>") | |
self.model.gpt.config.pad_token_id = self.tokenizer.pad_token_id | |
self.temperature = temperature | |
def get_processor(self, no_repeat_ngram_size: int = 3): | |
logits_processor = LogitsProcessorList() | |
if no_repeat_ngram_size > 0: | |
logits_processor.append(NoRepeatNGramLogitsProcessor(ngram_size=no_repeat_ngram_size)) | |
''' | |
logits_processor.append(NoBadWordsLogitsProcessor([[self.tokenizer.pad_token_id]], | |
self.tokenizer.pad_token_id)) | |
''' | |
return logits_processor | |
def sample(self, | |
input_ids: torch.Tensor = None, | |
features: torch.Tensor = None, | |
attention_mask: torch.Tensor = None, | |
labels: Optional[torch.Tensor] = None, | |
max_len: int = 20, | |
sample: bool = True, | |
top_k: int = None, | |
top_p: float = None, | |
temperature: float = None, | |
no_repeat_ngram_size: int = 0, | |
invalidate_eos: bool = True, | |
device = None) -> Dict[str, Union[torch.Tensor, List[str]]]: | |
if device is None: | |
device = self.device | |
if temperature is None: | |
temperature = self.temperature | |
input_ids = input_ids.to(device) | |
attention_mask = attention_mask.to(device) | |
model_kwargs = {'attention_mask': attention_mask} | |
batch_size, input_seq_len = input_ids.shape | |
logits_processor = self.get_processor(no_repeat_ngram_size=no_repeat_ngram_size) | |
logits_warper = self.model.gpt._get_logits_warper( | |
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=1 | |
) | |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) | |
output_logprob = torch.zeros([batch_size, 0], device=device) | |
eos_logprobs = torch.zeros([batch_size, 0], device=device) | |
output_mask = torch.ones([batch_size, 0], dtype=torch.long, device=device) | |
self.model.eval() | |
with torch.no_grad(): | |
for step in range(max_len): | |
# prepare model inputs | |
model_inputs = self.model.prepare_inputs_for_generation(input_ids, | |
features=features, | |
labels=labels, | |
**model_kwargs) | |
# forward pass to get next token | |
outputs = self.model( | |
**model_inputs, | |
device=device | |
) | |
# in the first decoding step, we want to use the 'real' last position for each sentence | |
if step == 0: | |
last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1 | |
next_token_logits = outputs.logits[range(batch_size), last_non_masked_idx, :] | |
else: | |
next_token_logits = outputs.logits[:, -1, :] | |
negative_inf = HALF_NEGATIVE_INF if next_token_logits.dtype == torch.half else NEGATIVE_INF | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
if invalidate_eos: | |
next_token_scores[:, self.tokenizer.eos_token_id] = negative_inf # no endoftext | |
log_prob = F.log_softmax(next_token_scores, dim=-1) # authentic sampling distribution | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
if sample: | |
# Temperature (higher temperature => more likely to sample low probability tokens) | |
probs = F.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
else: | |
# Greedy decoding | |
next_tokens = torch.argmax(next_token_scores, dim=-1) | |
# finished sentences should have their next token be a padding token | |
next_tokens = next_tokens * unfinished_sequences + self.tokenizer.pad_token_id * (1 - unfinished_sequences) | |
# update output mask | |
output_mask = torch.cat([output_mask, unfinished_sequences[:, None]], dim=-1) | |
# update output log probability | |
eos_logprob = log_prob[:, self.tokenizer.eos_token_id] | |
eos_logprob = eos_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences) | |
eos_logprobs = torch.cat([eos_logprobs, eos_logprob[:, None]], dim=-1) | |
token_logprob = torch.gather(log_prob, 1, next_tokens[:, None]).squeeze(1) | |
token_logprob = token_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences) | |
output_logprob = torch.cat([output_logprob, token_logprob[:, None]], dim=-1) | |
# update generated ids, model inputs for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self.model.gpt._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.model.gpt.config.is_encoder_decoder | |
) | |
# if eos_token was found in one sentence, set sentence to finished | |
unfinished_sequences = unfinished_sequences.mul((next_tokens != self.tokenizer.eos_token_id).long()) | |
if unfinished_sequences.max() == 0: | |
break | |
response_ids = input_ids[:, input_seq_len:] | |
response_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
for output in response_ids] | |
prompt_ids = input_ids[:, :input_seq_len] | |
prompts = [self.tokenizer.decode(query, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
for query in prompt_ids] | |
eos_probs = eos_logprobs.exp() | |
return { | |
'query/input_ids': prompt_ids, | |
'query/text': prompts, | |
'query/mask': attention_mask, | |
'response/input_ids': response_ids, | |
'response/text': response_text, | |
'response/mask': output_mask, | |
'response/log_prob': output_logprob, | |
'response/eos_prob': eos_probs, | |
} | |
def forward_pass(self, | |
query_input_ids: torch.Tensor, | |
query_mask: torch.Tensor, | |
response_input_ids: torch.Tensor, | |
response_mask: torch.Tensor, | |
features: torch.Tensor, | |
labels: Optional[torch.Tensor] = None, | |
invalidate_eos: bool = True, | |
device = None): | |
if device is None: | |
device = self.device | |
batch_size, query_seq_len = query_input_ids.shape | |
input_ids = torch.cat([query_input_ids, response_input_ids], dim=-1) | |
attention_mask = torch.cat([query_mask, response_mask], dim=-1) | |
# forward pass to get next token | |
outputs = self.model( | |
input_ids, | |
features, | |
attention_mask, | |
labels, | |
device=device | |
) | |
# get the first logit | |
query_logits = outputs.logits[:, :query_seq_len, :] | |
last_non_masked_idx = torch.sum(query_mask, dim=1) - 1 | |
first_logits = query_logits[range(batch_size), last_non_masked_idx, :] | |
# get the second to last logit | |
response_logits = outputs.logits[:, query_seq_len:-1, :] | |
logits = torch.cat([first_logits[:, None], response_logits], dim=1) | |
negative_inf = HALF_NEGATIVE_INF if logits.dtype == torch.half else NEGATIVE_INF | |
if invalidate_eos: | |
logits[:, :, self.tokenizer.eos_token_id] = negative_inf # no endoftext | |
log_prob = F.log_softmax(logits, dim=-1) | |
output_logprob = torch.gather(log_prob, 2, response_input_ids[:, :, None]).squeeze(2) | |
output_entropy = logits_to_entropy(logits) | |
eos_prob = F.softmax(logits, dim=-1)[:, :, self.tokenizer.eos_token_id] | |
pos_logit = torch.gather(logits, 2, response_input_ids[:, :, None]).squeeze(2) | |
return { | |
'response/log_prob': mask_pad(output_logprob, response_mask), | |
'response/eos_prob': mask_pad(eos_prob, response_mask), | |
'response/entropy': mask_pad(output_entropy, response_mask), | |
'response/pos_logit': mask_pad(pos_logit, response_mask), | |
'response/logits': logits, | |
} | |