import torch import torch.nn as nn class SentimentClassifier(nn.Module) : def __init__(self, n_classes) : super(SentimentClassifier, self).__init__() self.BERT = BertModel.from_pretrained("bert-base-cased") self.dropout = nn.Dropout(p=0.3) self.fc = nn.Linear(self.BERT.config.hidden_size, n_classes) def forward(self, input_ids, attention_mask) : _, pooled_output = self.BERT( input_ids=input_ids, attention_mask=attention_mask, return_dict=False ) output = self.dropout(pooled_output) return self.fc(output)