Jinyi-Guard / model.py
changsr's picture
Update model.py
71241b9 verified
raw
history blame contribute delete
No virus
4.43 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSeq2SeqLM, AutoModelForMaskedLM
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import math
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, d_model, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.d_model = d_model
self.dropout = nn.Dropout(p=dropout)
self.h = nn.Parameter(torch.empty(n_heads, d_model // n_heads), requires_grad=True)
self.linear_q = nn.Linear(d_model, d_model)
self.linear_k = nn.Linear(d_model, d_model)
self.linear_v = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(d_model, d_model)
def forward(self, query, key, value):
# 计算query、key和value的线性变换
q = self.linear_q(query)
k = self.linear_k(key)
v = self.linear_v(value)
# 分割成多个头
q, k, v = q.chunk(self.n_heads, dim=-1), k.chunk(self.n_heads, dim=-1), v.chunk(self.n_heads, dim=-1)
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.h.size(1))
# 应用softmax函数
weights = F.softmax(scores, dim=-1)
# 加权求和
outputs = torch.matmul(weights, v)
# 反卷积操作
outputs = outputs.transpose(1, 2).contiguous()
# 再次应用线性变换
outputs = self.fc(outputs)
# 应用层归一化
return self.norm(outputs)
class SimplifiedTransformerDecoderLayer(nn.Module):
def __init__(self, d_model, output_dim, vocab_size, nhead, dim_feedforward, dropout=0.1):
super(SimplifiedTransformerDecoderLayer, self).__init__()
# self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, kdim=d_model, vdim=d_model,
batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.linear_vocab = nn.Linear(d_model, 1)
self.norm1 = nn.LayerNorm(d_model)
self.relu = nn.ReLU()
def forward(self, tgt, tar):
tar1 = self.self_attn(query=tar, key=tar, value=tar)[0]
tar = tar + tar1
tar = self.norm1(tar)
tgt2 = self.self_attn(query=tar, key=tgt, value=tgt)[0]
tgt3 = tar + self.dropout(tgt2)
tgt3 = self.norm1(tgt3)
# 线性层 + ReLU + Dropout
tgt2 = self.linear2(self.dropout(self.relu(self.linear1(tgt3))))
tar = tar + self.dropout(tgt2)
return tar
class BirdModel_Attention_lstm(nn.Module):
def __init__(self, model, key):
super(BirdModel_Attention_lstm, self).__init__()
self.model = model
self.decoder = SimplifiedTransformerDecoderLayer(d_model=768, output_dim=1, vocab_size=1, nhead=8,
dim_feedforward=1024)
self.linear_vocab = nn.Linear(1024, 2)
self.relu = nn.ReLU()
self.lstm = nn.LSTM(768, 128, 5,
bidirectional=True, batch_first=True, dropout=0.1)
self.maxpool = nn.MaxPool1d(32)
self.fc = nn.Linear(128 * 2 + 768, 2)
if key:
print("begin-------Training from scratch !!!")
for param in self.model.parameters():
param.requires_grad = False
else:
print("begin-------fine train !!!")
for name, param in reversed(list(model.named_parameters())):
if "cls" in name:
param.requires_grad = False
def forward(self, input_ids, masks):
outputs = self.model.bert(input_ids, attention_mask=masks).last_hidden_state
tar = torch.zeros(outputs.size(0), 256, 768)
tar = self.decoder(outputs, tar)
tar = self.decoder(outputs, tar)
tar = self.decoder(outputs, tar)
tar = self.decoder(outputs, tar)
out, _ = self.lstm(tar)
out = torch.cat((tar, out), 2)
out = self.relu(out)
out = out.permute(0, 2, 1)
pool = nn.MaxPool1d(out.size(-1))
out = pool(out).squeeze()
out = self.fc(out)
return out