Sen / app.py
pouchedfox's picture
Upload app.py
25d443b
import numpy as np
import csv
from typing import Optional
from urllib.request import urlopen
import gradio as gr
class SentimentTransform():
def __init__(
self,
model_name: str = "cardiffnlp/twitter-roberta-base-sentiment",
highlight: bool = False,
positive_sentiment_name: str = "positive",
max_number_of_shap_documents: Optional[int] = None,
min_abs_score: float = 0.1,
sensitivity: float = 0,
**kwargs,
):
"""
Sentiment Ops.
Parameters
-------------
model_name: str
The name of the model
sensitivity: float
How confident it is about being `neutral`. If you are dealing with news sources,
you probably want less sensitivity
"""
self.model_name = model_name
self.highlight = highlight
self.positive_sentiment_name = positive_sentiment_name
self.max_number_of_shap_documents = max_number_of_shap_documents
self.min_abs_score = min_abs_score
self.sensitivity = sensitivity
for k, v in kwargs.items():
setattr(self, k, v)
def preprocess(self, text: str):
new_text = []
for t in text.split(" "):
t = "@user" if t.startswith("@") and len(t) > 1 else t
t = "http" if t.startswith("http") else t
new_text.append(t)
return " ".join(new_text)
@property
def classifier(self):
if not hasattr(self, "_classifier"):
import transformers
self._classifier = transformers.pipeline(
return_all_scores=True,
model=self.model_name,
)
return self._classifier
def _get_label_mapping(self, task: str):
# Note: this is specific to the current model
labels = []
mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt"
with urlopen(mapping_link) as f:
html = f.read().decode("utf-8").split("\n")
csvreader = csv.reader(html, delimiter="\t")
labels = [row[1] for row in csvreader if len(row) > 1]
return labels
@property
def label_mapping(self):
return {"LABEL_0": "negative", "LABEL_1": "neutral", "LABEL_2": "positive"}
def analyze_sentiment(
self,
text,
highlight: bool = False,
positive_sentiment_name: str = "positive",
max_number_of_shap_documents: Optional[int] = None,
min_abs_score: float = 0.1,
):
if text is None:
return None
labels = self.classifier([str(text)], truncation=True, max_length=512)
ind_max = np.argmax([l["score"] for l in labels[0]])
sentiment = labels[0][ind_max]["label"]
max_score = labels[0][ind_max]["score"]
sentiment = self.label_mapping.get(sentiment, sentiment)
if sentiment.lower() == "neutral" and max_score > self.sensitivity:
overall_sentiment = 1e-5
elif sentiment.lower() == "neutral":
# get the next highest score
new_labels = labels[0][:ind_max] + labels[0][(ind_max + 1):]
new_ind_max = np.argmax([l["score"] for l in new_labels])
new_max_score = new_labels[new_ind_max]["score"]
new_sentiment = new_labels[new_ind_max]["label"]
new_sentiment = self.label_mapping.get(new_sentiment, new_sentiment)
overall_sentiment = self._calculate_overall_sentiment(
new_max_score, new_sentiment
)
else:
overall_sentiment = self._calculate_overall_sentiment(max_score, sentiment)
# Adjust to avoid bug
if overall_sentiment == 0:
overall_sentiment = 1e-5
if not highlight:
return {
"sentiment": sentiment,
"overall_sentiment_score": overall_sentiment,
}
shap_documents = self.get_shap_values(
text,
sentiment_ind=ind_max,
max_number_of_shap_documents=max_number_of_shap_documents,
min_abs_score=min_abs_score,
)
return {
"sentiment": sentiment,
"score": max_score,
"overall_sentiment": overall_sentiment,
"highlight_chunk_": shap_documents,
}
def _calculate_overall_sentiment(self, score: float, sentiment: str):
if sentiment.lower().strip() == self.positive_sentiment_name:
return score
else:
return -score
# def explainer(self):
# if hasattr(self, "_explainer"):
# return self._explainer
# else:
# try:
# import shap
# except ModuleNotFoundError:
# raise MissingPackageError("shap")
# self._explainer = shap.Explainer(self.classifier)
# return self._explainer
def get_shap_values(
self,
text: str,
sentiment_ind: int = 2,
max_number_of_shap_documents: Optional[int] = None,
min_abs_score: float = 0.1,
):
"""Get SHAP values"""
shap_values = self.explainer([text])
cohorts = {"": shap_values}
cohort_labels = list(cohorts.keys())
cohort_exps = list(cohorts.values())
features = cohort_exps[0].data
feature_names = cohort_exps[0].feature_names
values = np.array([cohort_exps[i].values for i in range(len(cohort_exps))])
shap_docs = [
{"text": v, "score": f}
for f, v in zip(
[x[sentiment_ind] for x in values[0][0].tolist()], feature_names[0]
)
]
if max_number_of_shap_documents is not None:
sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)
else:
sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)[
:max_number_of_shap_documents
]
return [d for d in sorted_scores if abs(d["score"]) > min_abs_score]
def transform(self, text):
# # For each document, update the field
# sentiment_docs = [{"_id": d["_id"]} for d in documents]
# for i, t in enumerate(self.text_fields):
# if self.output_fields is not None:
# output_field = self.output_fields[i]
# else:
# output_field = self._get_output_field(t)
sentiment = self.analyze_sentiment(
text,
highlight=self.highlight,
max_number_of_shap_documents=self.max_number_of_shap_documents,
min_abs_score=self.min_abs_score, )
return sentiment
def sentiment_classifier(text, model_type, sensitivity):
if model_type == 'Social Media Model':
model_name = "cardiffnlp/twitter-roberta-base-sentiment"
elif model_type == 'Survey Model':
model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
else:
model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
model = SentimentTransform(model_name=model_name, sensitivity=sensitivity)
res_dict = model.transform(text)
return res_dict['sentiment'], res_dict['overall_sentiment_score']
demo = gr.Interface(
fn=sentiment_classifier,
inputs=[gr.Textbox(placeholder="Put the text here and click 'submit' to predict its sentiment", label="Input Text"), gr.Dropdown(["Social Media Model", "Survey Model"], value="Survey Model", label="Select the Model that you want to use."), gr.Slider(0, 1, step = 0.01, label="Sensitivity (How confident it is about being `neutral`. If you are dealing with news sources, you probably want less sensitivity.)")],
outputs=[gr.Textbox(label='Sentiment'), gr.Textbox(label='Sentiment Score')],
)
demo.launch(debug=True)