Xuezha's picture
Update modeling.py
5ba2653 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configure import RecombinationTransformerConfig
class MaskedSelfAttentionLayer(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MaskedSelfAttentionLayer, self).__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
def forward(self, q, k, v, attn_mask=None):
attn_output, _ = self.multihead_attn(q, k, v, attn_mask=attn_mask)
return attn_output
class FcLayer(nn.Module):
def __init__(self, input_dim, output_dim):
super(FcLayer, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x)
class SwishGLU(nn.Module):
def __init__(self, input_dim):
super(SwishGLU, self).__init__()
self.fc1 = nn.Linear(input_dim, input_dim)
self.fc2 = nn.Linear(input_dim, input_dim)
def forward(self, x):
return torch.sigmoid(self.fc1(x)) * self.fc2(x)
class SpecialLayerF(nn.Module):
def __init__(self, input_dim):
super(SpecialLayerF, self).__init__()
self.proj_up = nn.Linear(input_dim, input_dim)
self.proj_gate = SwishGLU(input_dim)
def forward(self, o2, o3):
cross_product = o2 * o3
proj_up_output = self.proj_up(cross_product)
proj_gate_output = self.proj_gate(cross_product)
return proj_up_output * proj_gate_output
class RMSNorm(nn.Module):
def __init__(self, embed_dim, eps=1e-8):
super(RMSNorm, self).__init__()
self.embed_dim = embed_dim
self.eps = eps
self.scale = nn.Parameter(torch.ones(embed_dim))
def forward(self, x):
norm = x.norm(2, dim=-1, keepdim=True)
rms_norm = x / (norm + self.eps)
return self.scale * rms_norm
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(MLP, self).__init__()
self.up_proj = nn.Linear(input_dim, hidden_dim)
self.gate_proj = nn.Linear(input_dim, hidden_dim)
self.act = SwishGLU(hidden_dim)
self.down_proj = nn.Linear(hidden_dim, input_dim)
def forward(self, x):
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
class RecombinationTransformerLayer(nn.Module):
def __init__(self, embed_dim, num_heads):
super(RecombinationTransformerLayer, self).__init__()
self.num_heads = num_heads
# First self-attention layer
self.self_attention_1 = MaskedSelfAttentionLayer(embed_dim, num_heads)
self.fc_q = FcLayer(embed_dim, embed_dim)
self.fc_k = FcLayer(embed_dim, embed_dim)
self.fc_v = FcLayer(embed_dim, embed_dim)
# Second self-attention layer
self.self_attention_2 = MaskedSelfAttentionLayer(embed_dim, num_heads)
self.fc_qc = FcLayer(embed_dim, embed_dim)
self.fc_kb = FcLayer(embed_dim, embed_dim)
self.fc_vb = FcLayer(embed_dim, embed_dim)
# Third self-attention layer
self.self_attention_3 = MaskedSelfAttentionLayer(embed_dim, num_heads)
# Special layer F
self.special_layer_f = SpecialLayerF(embed_dim)
# MLP layer
self.mlp = MLP(embed_dim, embed_dim * 4)
self.rms_norm1 = RMSNorm(embed_dim)
self.rms_norm2 = RMSNorm(embed_dim)
def forward(self, x, attn_mask=None):
batch_size, seq_length, _ = x.size()
if attn_mask is not None:
# Reshape the attention mask to (batch_size * num_heads, seq_length, seq_length)
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, seq_length, seq_length)
# First self-attention block
q1 = self.fc_q(x).transpose(0, 1)
k1 = self.fc_k(x).transpose(0, 1)
v1 = self.fc_v(x).transpose(0, 1)
o1 = self.self_attention_1(q1, k1, v1, attn_mask=attn_mask).transpose(0, 1)
# Second self-attention block
q2 = q1
k2 = self.fc_kb(o1).transpose(0, 1)
v2 = self.fc_vb(o1).transpose(0, 1)
o2 = self.self_attention_2(q2, k2, v2, attn_mask=attn_mask).transpose(0, 1)
# Third self-attention block
q3 = self.fc_qc(o1).transpose(0, 1)
k3 = k1
v3 = v1
o3 = self.self_attention_3(q3, k3, v3, attn_mask=attn_mask).transpose(0, 1)
# Special layer F
output_f = self.special_layer_f(o2, o3) * o1
# Add & Norm
x = x + output_f
x = self.rms_norm1(x)
# MLP block
mlp_output = self.mlp(x)
# Add & Norm
x = x + mlp_output
x = self.rms_norm2(x)
return x
class RecombinationTransformerForCausalLM(PreTrainedModel):
config_class = RecombinationTransformerConfig
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim)
self.layers = nn.ModuleList([
RecombinationTransformerLayer(config.embed_dim, config.num_heads) for _ in range(config.num_layers)
])
self.final_rms_norm = RMSNorm(config.embed_dim)
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs):
if attention_mask is None:
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
batch_size, seq_length = input_ids.size()
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1)
if past_key_values is None:
past_key_values = [None] * len(self.layers)
# Embedding
x = self.embed_tokens(input_ids)
new_past_key_values = []
for i, layer in enumerate(self.layers):
past_key_value = past_key_values[i]
x = layer(x, attn_mask=causal_mask)
new_past_key_values.append(x)
# Final normalization
x = self.final_rms_norm(x)
# LM head
logits = self.lm_head(x)
if not return_dict:
return (logits, new_past_key_values)
return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is None:
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1):
logits_processor = LogitsProcessorList()
if min_length is not None:
logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id))
outputs = super().generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
logits_processor=logits_processor
)
return outputs