luciusssss's picture
Upload 22 files
a48216a verified
raw
history blame
1.54 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel,AutoTokenizer
class Elect(nn.Module):
def __init__(self,args,device):
super(Elect, self).__init__()
self.device = device
self.plm = AutoModel.from_pretrained(args.ckpt_dir)
self.hidden_size = self.plm.config.hidden_size
self.tokenizer = AutoTokenizer.from_pretrained(args.ckpt_dir)
self.clf = nn.Linear(self.hidden_size, len(args.labels))
self.dropout = nn.Dropout(0.3)
self.p2l = nn.Linear(self.hidden_size,256)
self.proj = nn.Linear(self.hidden_size*2,self.hidden_size)
self.l2a = nn.Linear(11,256)
self.la = nn.Parameter(torch.zeros(len(args.labels),self.hidden_size))
def forward(self, batch):
ids = batch['ids'].to(self.device, dtype=torch.long)
mask = batch['mask'].to(self.device, dtype=torch.long)
token_type_ids = batch['token_type_ids'].to(self.device, dtype=torch.long)
hidden_state = self.plm(input_ids=ids, attention_mask=mask)[0]
pooler = hidden_state[:, 0] # [batch_size, hidden_size]
pooler = self.dropout(pooler) # [batch_size, hidden_size]
attn = torch.softmax(pooler@(self.la.transpose(0,1)),dim=-1) # [batch_size, hidden_size]
art = attn@self.la # [batch_size, hidden_size]
oa = F.relu(self.proj(torch.cat([art, pooler],dim=-1))) # [batch_size, hidden_size]
output = self.clf(oa) # [batch_size, len(labels)]
return output