dakitari-instruct-v2-advanced / modeling_dakitari_instruct.py
Elijah
Upload 12 files
8410be3 verified
raw
history blame contribute delete
5.83 kB
import math
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation import GenerationMixin # NEW: Import GenerationMixin
from .configuration_dakitari_instruct import DakitariInstructConfig
class SimpleAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_embd = config.n_embd
self.in_proj = nn.Linear(config.n_embd, config.n_embd)
self.out_proj = nn.Linear(config.n_embd, config.n_embd)
def forward(self, x, attention_mask=None):
B, L, D = x.size() # batch, length, dimension
q = k = v = self.in_proj(x)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
if attention_mask is not None:
# Expand mask to correct shape
attention_mask = attention_mask.view(B, 1, L)
scores = scores.masked_fill(~attention_mask, float('-inf'))
attn = torch.softmax(scores, dim=-1)
context = torch.matmul(attn, v)
return self.out_proj(context)
class CustomTransformerLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = SimpleAttention(config)
self.linear1 = nn.Linear(config.n_embd, config.n_embd) # Keep dimensions consistent
self.linear2 = nn.Linear(config.n_embd, config.n_embd) # Keep dimensions consistent
self.norm1 = nn.LayerNorm(config.n_embd)
self.norm2 = nn.LayerNorm(config.n_embd)
self.dropout = nn.Dropout(config.resid_pdrop)
self.activation = nn.GELU()
# New: Adapter layers for domain-specific finetuning
self.adapter_down = nn.Linear(config.n_embd, config.adapter_bottleneck)
self.adapter_up = nn.Linear(config.adapter_bottleneck, config.n_embd)
self.norm_adapter = nn.LayerNorm(config.n_embd)
def forward(self, x, attention_mask=None):
residual = x
x_norm = self.norm1(x)
x_attn = self.attention(x_norm, attention_mask)
x = residual + self.dropout(x_attn)
residual = x
x_norm = self.norm2(x)
x_ff = self.linear2(self.dropout(self.activation(self.linear1(x_norm))))
x = residual + self.dropout(x_ff)
# New: Adapter branch (only train adapter layers)
adapter_input = self.norm_adapter(x)
adapter_out = self.adapter_up(self.adapter_down(adapter_input))
x = x + adapter_out
return x
class DakitariInstructModel(PreTrainedModel, GenerationMixin): # Updated: Inherit from GenerationMixin
config_class = DakitariInstructConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Update embeddings with new dimensions
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
# Build transformer layers based on new n_layer value and dimensions
self.layers = nn.ModuleList([
CustomTransformerLayer(config)
for _ in range(config.n_layer)
])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
# New: LM head for generation
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.apply(self._init_weights)
# Ensure adapter layers are not frozen (already commented out freezing)
for name, param in self.named_parameters():
if "adapter" in name:
param.requires_grad = True # Explicit: make sure adapters train
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, input_ids=None, attention_mask=None, position_ids=None, **kwargs):
if input_ids is None:
raise ValueError("input_ids must be provided")
# Ensure input_ids are within bounds
input_ids = torch.clamp(input_ids, 0, self.config.vocab_size - 1)
input_shape = input_ids.shape
batch_size, seq_length = input_shape
# Handle position IDs correctly
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
# Embeddings
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = self.drop(inputs_embeds + position_embeds)
# Ensure attention mask is bool tensor
if attention_mask is not None:
attention_mask = attention_mask.bool()
# Process through transformer layers
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
hidden_states = self.ln_f(hidden_states)
logits = self.lm_head(hidden_states)
# NEW: Return a CausalLMOutput with logits attribute so generate() works correctly
return CausalLMOutput(logits=logits, loss=logits.mean())
# NEW: Override generation input preparation
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}