|
import numpy as np |
|
import csv |
|
from typing import Optional |
|
from urllib.request import urlopen |
|
import gradio as gr |
|
|
|
|
|
class SentimentTransform(): |
|
def __init__( |
|
self, |
|
model_name: str = "cardiffnlp/twitter-roberta-base-sentiment", |
|
highlight: bool = False, |
|
positive_sentiment_name: str = "positive", |
|
max_number_of_shap_documents: Optional[int] = None, |
|
min_abs_score: float = 0.1, |
|
sensitivity: float = 0, |
|
**kwargs, |
|
): |
|
""" |
|
Sentiment Ops. |
|
Parameters |
|
------------- |
|
model_name: str |
|
The name of the model |
|
sensitivity: float |
|
How confident it is about being `neutral`. If you are dealing with news sources, |
|
you probably want less sensitivity |
|
""" |
|
self.model_name = model_name |
|
self.highlight = highlight |
|
self.positive_sentiment_name = positive_sentiment_name |
|
self.max_number_of_shap_documents = max_number_of_shap_documents |
|
self.min_abs_score = min_abs_score |
|
self.sensitivity = sensitivity |
|
for k, v in kwargs.items(): |
|
setattr(self, k, v) |
|
|
|
def preprocess(self, text: str): |
|
new_text = [] |
|
for t in text.split(" "): |
|
t = "@user" if t.startswith("@") and len(t) > 1 else t |
|
t = "http" if t.startswith("http") else t |
|
new_text.append(t) |
|
return " ".join(new_text) |
|
|
|
@property |
|
def classifier(self): |
|
if not hasattr(self, "_classifier"): |
|
import transformers |
|
|
|
self._classifier = transformers.pipeline( |
|
return_all_scores=True, |
|
model=self.model_name, |
|
) |
|
return self._classifier |
|
|
|
def _get_label_mapping(self, task: str): |
|
|
|
labels = [] |
|
mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt" |
|
with urlopen(mapping_link) as f: |
|
html = f.read().decode("utf-8").split("\n") |
|
csvreader = csv.reader(html, delimiter="\t") |
|
labels = [row[1] for row in csvreader if len(row) > 1] |
|
return labels |
|
|
|
@property |
|
def label_mapping(self): |
|
return {"LABEL_0": "negative", "LABEL_1": "neutral", "LABEL_2": "positive"} |
|
|
|
def analyze_sentiment( |
|
self, |
|
text, |
|
highlight: bool = False, |
|
positive_sentiment_name: str = "positive", |
|
max_number_of_shap_documents: Optional[int] = None, |
|
min_abs_score: float = 0.1, |
|
): |
|
if text is None: |
|
return None |
|
labels = self.classifier([str(text)], truncation=True, max_length=512) |
|
ind_max = np.argmax([l["score"] for l in labels[0]]) |
|
sentiment = labels[0][ind_max]["label"] |
|
max_score = labels[0][ind_max]["score"] |
|
sentiment = self.label_mapping.get(sentiment, sentiment) |
|
if sentiment.lower() == "neutral" and max_score > self.sensitivity: |
|
overall_sentiment = 1e-5 |
|
elif sentiment.lower() == "neutral": |
|
|
|
new_labels = labels[0][:ind_max] + labels[0][(ind_max + 1):] |
|
new_ind_max = np.argmax([l["score"] for l in new_labels]) |
|
new_max_score = new_labels[new_ind_max]["score"] |
|
new_sentiment = new_labels[new_ind_max]["label"] |
|
new_sentiment = self.label_mapping.get(new_sentiment, new_sentiment) |
|
overall_sentiment = self._calculate_overall_sentiment( |
|
new_max_score, new_sentiment |
|
) |
|
|
|
else: |
|
overall_sentiment = self._calculate_overall_sentiment(max_score, sentiment) |
|
|
|
if overall_sentiment == 0: |
|
overall_sentiment = 1e-5 |
|
if not highlight: |
|
return { |
|
"sentiment": sentiment, |
|
"overall_sentiment_score": overall_sentiment, |
|
} |
|
shap_documents = self.get_shap_values( |
|
text, |
|
sentiment_ind=ind_max, |
|
max_number_of_shap_documents=max_number_of_shap_documents, |
|
min_abs_score=min_abs_score, |
|
) |
|
return { |
|
"sentiment": sentiment, |
|
"score": max_score, |
|
"overall_sentiment": overall_sentiment, |
|
"highlight_chunk_": shap_documents, |
|
} |
|
|
|
def _calculate_overall_sentiment(self, score: float, sentiment: str): |
|
if sentiment.lower().strip() == self.positive_sentiment_name: |
|
return score |
|
else: |
|
return -score |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_shap_values( |
|
self, |
|
text: str, |
|
sentiment_ind: int = 2, |
|
max_number_of_shap_documents: Optional[int] = None, |
|
min_abs_score: float = 0.1, |
|
): |
|
"""Get SHAP values""" |
|
shap_values = self.explainer([text]) |
|
cohorts = {"": shap_values} |
|
cohort_labels = list(cohorts.keys()) |
|
cohort_exps = list(cohorts.values()) |
|
features = cohort_exps[0].data |
|
feature_names = cohort_exps[0].feature_names |
|
values = np.array([cohort_exps[i].values for i in range(len(cohort_exps))]) |
|
shap_docs = [ |
|
{"text": v, "score": f} |
|
for f, v in zip( |
|
[x[sentiment_ind] for x in values[0][0].tolist()], feature_names[0] |
|
) |
|
] |
|
if max_number_of_shap_documents is not None: |
|
sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True) |
|
else: |
|
sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)[ |
|
:max_number_of_shap_documents |
|
] |
|
return [d for d in sorted_scores if abs(d["score"]) > min_abs_score] |
|
|
|
def transform(self, text): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentiment = self.analyze_sentiment( |
|
text, |
|
highlight=self.highlight, |
|
max_number_of_shap_documents=self.max_number_of_shap_documents, |
|
min_abs_score=self.min_abs_score, ) |
|
return sentiment |
|
|
|
|
|
def sentiment_classifier(text, model_type, sensitivity): |
|
if model_type == 'Social Media Model': |
|
model_name = "cardiffnlp/twitter-roberta-base-sentiment" |
|
elif model_type == 'Survey Model': |
|
model_name = "j-hartmann/sentiment-roberta-large-english-3-classes" |
|
else: |
|
model_name = "j-hartmann/sentiment-roberta-large-english-3-classes" |
|
model = SentimentTransform(model_name=model_name, sensitivity=sensitivity) |
|
res_dict = model.transform(text) |
|
return res_dict['sentiment'], res_dict['overall_sentiment_score'] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=sentiment_classifier, |
|
inputs=[gr.Textbox(placeholder="Put the text here and click 'submit' to predict its sentiment", label="Input Text"), gr.Dropdown(["Social Media Model", "Survey Model"], value="Survey Model", label="Select the Model that you want to use."), gr.Slider(0, 1, step = 0.01, label="Sensitivity (How confident it is about being `neutral`. If you are dealing with news sources, you probably want less sensitivity.)")], |
|
outputs=[gr.Textbox(label='Sentiment'), gr.Textbox(label='Sentiment Score')], |
|
) |
|
demo.launch(debug=True) |