File size: 4,431 Bytes
db5a103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71241b9
db5a103
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForSeq2SeqLM, AutoModelForMaskedLM
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        self.h = nn.Parameter(torch.empty(n_heads, d_model // n_heads), requires_grad=True)
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, query, key, value):
        # 计算query、key和value的线性变换
        q = self.linear_q(query)
        k = self.linear_k(key)
        v = self.linear_v(value)
        # 分割成多个头
        q, k, v = q.chunk(self.n_heads, dim=-1), k.chunk(self.n_heads, dim=-1), v.chunk(self.n_heads, dim=-1)
        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.h.size(1))
        # 应用softmax函数
        weights = F.softmax(scores, dim=-1)
        # 加权求和
        outputs = torch.matmul(weights, v)
        # 反卷积操作
        outputs = outputs.transpose(1, 2).contiguous()
        # 再次应用线性变换
        outputs = self.fc(outputs)
        # 应用层归一化
        return self.norm(outputs)


class SimplifiedTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, output_dim, vocab_size, nhead, dim_feedforward, dropout=0.1):
        super(SimplifiedTransformerDecoderLayer, self).__init__()
        # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, kdim=d_model, vdim=d_model,
                                               batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.linear_vocab = nn.Linear(d_model, 1)
        self.norm1 = nn.LayerNorm(d_model)
        self.relu = nn.ReLU()

    def forward(self, tgt, tar):
        tar1 = self.self_attn(query=tar, key=tar, value=tar)[0]
        tar = tar + tar1
        tar = self.norm1(tar)
        tgt2 = self.self_attn(query=tar, key=tgt, value=tgt)[0]
        tgt3 = tar + self.dropout(tgt2)
        tgt3 = self.norm1(tgt3)

        # 线性层 + ReLU + Dropout
        tgt2 = self.linear2(self.dropout(self.relu(self.linear1(tgt3))))
        tar = tar + self.dropout(tgt2)
        return tar


class BirdModel_Attention_lstm(nn.Module):
    def __init__(self, model, key):
        super(BirdModel_Attention_lstm, self).__init__()
        self.model = model
        self.decoder = SimplifiedTransformerDecoderLayer(d_model=768, output_dim=1, vocab_size=1, nhead=8,
                                                         dim_feedforward=1024)
        self.linear_vocab = nn.Linear(1024, 2)
        self.relu = nn.ReLU()
        self.lstm = nn.LSTM(768, 128, 5,
                            bidirectional=True, batch_first=True, dropout=0.1)
        self.maxpool = nn.MaxPool1d(32)
        self.fc = nn.Linear(128 * 2 + 768, 2)

        if key:
            print("begin-------Training from scratch !!!")
            for param in self.model.parameters():
                param.requires_grad = False
        else:
            print("begin-------fine train !!!")
            for name, param in reversed(list(model.named_parameters())):
                if "cls" in name:
                    param.requires_grad = False

    def forward(self, input_ids, masks):
        outputs = self.model.bert(input_ids, attention_mask=masks).last_hidden_state
        tar = torch.zeros(outputs.size(0), 256, 768)
        tar = self.decoder(outputs, tar)
        tar = self.decoder(outputs, tar)
        tar = self.decoder(outputs, tar)
        tar = self.decoder(outputs, tar)
        out, _ = self.lstm(tar)
        out = torch.cat((tar, out), 2)
        out = self.relu(out)
        out = out.permute(0, 2, 1)
        pool = nn.MaxPool1d(out.size(-1))
        out = pool(out).squeeze()
        out = self.fc(out)
        return out