|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria, MinLengthLogitsProcessor |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
from .configure import RecombinationTransformerConfig |
|
|
|
class MaskedSelfAttentionLayer(nn.Module): |
|
def __init__(self, embed_dim, num_heads): |
|
super(MaskedSelfAttentionLayer, self).__init__() |
|
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) |
|
|
|
def forward(self, q, k, v, attn_mask=None): |
|
attn_output, _ = self.multihead_attn(q, k, v, attn_mask=attn_mask) |
|
return attn_output |
|
|
|
class FcLayer(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
super(FcLayer, self).__init__() |
|
self.fc = nn.Linear(input_dim, output_dim) |
|
|
|
def forward(self, x): |
|
return self.fc(x) |
|
|
|
class SwishGLU(nn.Module): |
|
def __init__(self, input_dim): |
|
super(SwishGLU, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, input_dim) |
|
self.fc2 = nn.Linear(input_dim, input_dim) |
|
|
|
def forward(self, x): |
|
return torch.sigmoid(self.fc1(x)) * self.fc2(x) |
|
|
|
class SpecialLayerF(nn.Module): |
|
def __init__(self, input_dim): |
|
super(SpecialLayerF, self).__init__() |
|
self.proj_up = nn.Linear(input_dim, input_dim) |
|
self.proj_gate = SwishGLU(input_dim) |
|
|
|
def forward(self, o2, o3): |
|
cross_product = o2 * o3 |
|
proj_up_output = self.proj_up(cross_product) |
|
proj_gate_output = self.proj_gate(cross_product) |
|
return proj_up_output * proj_gate_output |
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, embed_dim, eps=1e-8): |
|
super(RMSNorm, self).__init__() |
|
self.embed_dim = embed_dim |
|
self.eps = eps |
|
self.scale = nn.Parameter(torch.ones(embed_dim)) |
|
|
|
def forward(self, x): |
|
norm = x.norm(2, dim=-1, keepdim=True) |
|
rms_norm = x / (norm + self.eps) |
|
return self.scale * rms_norm |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, input_dim, hidden_dim): |
|
super(MLP, self).__init__() |
|
self.up_proj = nn.Linear(input_dim, hidden_dim) |
|
self.gate_proj = nn.Linear(input_dim, hidden_dim) |
|
self.act = SwishGLU(hidden_dim) |
|
self.down_proj = nn.Linear(hidden_dim, input_dim) |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
class RecombinationTransformerLayer(nn.Module): |
|
def __init__(self, embed_dim, num_heads): |
|
super(RecombinationTransformerLayer, self).__init__() |
|
self.num_heads = num_heads |
|
|
|
|
|
self.self_attention_1 = MaskedSelfAttentionLayer(embed_dim, num_heads) |
|
self.fc_q = FcLayer(embed_dim, embed_dim) |
|
self.fc_k = FcLayer(embed_dim, embed_dim) |
|
self.fc_v = FcLayer(embed_dim, embed_dim) |
|
|
|
|
|
self.self_attention_2 = MaskedSelfAttentionLayer(embed_dim, num_heads) |
|
self.fc_qc = FcLayer(embed_dim, embed_dim) |
|
self.fc_kb = FcLayer(embed_dim, embed_dim) |
|
self.fc_vb = FcLayer(embed_dim, embed_dim) |
|
|
|
|
|
self.self_attention_3 = MaskedSelfAttentionLayer(embed_dim, num_heads) |
|
|
|
|
|
self.special_layer_f = SpecialLayerF(embed_dim) |
|
|
|
|
|
self.mlp = MLP(embed_dim, embed_dim * 4) |
|
self.rms_norm1 = RMSNorm(embed_dim) |
|
self.rms_norm2 = RMSNorm(embed_dim) |
|
|
|
def forward(self, x, attn_mask=None): |
|
batch_size, seq_length, _ = x.size() |
|
|
|
if attn_mask is not None: |
|
|
|
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, seq_length, seq_length) |
|
|
|
|
|
q1 = self.fc_q(x).transpose(0, 1) |
|
k1 = self.fc_k(x).transpose(0, 1) |
|
v1 = self.fc_v(x).transpose(0, 1) |
|
o1 = self.self_attention_1(q1, k1, v1, attn_mask=attn_mask).transpose(0, 1) |
|
|
|
|
|
q2 = q1 |
|
k2 = self.fc_kb(o1).transpose(0, 1) |
|
v2 = self.fc_vb(o1).transpose(0, 1) |
|
o2 = self.self_attention_2(q2, k2, v2, attn_mask=attn_mask).transpose(0, 1) |
|
|
|
|
|
q3 = self.fc_qc(o1).transpose(0, 1) |
|
k3 = k1 |
|
v3 = v1 |
|
o3 = self.self_attention_3(q3, k3, v3, attn_mask=attn_mask).transpose(0, 1) |
|
|
|
|
|
output_f = self.special_layer_f(o2, o3) * o1 |
|
|
|
|
|
x = x + output_f |
|
x = self.rms_norm1(x) |
|
|
|
|
|
mlp_output = self.mlp(x) |
|
|
|
|
|
x = x + mlp_output |
|
x = self.rms_norm2(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class RecombinationTransformerForCausalLM(PreTrainedModel): |
|
config_class = RecombinationTransformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim) |
|
self.layers = nn.ModuleList([ |
|
RecombinationTransformerLayer(config.embed_dim, config.num_heads) for _ in range(config.num_layers) |
|
]) |
|
self.final_rms_norm = RMSNorm(config.embed_dim) |
|
self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False) |
|
|
|
def forward(self, input_ids, attention_mask=None, past_key_values=None, return_dict=None, **kwargs): |
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_ids.shape, device=input_ids.device) |
|
|
|
batch_size, seq_length = input_ids.size() |
|
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=input_ids.device)).unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * len(self.layers) |
|
|
|
|
|
x = self.embed_tokens(input_ids) |
|
|
|
new_past_key_values = [] |
|
for i, layer in enumerate(self.layers): |
|
past_key_value = past_key_values[i] |
|
x = layer(x, attn_mask=causal_mask) |
|
new_past_key_values.append(x) |
|
|
|
|
|
x = self.final_rms_norm(x) |
|
|
|
|
|
logits = self.lm_head(x) |
|
|
|
if not return_dict: |
|
return (logits, new_past_key_values) |
|
|
|
return CausalLMOutputWithPast(logits=logits, past_key_values=new_past_key_values) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs): |
|
if past: |
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_ids.shape, device=input_ids.device) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} |
|
|
|
def generate(self, input_ids, attention_mask=None, max_length=512, min_length=None, num_return_sequences=1): |
|
logits_processor = LogitsProcessorList() |
|
|
|
if min_length is not None: |
|
logits_processor.append(MinLengthLogitsProcessor(min_length, eos_token_id=self.config.eos_token_id)) |
|
|
|
outputs = super().generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=num_return_sequences, |
|
logits_processor=logits_processor |
|
) |
|
|
|
return outputs |
|
|