Hector Lopez
Implemented checkpoint loading from CPU
d0e0bba
"""
Positivity predictor.
This module provides the functionality to predict
a tweet's positivity using a BERT model.
"""
import torch
from transformers import BertForSequenceClassification, BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=5,
output_attentions=False,
output_hidden_states=False,
local_files_only=False,
)
model.load_state_dict(torch.load("data/BERT_ft_epoch5.model", map_location='cpu'))
model.eval()
def predict_positivity(text: str) -> str:
"""
Predict the positivity of a given tweet.
Args:
text (str): Tweet's text.
Returns:
str: Predicted positivity.
"""
label_dict = {
0: "Extremely Negative",
1: "Negative",
2: "Neutral",
3: "Positive",
4: "Extremely Positive",
}
encoded = tokenizer(text, return_tensors="pt")
logits = model(**encoded).logits
predicted_class_id = logits.argmax().item()
return label_dict[predicted_class_id]