|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import AutoModel |
|
|
|
|
|
class PhoBERTMultiTask(nn.Module): |
|
|
def __init__(self, num_sentiment=3, num_topic=4): |
|
|
super().__init__() |
|
|
self.phobert = AutoModel.from_pretrained("vinai/phobert-base") |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
|
|
|
self.sentiment_head = nn.Sequential( |
|
|
nn.Linear(768, 768), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(768, num_sentiment) |
|
|
) |
|
|
|
|
|
self.topic_head = nn.Sequential( |
|
|
nn.Linear(768, 768), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(768, num_topic) |
|
|
) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
outputs = self.phobert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pooled = outputs.last_hidden_state[:, 0, :] |
|
|
pooled = self.dropout(pooled) |
|
|
logits_sent = self.sentiment_head(pooled) |
|
|
logits_topic = self.topic_head(pooled) |
|
|
return logits_sent, logits_topic |
|
|
|