|
import torch |
|
import torch.nn as nn |
|
|
|
class SentimentClassifier(nn.Module) : |
|
|
|
def __init__(self, n_classes) : |
|
|
|
super(SentimentClassifier, self).__init__() |
|
|
|
self.BERT = BertModel.from_pretrained("bert-base-cased") |
|
self.dropout = nn.Dropout(p=0.3) |
|
self.fc = nn.Linear(self.BERT.config.hidden_size, n_classes) |
|
|
|
def forward(self, input_ids, attention_mask) : |
|
_, pooled_output = self.BERT( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=False |
|
) |
|
output = self.dropout(pooled_output) |
|
return self.fc(output) |