Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoModelForSeq2SeqLM, AutoModelForMaskedLM | |
from torch.nn import TransformerDecoder, TransformerDecoderLayer | |
import math | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, n_heads, d_model, dropout=0.1): | |
super(MultiHeadAttention, self).__init__() | |
self.n_heads = n_heads | |
self.d_model = d_model | |
self.dropout = nn.Dropout(p=dropout) | |
self.h = nn.Parameter(torch.empty(n_heads, d_model // n_heads), requires_grad=True) | |
self.linear_q = nn.Linear(d_model, d_model) | |
self.linear_k = nn.Linear(d_model, d_model) | |
self.linear_v = nn.Linear(d_model, d_model) | |
self.norm = nn.LayerNorm(d_model) | |
self.fc = nn.Linear(d_model, d_model) | |
def forward(self, query, key, value): | |
# 计算query、key和value的线性变换 | |
q = self.linear_q(query) | |
k = self.linear_k(key) | |
v = self.linear_v(value) | |
# 分割成多个头 | |
q, k, v = q.chunk(self.n_heads, dim=-1), k.chunk(self.n_heads, dim=-1), v.chunk(self.n_heads, dim=-1) | |
# 计算注意力分数 | |
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.h.size(1)) | |
# 应用softmax函数 | |
weights = F.softmax(scores, dim=-1) | |
# 加权求和 | |
outputs = torch.matmul(weights, v) | |
# 反卷积操作 | |
outputs = outputs.transpose(1, 2).contiguous() | |
# 再次应用线性变换 | |
outputs = self.fc(outputs) | |
# 应用层归一化 | |
return self.norm(outputs) | |
class SimplifiedTransformerDecoderLayer(nn.Module): | |
def __init__(self, d_model, output_dim, vocab_size, nhead, dim_feedforward, dropout=0.1): | |
super(SimplifiedTransformerDecoderLayer, self).__init__() | |
# self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, kdim=d_model, vdim=d_model, | |
batch_first=True) | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.linear_vocab = nn.Linear(d_model, 1) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.relu = nn.ReLU() | |
def forward(self, tgt, tar): | |
tar1 = self.self_attn(query=tar, key=tar, value=tar)[0] | |
tar = tar + tar1 | |
tar = self.norm1(tar) | |
tgt2 = self.self_attn(query=tar, key=tgt, value=tgt)[0] | |
tgt3 = tar + self.dropout(tgt2) | |
tgt3 = self.norm1(tgt3) | |
# 线性层 + ReLU + Dropout | |
tgt2 = self.linear2(self.dropout(self.relu(self.linear1(tgt3)))) | |
tar = tar + self.dropout(tgt2) | |
return tar | |
class BirdModel_Attention_lstm(nn.Module): | |
def __init__(self, model, key): | |
super(BirdModel_Attention_lstm, self).__init__() | |
self.model = model | |
self.decoder = SimplifiedTransformerDecoderLayer(d_model=768, output_dim=1, vocab_size=1, nhead=8, | |
dim_feedforward=1024) | |
self.linear_vocab = nn.Linear(1024, 2) | |
self.relu = nn.ReLU() | |
self.lstm = nn.LSTM(768, 128, 5, | |
bidirectional=True, batch_first=True, dropout=0.1) | |
self.maxpool = nn.MaxPool1d(32) | |
self.fc = nn.Linear(128 * 2 + 768, 2) | |
if key: | |
print("begin-------Training from scratch !!!") | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
else: | |
print("begin-------fine train !!!") | |
for name, param in reversed(list(model.named_parameters())): | |
if "cls" in name: | |
param.requires_grad = False | |
def forward(self, input_ids, masks): | |
outputs = self.model.bert(input_ids, attention_mask=masks).last_hidden_state | |
tar = torch.zeros(outputs.size(0), 256, 768) | |
tar = self.decoder(outputs, tar) | |
tar = self.decoder(outputs, tar) | |
tar = self.decoder(outputs, tar) | |
tar = self.decoder(outputs, tar) | |
out, _ = self.lstm(tar) | |
out = torch.cat((tar, out), 2) | |
out = self.relu(out) | |
out = out.permute(0, 2, 1) | |
pool = nn.MaxPool1d(out.size(-1)) | |
out = pool(out).squeeze() | |
out = self.fc(out) | |
return out | |