MTMT / modeling_mtmt.py
pritamdeb68's picture
Create modeling_mtmt.py
52f7984 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from .configuration_mtmt import MTMTConfig
class MultiHeadAttention(nn.Module):
def __init__(self, emb_dim=512, num_heads=8, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.embed_size = emb_dim
self.heads = num_heads
self.head_dim = emb_dim // num_heads
assert self.head_dim * num_heads == emb_dim, "Embed size must be divisible by heads"
self.keys = nn.Linear(emb_dim, emb_dim, bias=False)
self.queries = nn.Linear(emb_dim, emb_dim, bias=False)
self.values = nn.Linear(emb_dim, emb_dim, bias=False)
self.fc_out = nn.Linear(emb_dim, emb_dim)
def forward(self, x, attn_mask=None):
N, seq_length, _ = x.shape
values = self.values(x)
keys = self.keys(x)
queries = self.queries(x)
values = values.view(N, seq_length, self.heads, self.head_dim)
keys = keys.view(N, seq_length, self.heads, self.head_dim)
queries = queries.view(N, seq_length, self.heads, self.head_dim)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if attn_mask is not None:
attn_mask = attn_mask.bool() if attn_mask.dtype != torch.bool else attn_mask
if attn_mask.dim() == 2:
token_mask = attn_mask.unsqueeze(1).unsqueeze(2).expand(N, self.heads, seq_length, -1)
elif attn_mask.dim() == 3:
token_mask = attn_mask.unsqueeze(1).expand(N, self.heads, -1, -1)
elif attn_mask.dim() == 4:
token_mask = attn_mask
else:
raise ValueError("attn_mask must be of dimension 2, 3, or 4.")
causal_mask = torch.tril(torch.ones(seq_length, seq_length, device=x.device, dtype=torch.bool))
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
combined_mask = token_mask & causal_mask
energy = energy.masked_fill(~combined_mask, float("-1e20"))
attention = F.softmax(energy / (self.head_dim ** 0.5), dim=-1)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
out = out.reshape(N, seq_length, self.embed_size)
out = self.fc_out(out)
return out
class TransformerBlockWithFFD(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.mha = MultiHeadAttention(embed_dim, num_heads, dropout=dropout)
self.ff = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_dim)
)
self.layernorm1 = nn.LayerNorm(embed_dim)
self.layernorm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
attn_out = self.mha(x, attn_mask=attn_mask)
x = x + self.dropout(attn_out)
x = self.layernorm1(x)
ff_out = self.ff(x)
x = x + self.dropout(ff_out)
x = self.layernorm2(x)
return x
class Gate(nn.Module):
def __init__(self, embed_dim, num_experts=2):
super().__init__()
self.gate_proj = nn.Linear(embed_dim, num_experts)
def forward(self, x):
gate_logits = self.gate_proj(x)
gate_weights = F.softmax(gate_logits, dim=-1)
return gate_weights
class TierBlock(nn.Module):
def __init__(self, in_dim, out_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=False)
self.transformer_block = TransformerBlockWithFFD(
embed_dim=out_dim,
num_heads=num_heads,
ff_dim=ff_dim,
dropout=dropout
)
def forward(self, x, attn_mask=None):
x = self.linear(x)
x = self.transformer_block(x, attn_mask=attn_mask)
return x
class MTMT(PreTrainedModel):
config_class = MTMTConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
self.initial_block = TransformerBlockWithFFD(
embed_dim=config.embed_dim,
num_heads=config.num_heads,
ff_dim=config.ff_dim,
dropout=config.dropout
)
self.gate1 = Gate(config.embed_dim, num_experts=2)
self.tier1_left = TierBlock(config.embed_dim, config.embed_dim // 2, config.num_heads, config.ff_dim // 2, config.dropout)
self.tier1_right = TierBlock(config.embed_dim, config.embed_dim // 2, config.num_heads, config.ff_dim // 2, config.dropout)
self.gate2_left = Gate(config.embed_dim // 2, num_experts=2)
self.tier2_left1 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout)
self.tier2_left2 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout)
self.gate2_right = Gate(config.embed_dim // 2, num_experts=2)
self.tier2_right1 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout)
self.tier2_right2 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout)
self.concat_mlp = nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim),
nn.ReLU()
)
self.output_projection = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
# Initialize weights (optional, can add custom initialization if needed)
self.post_init()
def forward(self, input_ids, attn_mask=None, **kwargs):
x = self.embedding(input_ids)
x = self.initial_block(x, attn_mask=attn_mask)
gate_w1 = self.gate1(x)
left_out = self.tier1_left(x, attn_mask=attn_mask)
right_out = self.tier1_right(x, attn_mask=attn_mask)
left_weighted = left_out * gate_w1[..., 0].unsqueeze(-1)
right_weighted = right_out * gate_w1[..., 1].unsqueeze(-1)
gate_w2_left = self.gate2_left(left_weighted)
left1_out = self.tier2_left1(left_weighted, attn_mask=attn_mask)
left2_out = self.tier2_left2(left_weighted, attn_mask=attn_mask)
left1_w = left1_out * gate_w2_left[..., 0].unsqueeze(-1)
left2_w = left2_out * gate_w2_left[..., 1].unsqueeze(-1)
left_concat = torch.cat([left1_w, left2_w], dim=-1)
gate_w2_right = self.gate2_right(right_weighted)
right1_out = self.tier2_right1(right_weighted, attn_mask=attn_mask)
right2_out = self.tier2_right2(right_weighted, attn_mask=attn_mask)
right1_w = right1_out * gate_w2_right[..., 0].unsqueeze(-1)
right2_w = right2_out * gate_w2_right[..., 1].unsqueeze(-1)
right_concat = torch.cat([right1_w, right2_w], dim=-1)
x_cat = torch.cat([left_concat, right_concat], dim=-1)
x_mlp = self.concat_mlp(x_cat)
logits = self.output_projection(x_mlp)
return logits
def generate_text(self, tokenizer, prompt, max_length=50, temperature=1.0, top_k=40, device="cpu"):
if max_length > 250:
return "Max length should be less than 250"
else:
self.eval()
self.to(device)
encoding = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids = encoding["input_ids"].to(device)
for _ in range(max_length):
with torch.no_grad():
seq_len = input_ids.size(1)
mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).unsqueeze(0)
outputs = self(input_ids, mask)
logits = outputs[:, -1, :]
logits = logits / temperature
top_k_values, top_k_indices = torch.topk(logits, top_k)
probabilities = torch.nn.functional.softmax(top_k_values, dim=-1)
sample_idx = torch.multinomial(probabilities, num_samples=1)
next_token = top_k_indices.gather(dim=-1, index=sample_idx)
next_token = next_token.squeeze(-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Register your model for auto class usage
MTMT.register_for_auto_class("AutoModel")
MTMT.register_for_auto_class("AutoModelForCausalLM")