import torch from torch import nn import torch.nn.functional as F from typing import Union, List, Dict, Optional from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM from transformers.generation_logits_process import ( LogitsProcessorList, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, ) from utils import ( NEGATIVE_INF, HALF_NEGATIVE_INF, logits_to_entropy, mask_pad ) from clipcap import ClipCap class Policy(nn.Module): def __init__(self, model_name, temperature, device, clipcap_path='', fix_gpt=False, use_transformer_mapper: bool = False, use_ptuning_v2: bool = False, prefix_length=10, clipcap_num_layers: int = 1, label_path: str = '', model_weight: str = 'None', use_label_prefix: bool = False): super().__init__() self.device = device self.model = ClipCap(model_name, device, model_path=clipcap_path, fix_gpt=fix_gpt, prefix_length=prefix_length, num_layers=clipcap_num_layers, label_path=label_path, model_weight=model_weight, use_transformer_mapper=use_transformer_mapper, use_ptuning_v2=use_ptuning_v2, use_label_prefix=use_label_prefix) self.tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|endoftext|>") self.model.gpt.config.pad_token_id = self.tokenizer.pad_token_id self.temperature = temperature def get_processor(self, no_repeat_ngram_size: int = 3): logits_processor = LogitsProcessorList() if no_repeat_ngram_size > 0: logits_processor.append(NoRepeatNGramLogitsProcessor(ngram_size=no_repeat_ngram_size)) ''' logits_processor.append(NoBadWordsLogitsProcessor([[self.tokenizer.pad_token_id]], self.tokenizer.pad_token_id)) ''' return logits_processor def sample(self, input_ids: torch.Tensor = None, features: torch.Tensor = None, attention_mask: torch.Tensor = None, labels: Optional[torch.Tensor] = None, max_len: int = 20, sample: bool = True, top_k: int = None, top_p: float = None, temperature: float = None, no_repeat_ngram_size: int = 0, invalidate_eos: bool = True, device = None) -> Dict[str, Union[torch.Tensor, List[str]]]: if device is None: device = self.device if temperature is None: temperature = self.temperature input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) model_kwargs = {'attention_mask': attention_mask} batch_size, input_seq_len = input_ids.shape logits_processor = self.get_processor(no_repeat_ngram_size=no_repeat_ngram_size) logits_warper = self.model.gpt._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature, num_beams=1 ) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) output_logprob = torch.zeros([batch_size, 0], device=device) eos_logprobs = torch.zeros([batch_size, 0], device=device) output_mask = torch.ones([batch_size, 0], dtype=torch.long, device=device) self.model.eval() with torch.no_grad(): for step in range(max_len): # prepare model inputs model_inputs = self.model.prepare_inputs_for_generation(input_ids, features=features, labels=labels, **model_kwargs) # forward pass to get next token outputs = self.model( **model_inputs, device=device ) # in the first decoding step, we want to use the 'real' last position for each sentence if step == 0: last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1 next_token_logits = outputs.logits[range(batch_size), last_non_masked_idx, :] else: next_token_logits = outputs.logits[:, -1, :] negative_inf = HALF_NEGATIVE_INF if next_token_logits.dtype == torch.half else NEGATIVE_INF next_token_scores = logits_processor(input_ids, next_token_logits) if invalidate_eos: next_token_scores[:, self.tokenizer.eos_token_id] = negative_inf # no endoftext log_prob = F.log_softmax(next_token_scores, dim=-1) # authentic sampling distribution next_token_scores = logits_warper(input_ids, next_token_scores) if sample: # Temperature (higher temperature => more likely to sample low probability tokens) probs = F.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: # Greedy decoding next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token next_tokens = next_tokens * unfinished_sequences + self.tokenizer.pad_token_id * (1 - unfinished_sequences) # update output mask output_mask = torch.cat([output_mask, unfinished_sequences[:, None]], dim=-1) # update output log probability eos_logprob = log_prob[:, self.tokenizer.eos_token_id] eos_logprob = eos_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences) eos_logprobs = torch.cat([eos_logprobs, eos_logprob[:, None]], dim=-1) token_logprob = torch.gather(log_prob, 1, next_tokens[:, None]).squeeze(1) token_logprob = token_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences) output_logprob = torch.cat([output_logprob, token_logprob[:, None]], dim=-1) # update generated ids, model inputs for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self.model.gpt._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.model.gpt.config.is_encoder_decoder ) # if eos_token was found in one sentence, set sentence to finished unfinished_sequences = unfinished_sequences.mul((next_tokens != self.tokenizer.eos_token_id).long()) if unfinished_sequences.max() == 0: break response_ids = input_ids[:, input_seq_len:] response_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) for output in response_ids] prompt_ids = input_ids[:, :input_seq_len] prompts = [self.tokenizer.decode(query, skip_special_tokens=True, clean_up_tokenization_spaces=True) for query in prompt_ids] eos_probs = eos_logprobs.exp() return { 'query/input_ids': prompt_ids, 'query/text': prompts, 'query/mask': attention_mask, 'response/input_ids': response_ids, 'response/text': response_text, 'response/mask': output_mask, 'response/log_prob': output_logprob, 'response/eos_prob': eos_probs, } def forward_pass(self, query_input_ids: torch.Tensor, query_mask: torch.Tensor, response_input_ids: torch.Tensor, response_mask: torch.Tensor, features: torch.Tensor, labels: Optional[torch.Tensor] = None, invalidate_eos: bool = True, device = None): if device is None: device = self.device batch_size, query_seq_len = query_input_ids.shape input_ids = torch.cat([query_input_ids, response_input_ids], dim=-1) attention_mask = torch.cat([query_mask, response_mask], dim=-1) # forward pass to get next token outputs = self.model( input_ids, features, attention_mask, labels, device=device ) # get the first logit query_logits = outputs.logits[:, :query_seq_len, :] last_non_masked_idx = torch.sum(query_mask, dim=1) - 1 first_logits = query_logits[range(batch_size), last_non_masked_idx, :] # get the second to last logit response_logits = outputs.logits[:, query_seq_len:-1, :] logits = torch.cat([first_logits[:, None], response_logits], dim=1) negative_inf = HALF_NEGATIVE_INF if logits.dtype == torch.half else NEGATIVE_INF if invalidate_eos: logits[:, :, self.tokenizer.eos_token_id] = negative_inf # no endoftext log_prob = F.log_softmax(logits, dim=-1) output_logprob = torch.gather(log_prob, 2, response_input_ids[:, :, None]).squeeze(2) output_entropy = logits_to_entropy(logits) eos_prob = F.softmax(logits, dim=-1)[:, :, self.tokenizer.eos_token_id] pos_logit = torch.gather(logits, 2, response_input_ids[:, :, None]).squeeze(2) return { 'response/log_prob': mask_pad(output_logprob, response_mask), 'response/eos_prob': mask_pad(eos_prob, response_mask), 'response/entropy': mask_pad(output_entropy, response_mask), 'response/pos_logit': mask_pad(pos_logit, response_mask), 'response/logits': logits, }