|
from base_model import TextClassifier |
|
import torch |
|
from transformers import pipeline |
|
|
|
class PretrainedSentimentAnalyzer(TextClassifier): |
|
|
|
def __init__(self, train_features, train_targets, test_features, test_targets, min_threshold=0.7): |
|
|
|
super().__init__(train_features, train_targets, test_features, test_targets) |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.model = pipeline("text-classification", |
|
model="cardiffnlp/twitter-roberta-base-sentiment-latest", |
|
device=device) |
|
|
|
self.prediction_map = {'positive' : 'positive', |
|
'negative' : 'negative', |
|
'neutral' : 'neutral'} |
|
|
|
self.threshold = min_threshold |
|
|
|
def train(self): |
|
pass |
|
|
|
def predict(self, text_samples:list, inverse_transform:bool, proba:bool=True) -> list: |
|
|
|
predictions = self.model(text_samples, batch_size=128) |
|
if proba: |
|
return predictions |
|
|
|
predictions = [self.prediction_map[prediction['label']] if prediction['score'] > self.threshold else 'neutral' |
|
for prediction in predictions] |
|
|
|
return predictions |
|
|
|
|