|
AdamW(model.parameters(), lr=5e-5) |
|
|
|
class BERTTextToRating(PreTrainedModel): |
|
config_class = BERTTextToRatingConfig |
|
|
|
def __init__(self, config): |
|
super(BERTTextToRating, self).__init__(config) |
|
model_checkpoint = "medhabi/distilbert-base-uncased-mlm-ta-local" |
|
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) |
|
self.bert_model = model.distilbert |
|
self.dropout = torch.nn.Dropout(0.3) |
|
self.linear = torch.nn.Linear(768, 5) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
output = self.bert_model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
output_dropout = self.dropout(output.last_hidden_state[0][0].reshape(1,-1)) |
|
output = self.linear(output_dropout) |
|
return output |