medhabi's picture
add model
a756e26
AdamW(model.parameters(), lr=5e-5)
class BERTTextToRating(PreTrainedModel):
config_class = BERTTextToRatingConfig
def __init__(self, config):
super(BERTTextToRating, self).__init__(config)
model_checkpoint = "medhabi/distilbert-base-uncased-mlm-ta-local"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
self.bert_model = model.distilbert
self.dropout = torch.nn.Dropout(0.3)
self.linear = torch.nn.Linear(768, 5)
def forward(self, input_ids, attention_mask, token_type_ids):
output = self.bert_model(
input_ids,
attention_mask=attention_mask,
)
output_dropout = self.dropout(output.last_hidden_state[0][0].reshape(1,-1))
output = self.linear(output_dropout)
return output