File size: 1,969 Bytes
7f9da02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch.nn as nn
from transformers import BertModel
class BertClassifier(nn.Module):
"""Bert Model for Classification Tasks.
"""
def __init__(self, freeze_bert=False):
"""
@param bert: a BertModel object
@param classifier: a torch.nn.Module classifier
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
"""
super(BertClassifier, self).__init__()
# hidden size of BERT, hidden size of our classifier, number of labels
D_in, H, D_out = 768, 50, 2
# Instantiate BERT model
self.bert = BertModel.from_pretrained('aubmindlab/bert-base-arabertv02')
# Instantiate an one-layer feed-forward classifier
self.classifier = nn.Sequential(
nn.Linear(D_in, H),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(H, D_out)
)
# Freeze the BERT model
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
"""
Feed input to BERT and the classifier to compute logits.
@param input_ids (torch.Tensor): an input tensor with shape (batch_size,
max_length)
@param attention_mask (torch.Tensor): a tensor that hold attention mask
information with shape (batch_size, max_length)
@return logits (torch.Tensor): an output tensor with shape (batch_size,
num_labels)
"""
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
# Extract the last hidden state of the token `[CLS]` for classification task and feed them to classifier to compute logits
last_hidden_state_cls = outputs[0][:, 0, :]
logits = self.classifier(last_hidden_state_cls)
return logits |