|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel |
|
from .configuration_mtmt import MTMTConfig |
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__(self, emb_dim=512, num_heads=8, dropout=0.1): |
|
super(MultiHeadAttention, self).__init__() |
|
self.embed_size = emb_dim |
|
self.heads = num_heads |
|
self.head_dim = emb_dim // num_heads |
|
|
|
assert self.head_dim * num_heads == emb_dim, "Embed size must be divisible by heads" |
|
|
|
self.keys = nn.Linear(emb_dim, emb_dim, bias=False) |
|
self.queries = nn.Linear(emb_dim, emb_dim, bias=False) |
|
self.values = nn.Linear(emb_dim, emb_dim, bias=False) |
|
|
|
self.fc_out = nn.Linear(emb_dim, emb_dim) |
|
|
|
def forward(self, x, attn_mask=None): |
|
N, seq_length, _ = x.shape |
|
values = self.values(x) |
|
keys = self.keys(x) |
|
queries = self.queries(x) |
|
|
|
values = values.view(N, seq_length, self.heads, self.head_dim) |
|
keys = keys.view(N, seq_length, self.heads, self.head_dim) |
|
queries = queries.view(N, seq_length, self.heads, self.head_dim) |
|
|
|
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) |
|
|
|
if attn_mask is not None: |
|
attn_mask = attn_mask.bool() if attn_mask.dtype != torch.bool else attn_mask |
|
if attn_mask.dim() == 2: |
|
token_mask = attn_mask.unsqueeze(1).unsqueeze(2).expand(N, self.heads, seq_length, -1) |
|
elif attn_mask.dim() == 3: |
|
token_mask = attn_mask.unsqueeze(1).expand(N, self.heads, -1, -1) |
|
elif attn_mask.dim() == 4: |
|
token_mask = attn_mask |
|
else: |
|
raise ValueError("attn_mask must be of dimension 2, 3, or 4.") |
|
|
|
causal_mask = torch.tril(torch.ones(seq_length, seq_length, device=x.device, dtype=torch.bool)) |
|
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) |
|
combined_mask = token_mask & causal_mask |
|
|
|
energy = energy.masked_fill(~combined_mask, float("-1e20")) |
|
|
|
attention = F.softmax(energy / (self.head_dim ** 0.5), dim=-1) |
|
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]) |
|
out = out.reshape(N, seq_length, self.embed_size) |
|
out = self.fc_out(out) |
|
return out |
|
|
|
class TransformerBlockWithFFD(nn.Module): |
|
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): |
|
super().__init__() |
|
self.mha = MultiHeadAttention(embed_dim, num_heads, dropout=dropout) |
|
self.ff = nn.Sequential( |
|
nn.Linear(embed_dim, ff_dim), |
|
nn.ReLU(), |
|
nn.Linear(ff_dim, embed_dim) |
|
) |
|
self.layernorm1 = nn.LayerNorm(embed_dim) |
|
self.layernorm2 = nn.LayerNorm(embed_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x, attn_mask=None): |
|
attn_out = self.mha(x, attn_mask=attn_mask) |
|
x = x + self.dropout(attn_out) |
|
x = self.layernorm1(x) |
|
ff_out = self.ff(x) |
|
x = x + self.dropout(ff_out) |
|
x = self.layernorm2(x) |
|
return x |
|
|
|
class Gate(nn.Module): |
|
def __init__(self, embed_dim, num_experts=2): |
|
super().__init__() |
|
self.gate_proj = nn.Linear(embed_dim, num_experts) |
|
|
|
def forward(self, x): |
|
gate_logits = self.gate_proj(x) |
|
gate_weights = F.softmax(gate_logits, dim=-1) |
|
return gate_weights |
|
|
|
class TierBlock(nn.Module): |
|
def __init__(self, in_dim, out_dim, num_heads, ff_dim, dropout=0.1): |
|
super().__init__() |
|
self.linear = nn.Linear(in_dim, out_dim, bias=False) |
|
self.transformer_block = TransformerBlockWithFFD( |
|
embed_dim=out_dim, |
|
num_heads=num_heads, |
|
ff_dim=ff_dim, |
|
dropout=dropout |
|
) |
|
|
|
def forward(self, x, attn_mask=None): |
|
x = self.linear(x) |
|
x = self.transformer_block(x, attn_mask=attn_mask) |
|
return x |
|
|
|
class MTMT(PreTrainedModel): |
|
config_class = MTMTConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embedding = nn.Embedding(config.vocab_size, config.embed_dim) |
|
self.initial_block = TransformerBlockWithFFD( |
|
embed_dim=config.embed_dim, |
|
num_heads=config.num_heads, |
|
ff_dim=config.ff_dim, |
|
dropout=config.dropout |
|
) |
|
self.gate1 = Gate(config.embed_dim, num_experts=2) |
|
self.tier1_left = TierBlock(config.embed_dim, config.embed_dim // 2, config.num_heads, config.ff_dim // 2, config.dropout) |
|
self.tier1_right = TierBlock(config.embed_dim, config.embed_dim // 2, config.num_heads, config.ff_dim // 2, config.dropout) |
|
self.gate2_left = Gate(config.embed_dim // 2, num_experts=2) |
|
self.tier2_left1 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout) |
|
self.tier2_left2 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout) |
|
self.gate2_right = Gate(config.embed_dim // 2, num_experts=2) |
|
self.tier2_right1 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout) |
|
self.tier2_right2 = TierBlock(config.embed_dim // 2, config.embed_dim // 4, config.num_heads, config.ff_dim // 4, config.dropout) |
|
self.concat_mlp = nn.Sequential( |
|
nn.Linear(config.embed_dim, config.embed_dim), |
|
nn.ReLU() |
|
) |
|
self.output_projection = nn.Linear(config.embed_dim, config.vocab_size, bias=False) |
|
|
|
self.post_init() |
|
|
|
def forward(self, input_ids, attn_mask=None, **kwargs): |
|
x = self.embedding(input_ids) |
|
x = self.initial_block(x, attn_mask=attn_mask) |
|
gate_w1 = self.gate1(x) |
|
left_out = self.tier1_left(x, attn_mask=attn_mask) |
|
right_out = self.tier1_right(x, attn_mask=attn_mask) |
|
left_weighted = left_out * gate_w1[..., 0].unsqueeze(-1) |
|
right_weighted = right_out * gate_w1[..., 1].unsqueeze(-1) |
|
gate_w2_left = self.gate2_left(left_weighted) |
|
left1_out = self.tier2_left1(left_weighted, attn_mask=attn_mask) |
|
left2_out = self.tier2_left2(left_weighted, attn_mask=attn_mask) |
|
left1_w = left1_out * gate_w2_left[..., 0].unsqueeze(-1) |
|
left2_w = left2_out * gate_w2_left[..., 1].unsqueeze(-1) |
|
left_concat = torch.cat([left1_w, left2_w], dim=-1) |
|
gate_w2_right = self.gate2_right(right_weighted) |
|
right1_out = self.tier2_right1(right_weighted, attn_mask=attn_mask) |
|
right2_out = self.tier2_right2(right_weighted, attn_mask=attn_mask) |
|
right1_w = right1_out * gate_w2_right[..., 0].unsqueeze(-1) |
|
right2_w = right2_out * gate_w2_right[..., 1].unsqueeze(-1) |
|
right_concat = torch.cat([right1_w, right2_w], dim=-1) |
|
x_cat = torch.cat([left_concat, right_concat], dim=-1) |
|
x_mlp = self.concat_mlp(x_cat) |
|
logits = self.output_projection(x_mlp) |
|
return logits |
|
|
|
def generate_text(self, tokenizer, prompt, max_length=50, temperature=1.0, top_k=40, device="cpu"): |
|
if max_length > 250: |
|
return "Max length should be less than 250" |
|
else: |
|
self.eval() |
|
self.to(device) |
|
encoding = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
input_ids = encoding["input_ids"].to(device) |
|
for _ in range(max_length): |
|
with torch.no_grad(): |
|
seq_len = input_ids.size(1) |
|
mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).unsqueeze(0) |
|
outputs = self(input_ids, mask) |
|
logits = outputs[:, -1, :] |
|
logits = logits / temperature |
|
top_k_values, top_k_indices = torch.topk(logits, top_k) |
|
probabilities = torch.nn.functional.softmax(top_k_values, dim=-1) |
|
sample_idx = torch.multinomial(probabilities, num_samples=1) |
|
next_token = top_k_indices.gather(dim=-1, index=sample_idx) |
|
next_token = next_token.squeeze(-1) |
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1) |
|
if next_token.item() == tokenizer.eos_token_id: |
|
break |
|
return tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
|
|
MTMT.register_for_auto_class("AutoModel") |
|
MTMT.register_for_auto_class("AutoModelForCausalLM") |