File size: 5,159 Bytes
6807253 1d7496d 6807253 1d7496d 6807253 1d7496d 6807253 1d7496d 6807253 1d7496d 6807253 1d7496d 6807253 1d7496d 6807253 1d7496d 6807253 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass
@dataclass
class SelfCorrectiveLlamaOutput(CausalLMOutputWithPast):
hallucination_logits: torch.FloatTensor = None
class SelfCorrectiveLlama(LlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
self.num_new_tokens = 3
self.original_vocab_size = config.vocab_size
# Create a new, small embedding layer for only the special tokens
self.new_token_embeddings = nn.Embedding(self.num_new_tokens, config.hidden_size)
# Initialize new embeddings with the mean of the original ones
with torch.no_grad():
original_embeddings = self.model.embed_tokens.weight
mean_embeddings = original_embeddings.mean(dim=0)
self.new_token_embeddings.weight.data.copy_(
mean_embeddings.unsqueeze(0).expand(self.num_new_tokens, -1)
)
intermediate_size = config.intermediate_size
self.hallucination_gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.hallucination_up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.hallucination_down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
self.hallucination_detector = nn.Linear(config.hidden_size, self.num_new_tokens + 1)
def forward(
self,
input_ids,
attention_mask=None,
labels=None,
hallucination_labels=None,
**kwargs
):
# 1. Manually construct the input embeddings
clamped_input_ids = torch.clamp(input_ids, max=self.original_vocab_size - 1)
inputs_embeds = self.model.embed_tokens(clamped_input_ids)
# Overwrite the embeddings for our new special tokens
special_token_mask = input_ids >= self.original_vocab_size
if special_token_mask.any():
special_ids = input_ids[special_token_mask] - self.original_vocab_size
special_embeds = self.new_token_embeddings(special_ids)
inputs_embeds[special_token_mask] = special_embeds
# 2. Pass the constructed embeddings through the base transformer model
kwargs["inputs_embeds"] = inputs_embeds
transformer_outputs = self.model(
attention_mask=attention_mask,
**kwargs
)
last_hidden = transformer_outputs.last_hidden_state
# 3. Calculate token logits by combining outputs from both heads
# Main logits from the original, frozen lm_head
main_logits = self.lm_head(last_hidden)
# New token logits from small, trainable embedding layer
new_logits = F.linear(last_hidden, self.new_token_embeddings.weight)
# Concatenate to get logits over the full, expanded vocabulary
logits = torch.cat([main_logits, new_logits], dim=-1)
# 4. SwiGLU-based hallucination detector
gate_output = self.hallucination_gate_proj(last_hidden)
up_output = self.hallucination_up_proj(last_hidden)
gated_hidden = F.silu(gate_output) * up_output
detector_hidden = self.hallucination_down_proj(gated_hidden)
# Hallucination logits
all_hallucination_logits = self.hallucination_detector(detector_hidden)
# 5. Modify the token logits conditionally.
deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
# Conditionally add the deletion logits.
if hallucination_labels is not None and labels is not None:
# Training case:
# Condition 1: The hallucination label is 0 (no hallucination)
mask_no_hallucination = (hallucination_labels == 0)
# Condition 2: The next token is one of the deletion tokens.
# Check if the token ID is within the range of the last `num_new_tokens` in the vocab
vocab_size = logits.shape[-1]
mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size)
# Combine masks and create the tensor to add.
combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
to_add = torch.where(
combined_mask,
deletion_logits,
torch.zeros_like(deletion_logits)
)
logits[:, :, -self.num_new_tokens:].add_(to_add)
else:
# Inference case: always add the deletion logits to the token logits
logits[:, :, -self.num_new_tokens:].add_(deletion_logits)
# 6. Return the custom output object
return SelfCorrectiveLlamaOutput(
loss=None, # Loss calculation is handled by the Trainer
logits=logits,
hallucination_logits=all_hallucination_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=None,
attentions=transformer_outputs.attentions
) |