import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForSeq2SeqLM, AutoModelForMaskedLM from torch.nn import TransformerDecoder, TransformerDecoderLayer import math class MultiHeadAttention(nn.Module): def __init__(self, n_heads, d_model, dropout=0.1): super(MultiHeadAttention, self).__init__() self.n_heads = n_heads self.d_model = d_model self.dropout = nn.Dropout(p=dropout) self.h = nn.Parameter(torch.empty(n_heads, d_model // n_heads), requires_grad=True) self.linear_q = nn.Linear(d_model, d_model) self.linear_k = nn.Linear(d_model, d_model) self.linear_v = nn.Linear(d_model, d_model) self.norm = nn.LayerNorm(d_model) self.fc = nn.Linear(d_model, d_model) def forward(self, query, key, value): # 计算query、key和value的线性变换 q = self.linear_q(query) k = self.linear_k(key) v = self.linear_v(value) # 分割成多个头 q, k, v = q.chunk(self.n_heads, dim=-1), k.chunk(self.n_heads, dim=-1), v.chunk(self.n_heads, dim=-1) # 计算注意力分数 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.h.size(1)) # 应用softmax函数 weights = F.softmax(scores, dim=-1) # 加权求和 outputs = torch.matmul(weights, v) # 反卷积操作 outputs = outputs.transpose(1, 2).contiguous() # 再次应用线性变换 outputs = self.fc(outputs) # 应用层归一化 return self.norm(outputs) class SimplifiedTransformerDecoderLayer(nn.Module): def __init__(self, d_model, output_dim, vocab_size, nhead, dim_feedforward, dropout=0.1): super(SimplifiedTransformerDecoderLayer, self).__init__() # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, kdim=d_model, vdim=d_model, batch_first=True) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.linear_vocab = nn.Linear(d_model, 1) self.norm1 = nn.LayerNorm(d_model) self.relu = nn.ReLU() def forward(self, tgt, tar): tar1 = self.self_attn(query=tar, key=tar, value=tar)[0] tar = tar + tar1 tar = self.norm1(tar) tgt2 = self.self_attn(query=tar, key=tgt, value=tgt)[0] tgt3 = tar + self.dropout(tgt2) tgt3 = self.norm1(tgt3) # 线性层 + ReLU + Dropout tgt2 = self.linear2(self.dropout(self.relu(self.linear1(tgt3)))) tar = tar + self.dropout(tgt2) return tar class BirdModel_Attention_lstm(nn.Module): def __init__(self, model, key): super(BirdModel_Attention_lstm, self).__init__() self.model = model self.decoder = SimplifiedTransformerDecoderLayer(d_model=768, output_dim=1, vocab_size=1, nhead=8, dim_feedforward=1024) self.linear_vocab = nn.Linear(1024, 2) self.relu = nn.ReLU() self.lstm = nn.LSTM(768, 128, 5, bidirectional=True, batch_first=True, dropout=0.1) self.maxpool = nn.MaxPool1d(32) self.fc = nn.Linear(128 * 2 + 768, 2) if key: print("begin-------Training from scratch !!!") for param in self.model.parameters(): param.requires_grad = False else: print("begin-------fine train !!!") for name, param in reversed(list(model.named_parameters())): if "cls" in name: param.requires_grad = False def forward(self, input_ids, masks): outputs = self.model.bert(input_ids, attention_mask=masks).last_hidden_state tar = torch.zeros(outputs.size(0), 256, 768) tar = self.decoder(outputs, tar) tar = self.decoder(outputs, tar) tar = self.decoder(outputs, tar) tar = self.decoder(outputs, tar) out, _ = self.lstm(tar) out = torch.cat((tar, out), 2) out = self.relu(out) out = out.permute(0, 2, 1) pool = nn.MaxPool1d(out.size(-1)) out = pool(out).squeeze() out = self.fc(out) return out