import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel,AutoTokenizer class Elect(nn.Module): def __init__(self,args,device): super(Elect, self).__init__() self.device = device self.plm = AutoModel.from_pretrained(args.ckpt_dir) self.hidden_size = self.plm.config.hidden_size self.tokenizer = AutoTokenizer.from_pretrained(args.ckpt_dir) self.clf = nn.Linear(self.hidden_size, len(args.labels)) self.dropout = nn.Dropout(0.3) self.p2l = nn.Linear(self.hidden_size,256) self.proj = nn.Linear(self.hidden_size*2,self.hidden_size) self.l2a = nn.Linear(11,256) self.la = nn.Parameter(torch.zeros(len(args.labels),self.hidden_size)) def forward(self, batch): ids = batch['ids'].to(self.device, dtype=torch.long) mask = batch['mask'].to(self.device, dtype=torch.long) token_type_ids = batch['token_type_ids'].to(self.device, dtype=torch.long) hidden_state = self.plm(input_ids=ids, attention_mask=mask)[0] pooler = hidden_state[:, 0] # [batch_size, hidden_size] pooler = self.dropout(pooler) # [batch_size, hidden_size] attn = torch.softmax(pooler@(self.la.transpose(0,1)),dim=-1) # [batch_size, hidden_size] art = attn@self.la # [batch_size, hidden_size] oa = F.relu(self.proj(torch.cat([art, pooler],dim=-1))) # [batch_size, hidden_size] output = self.clf(oa) # [batch_size, len(labels)] return output