File size: 572 Bytes
fa892a6
 
 
0d5a432
 
 
 
 
 
c4589ef
0d5a432
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn

class SentimentClassifier(nn.Module) :

  def __init__(self, n_classes) :

    super(SentimentClassifier, self).__init__()
    
    self.BERT = BertModel.from_pretrained("bert-base-cased")
    self.dropout = nn.Dropout(p=0.3)
    self.fc = nn.Linear(self.BERT.config.hidden_size, n_classes)

  def forward(self, input_ids, attention_mask) :
    _, pooled_output = self.BERT(
      input_ids=input_ids,
      attention_mask=attention_mask,
      return_dict=False
    )
    output = self.dropout(pooled_output)
    return self.fc(output)