File size: 886 Bytes
b1edabb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# modeling_phobert_attn.py
import torch
import torch.nn as nn
from transformers import AutoModel

class PhoBERT_Attention(nn.Module):
    def __init__(self, num_classes=2, dropout=0.3):
        super().__init__()
        self.xlm_roberta = AutoModel.from_pretrained("vinai/phobert-base")
        hidden = self.xlm_roberta.config.hidden_size
        self.attention = nn.Linear(hidden, 1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden, num_classes)

    def forward(self, input_ids, attention_mask):
        out = self.xlm_roberta(input_ids=input_ids, attention_mask=attention_mask)
        H = out.last_hidden_state           # [B, T, H]
        attn = torch.softmax(self.attention(H), dim=1)  # [B, T, 1]
        ctx  = (attn * H).sum(dim=1)        # [B, H]
        logits = self.fc(self.dropout(ctx)) # [B, C]
        return logits