|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModel,AutoTokenizer |
|
|
|
class Elect(nn.Module): |
|
def __init__(self,args,device): |
|
super(Elect, self).__init__() |
|
self.device = device |
|
self.plm = AutoModel.from_pretrained(args.ckpt_dir) |
|
self.hidden_size = self.plm.config.hidden_size |
|
self.tokenizer = AutoTokenizer.from_pretrained(args.ckpt_dir) |
|
self.clf = nn.Linear(self.hidden_size, len(args.labels)) |
|
self.dropout = nn.Dropout(0.3) |
|
|
|
self.p2l = nn.Linear(self.hidden_size,256) |
|
self.proj = nn.Linear(self.hidden_size*2,self.hidden_size) |
|
self.l2a = nn.Linear(11,256) |
|
|
|
self.la = nn.Parameter(torch.zeros(len(args.labels),self.hidden_size)) |
|
|
|
def forward(self, batch): |
|
ids = batch['ids'].to(self.device, dtype=torch.long) |
|
mask = batch['mask'].to(self.device, dtype=torch.long) |
|
token_type_ids = batch['token_type_ids'].to(self.device, dtype=torch.long) |
|
hidden_state = self.plm(input_ids=ids, attention_mask=mask)[0] |
|
pooler = hidden_state[:, 0] |
|
pooler = self.dropout(pooler) |
|
|
|
attn = torch.softmax(pooler@(self.la.transpose(0,1)),dim=-1) |
|
art = attn@self.la |
|
oa = F.relu(self.proj(torch.cat([art, pooler],dim=-1))) |
|
|
|
output = self.clf(oa) |
|
|
|
return output |