Spaces:
Running
Running
import pandas as pd | |
import re | |
import torch | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from transformers import pipeline | |
cached_df = None | |
cached_file_name = None | |
# Load sentiment pipeline | |
sentiment_pipeline = pipeline( | |
"text-classification", | |
model="pvaluedotone/bigbird-flight-2", | |
tokenizer="pvaluedotone/bigbird-flight-2", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
# Contractions dictionary | |
contractions_dict = { | |
"don't": "do not", "can't": "cannot", "i'm": "i am", "it's": "it is", | |
"he's": "he is", "she's": "she is", "they're": "they are", "we're": "we are", | |
"you're": "you are", "that's": "that is", "there's": "there is", "what's": "what is", | |
"won't": "will not", "isn't": "is not", "aren't": "are not", "wasn't": "was not", | |
"weren't": "were not", "didn't": "did not", "doesn't": "does not", "haven't": "have not", | |
"hasn't": "has not", "hadn't": "had not", "wouldn't": "would not", "shouldn't": "should not", | |
"couldn't": "could not", "mustn't": "must not", "let's": "let us" | |
} | |
contractions_pattern = re.compile(r"\b(" + "|".join(re.escape(k) for k in contractions_dict.keys()) + r")\b") | |
def expand_contractions(text: str) -> str: | |
def replace(match): | |
return contractions_dict[match.group(0)] | |
return contractions_pattern.sub(replace, text) | |
# Emoticon mapping | |
emoticon_dict = { | |
":)": "smile", ":-)": "smile", ":(": "sad", ":-(": "sad", | |
";)": "wink", ";-)": "wink", ":d": "laugh", ":-d": "laugh", | |
":p": "playful", ":-p": "playful", ":'(": "cry", ":/": "skeptical", | |
":'-)": "tears_of_joy" | |
} | |
def clean_text(text: str) -> str: | |
if not isinstance(text, str): | |
return "" | |
text = re.sub(r"http\S+|@\w+", "", text) | |
text = expand_contractions(text) | |
try: | |
import emoji | |
text = emoji.demojize(text) | |
except ImportError: | |
pass | |
for emoticon, desc in emoticon_dict.items(): | |
text = text.replace(emoticon, f" {desc} ") | |
text = re.sub(r"#(\w+)", r"\1", text) | |
text = re.sub(r"\s+", " ", text).strip() | |
return text | |
def predict_sentiment(texts): | |
results = sentiment_pipeline(texts, truncation=False, batch_size=32) | |
sentiments = [] | |
confidences = [] | |
for r in results: | |
label_num = int(r['label'].split('_')[-1]) | |
sentiments.append(label_num) | |
confidences.append(r['score']) | |
return sentiments, confidences | |
def recategorize(labels, mode, pos_threshold, neg_threshold): | |
if mode == "Original (1β10)": | |
return labels | |
elif mode == "Binary (Positive vs Negative)": | |
return ["Positive" if lbl >= pos_threshold else "Negative" for lbl in labels] | |
elif mode == "Ternary (Pos/Neu/Neg)": | |
return [ | |
"Positive" if lbl >= pos_threshold else | |
"Negative" if lbl <= neg_threshold else | |
"Neutral" for lbl in labels | |
] | |
def analyze_sentiment(file, text_column, mode, pos_thresh, neg_thresh, auto_fix, apply_cleaning): | |
global cached_df, cached_file_name | |
try: | |
df = pd.read_csv(file.name) | |
except Exception as e: | |
return f"Error reading CSV file: {e}", None, None, None, None, None | |
if text_column not in df.columns: | |
return "Selected column not found.", None, None, None, None, None | |
if ( | |
cached_df is not None and | |
cached_file_name == file.name and | |
"sentiment_1to10" in cached_df.columns and | |
"confidence" in cached_df.columns | |
): | |
df = cached_df.copy() | |
else: | |
if apply_cleaning: | |
df["processed_text"] = df[text_column].apply(clean_text) | |
else: | |
df["processed_text"] = df[text_column].astype(str) | |
predictions, confidences = predict_sentiment(df["processed_text"].tolist()) | |
df["sentiment_1to10"] = predictions | |
df["confidence"] = confidences | |
cached_df = df.copy() | |
cached_file_name = file.name | |
if mode == "Ternary (Pos/Neu/Neg)": | |
if pos_thresh <= neg_thresh: | |
if auto_fix: | |
neg_thresh = pos_thresh - 1 | |
if neg_thresh < 1: | |
return "β οΈ Unable to auto-correct: thresholds out of valid range (1β10).", None, None, None, None, None | |
else: | |
return ( | |
f"β οΈ Invalid thresholds: Positive min ({pos_thresh}) must be greater than Negative max ({neg_thresh}).", | |
None, None, None, None, None | |
) | |
df["sentiment_recategorised"] = recategorize(df["sentiment_1to10"], mode, pos_thresh, neg_thresh) | |
output_file = "bigbird_sentiment_results.csv" | |
df.to_csv(output_file, index=False) | |
if "plot1_path" not in globals(): | |
plt.figure(figsize=(6, 4)) | |
sns.countplot(x=df["sentiment_1to10"], palette="Blues") | |
plt.title("Original 10-Class Sentiment Distribution") | |
plt.tight_layout() | |
global plot1_path | |
plot1_path = "original_dist.png" | |
plt.savefig(plot1_path) | |
plt.close() | |
plt.figure(figsize=(6, 4)) | |
sns.countplot(x=df["sentiment_recategorised"], palette="Set2") | |
plt.title(f"Recategorised Sentiment Distribution ({mode})") | |
plt.tight_layout() | |
plot2_path = "recategorised_dist.png" | |
plt.savefig(plot2_path) | |
plt.close() | |
if "plot3_path" not in globals(): | |
plt.figure(figsize=(6, 4)) | |
sns.histplot(df["confidence"], bins=20, color="skyblue", kde=True) | |
plt.title("Confidence Score Distribution") | |
plt.xlabel("Confidence") | |
plt.tight_layout() | |
global plot3_path | |
plot3_path = "confidence_dist.png" | |
plt.savefig(plot3_path) | |
plt.close() | |
preview = df[[text_column, "processed_text", "sentiment_1to10", "confidence", "sentiment_recategorised"]].head(10) | |
return f"β Sentiment analysis complete. Used cache: {cached_file_name == file.name}", preview, output_file, plot1_path, plot2_path, plot3_path | |
def get_text_columns(file): | |
try: | |
df = pd.read_csv(file.name, nrows=1) | |
text_columns = df.select_dtypes(include='object').columns.tolist() | |
if not text_columns: | |
return gr.update(choices=[], value=None, label="β οΈ No text columns found!") | |
return gr.update(choices=text_columns, value=text_columns[0]) | |
except Exception: | |
return gr.update(choices=[], value=None, label="β οΈ Error reading file") | |
with gr.Blocks() as app: | |
gr.Markdown("## βοΈ Sentiment analysis with Big Bird Flight 2") | |
gr.Markdown("**Citation:** Mat Roni, S. (2025). *Sentiment analysis with Big Bird Flight 2 on Gradio* (version 1.0) [software]. https://huggingface.co/spaces/pvaluedotone/bigbird-flight-2 DOI: https://doi.org/10.57967/hf/5780") | |
with gr.Row(): | |
file_input = gr.File(label="Upload CSV", file_types=[".csv"]) | |
column_dropdown = gr.Dropdown(label="Select Text Column", choices=[], interactive=True) | |
file_input.change(get_text_columns, inputs=file_input, outputs=column_dropdown) | |
output_mode = gr.Radio( | |
label="Sentiment Output Type", | |
choices=["Original (1β10)", "Binary (Positive vs Negative)", "Ternary (Pos/Neu/Neg)"], | |
value="Original (1β10)", | |
interactive=True | |
) | |
pos_thresh_slider = gr.Slider(3, 10, value=7, step=1, label="Positive min", visible=False) | |
neg_thresh_slider = gr.Slider(1, 7, value=4, step=1, label="Negative max", visible=False) | |
auto_fix_checkbox = gr.Checkbox(label="Auto-correct thresholds if overlapping?", value=True) | |
cleaning_checkbox = gr.Checkbox(label="Apply Text Cleaning", value=True) # β New toggle | |
def toggle_thresholds(mode): | |
show_pos = mode != "Original (1β10)" | |
show_neg = mode == "Ternary (Pos/Neu/Neg)" | |
return ( | |
gr.update(visible=show_pos), | |
gr.update(visible=show_neg) | |
) | |
output_mode.change(toggle_thresholds, inputs=output_mode, outputs=[pos_thresh_slider, neg_thresh_slider]) | |
run_button = gr.Button("Process sentiment") | |
status = gr.Textbox(label="Status") | |
df_output = gr.Dataframe(label="Sample Output (Top 10)") | |
file_result = gr.File(label="Download Full Results") | |
plot_orig = gr.Image(label="Original Sentiment Distribution") | |
plot_recat = gr.Image(label="Recategorised Sentiment Distribution") | |
plot_conf = gr.Image(label="Confidence Score Distribution") | |
run_button.click( | |
analyze_sentiment, | |
inputs=[ | |
file_input, column_dropdown, output_mode, | |
pos_thresh_slider, neg_thresh_slider, auto_fix_checkbox, | |
cleaning_checkbox # β New input | |
], | |
outputs=[status, df_output, file_result, plot_orig, plot_recat, plot_conf] | |
) | |
app.launch(share=True, debug=True) | |