joonkim's picture
Update model.py
c4589ef
raw
history blame
572 Bytes
import torch
import torch.nn as nn
class SentimentClassifier(nn.Module) :
def __init__(self, n_classes) :
super(SentimentClassifier, self).__init__()
self.BERT = BertModel.from_pretrained("bert-base-cased")
self.dropout = nn.Dropout(p=0.3)
self.fc = nn.Linear(self.BERT.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask) :
_, pooled_output = self.BERT(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=False
)
output = self.dropout(pooled_output)
return self.fc(output)