| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from transformers import LlamaForCausalLM |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from dataclasses import dataclass |
| |
|
| | @dataclass |
| | class SelfCorrectiveLlamaOutput(CausalLMOutputWithPast): |
| | hallucination_logits: torch.FloatTensor = None |
| |
|
| | class SelfCorrectiveLlama(LlamaForCausalLM): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | |
| | self.num_new_tokens = 3 |
| | self.original_vocab_size = config.vocab_size |
| |
|
| | |
| | self.new_token_embeddings = nn.Embedding(self.num_new_tokens, config.hidden_size) |
| |
|
| | |
| | with torch.no_grad(): |
| | original_embeddings = self.model.embed_tokens.weight |
| | mean_embeddings = original_embeddings.mean(dim=0) |
| | self.new_token_embeddings.weight.data.copy_( |
| | mean_embeddings.unsqueeze(0).expand(self.num_new_tokens, -1) |
| | ) |
| |
|
| | intermediate_size = config.intermediate_size |
| | self.hallucination_gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) |
| | self.hallucination_up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) |
| | self.hallucination_down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False) |
| | self.hallucination_detector = nn.Linear(config.hidden_size, self.num_new_tokens + 1) |
| | |
| | def forward( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | labels=None, |
| | hallucination_labels=None, |
| | **kwargs |
| | ): |
| | |
| | clamped_input_ids = torch.clamp(input_ids, max=self.original_vocab_size - 1) |
| | inputs_embeds = self.model.embed_tokens(clamped_input_ids) |
| |
|
| | |
| | special_token_mask = input_ids >= self.original_vocab_size |
| | if special_token_mask.any(): |
| | special_ids = input_ids[special_token_mask] - self.original_vocab_size |
| | special_embeds = self.new_token_embeddings(special_ids) |
| | inputs_embeds[special_token_mask] = special_embeds |
| |
|
| | |
| | kwargs["inputs_embeds"] = inputs_embeds |
| | transformer_outputs = self.model( |
| | attention_mask=attention_mask, |
| | **kwargs |
| | ) |
| | last_hidden = transformer_outputs.last_hidden_state |
| |
|
| | |
| | |
| | main_logits = self.lm_head(last_hidden) |
| |
|
| | |
| | new_logits = F.linear(last_hidden, self.new_token_embeddings.weight) |
| |
|
| | |
| | logits = torch.cat([main_logits, new_logits], dim=-1) |
| |
|
| | |
| | gate_output = self.hallucination_gate_proj(last_hidden) |
| | up_output = self.hallucination_up_proj(last_hidden) |
| | gated_hidden = F.silu(gate_output) * up_output |
| | detector_hidden = self.hallucination_down_proj(gated_hidden) |
| |
|
| | |
| | all_hallucination_logits = self.hallucination_detector(detector_hidden) |
| |
|
| | |
| | deletion_logits = all_hallucination_logits[..., 1:] |
| | deletion_tokens_boost = F.softplus(deletion_logits) |
| |
|
| | |
| | if hallucination_labels is not None and labels is not None: |
| | |
| | |
| | mask_no_hallucination = (hallucination_labels == 0) |
| |
|
| | |
| | |
| | vocab_size = logits.shape[-1] |
| | mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size) |
| | |
| | |
| | combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1) |
| | to_add = torch.where( |
| | combined_mask, |
| | deletion_tokens_boost, |
| | torch.zeros_like(deletion_tokens_boost) |
| | ) |
| | else: |
| | |
| | hallucination_decision = torch.argmax(all_hallucination_logits, dim=-1) |
| |
|
| | |
| | hallucination_present_mask = (hallucination_decision != 0).unsqueeze(-1) |
| |
|
| | |
| | |
| | to_add = torch.where( |
| | hallucination_present_mask, |
| | deletion_tokens_boost, |
| | torch.full_like(deletion_tokens_boost, torch.finfo(deletion_tokens_boost.dtype).min) |
| | ) |
| | |
| | |
| |
|
| | |
| | return SelfCorrectiveLlamaOutput( |
| | loss=None, |
| | logits=logits, |
| | hallucination_logits=all_hallucination_logits, |
| | past_key_values=transformer_outputs.past_key_values, |
| | hidden_states=None, |
| | attentions=transformer_outputs.attentions |
| | ) |