from base_model import TextClassifier import torch from transformers import pipeline class PretrainedSentimentAnalyzer(TextClassifier): def __init__(self, train_features, train_targets, test_features, test_targets, min_threshold=0.7): super().__init__(train_features, train_targets, test_features, test_targets) device = "cuda" if torch.cuda.is_available() else "cpu" self.model = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=device) self.prediction_map = {'positive' : 'positive', 'negative' : 'negative', 'neutral' : 'neutral'} self.threshold = min_threshold def train(self): pass def predict(self, text_samples:list, inverse_transform:bool, proba:bool=True) -> list: predictions = self.model(text_samples, batch_size=128) if proba: return predictions predictions = [self.prediction_map[prediction['label']] if prediction['score'] > self.threshold else 'neutral' for prediction in predictions] return predictions