File size: 1,539 Bytes
a48216a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel,AutoTokenizer

class Elect(nn.Module):
    def __init__(self,args,device):
        super(Elect, self).__init__()
        self.device = device
        self.plm = AutoModel.from_pretrained(args.ckpt_dir)
        self.hidden_size = self.plm.config.hidden_size
        self.tokenizer = AutoTokenizer.from_pretrained(args.ckpt_dir)
        self.clf = nn.Linear(self.hidden_size, len(args.labels))
        self.dropout = nn.Dropout(0.3)

        self.p2l = nn.Linear(self.hidden_size,256)
        self.proj = nn.Linear(self.hidden_size*2,self.hidden_size)
        self.l2a = nn.Linear(11,256)

        self.la = nn.Parameter(torch.zeros(len(args.labels),self.hidden_size))

    def forward(self, batch):
        ids = batch['ids'].to(self.device, dtype=torch.long)
        mask = batch['mask'].to(self.device, dtype=torch.long)
        token_type_ids = batch['token_type_ids'].to(self.device, dtype=torch.long)
        hidden_state = self.plm(input_ids=ids, attention_mask=mask)[0]
        pooler = hidden_state[:, 0]  # [batch_size, hidden_size]
        pooler = self.dropout(pooler) # [batch_size, hidden_size]

        attn = torch.softmax(pooler@(self.la.transpose(0,1)),dim=-1)  # [batch_size, hidden_size]
        art = attn@self.la  # [batch_size, hidden_size]
        oa = F.relu(self.proj(torch.cat([art, pooler],dim=-1)))  # [batch_size, hidden_size]

        output = self.clf(oa)  # [batch_size, len(labels)]

        return output