Spaces:
Sleeping
Sleeping
import nltk | |
import streamlit as st | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
import torch | |
import torch.nn.functional as F | |
import spacy | |
import re | |
from nltk.sentiment import SentimentIntensityAnalyzer | |
import emoji | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from collections import Counter | |
import time | |
import numpy as np | |
# Configuration - Multiple Models | |
MODELS = { | |
"helinivan": "helinivan/English-sarcasm-detector", | |
"distilbert": "dima806/sarcasm-detection-distilbert" | |
} | |
# Initialize NLTK VADER analyzer | |
try: | |
nltk.data.path.append('/app/nltk_data') | |
sia = SentimentIntensityAnalyzer() | |
except Exception as e: | |
st.error(f"Error downloading NLTK data: {e}") | |
sia = None | |
# Cache multiple models & tokenizers | |
def load_models(): | |
models = {} | |
tokenizers = {} | |
for name, model_path in MODELS.items(): | |
try: | |
model = AutoModelForSequenceClassification.from_pretrained(model_path, cache_dir="/tmp/hf_cache") | |
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/tmp/hf_cache") | |
model.eval() | |
models[name] = model | |
tokenizers[name] = tokenizer | |
st.success(f"β Loaded {name} model successfully") | |
except Exception as e: | |
st.error(f"β Failed to load {name} model: {str(e)}") | |
models[name] = None | |
tokenizers[name] = None | |
return models, tokenizers | |
# Lazy-load SpaCy (optional - not used in current implementation) | |
def load_spacy(): | |
try: | |
return spacy.load("en_core_web_sm") | |
except OSError: | |
st.warning("SpaCy model 'en_core_web_sm' not found. Some features may be limited.") | |
return None | |
# Pattern detection functions with highlighting info | |
def social_media_sarcasm_cues(text: str) -> tuple[float, list, list]: | |
explanations = [] | |
highlights = [] | |
boost = 0.0 | |
text_lower = text.lower() | |
# Enhanced sarcasm phrases (including Reddit-style patterns) | |
sarcasm_phrases = [ | |
"oh sure", "yeah right", "of course", "totally", "absolutely", | |
"perfect", "wonderful", "fantastic", "amazing", "brilliant", | |
"great job", "well done", "nice one", "good going", "way to go", | |
"real smooth", "genius move", "solid plan", "makes sense", | |
"just perfect", "exactly what i needed", "this is fine", | |
# Reddit-style additions | |
"thanks genius", "no shit sherlock", "well duh", "captain obvious", | |
"groundbreaking", "revolutionary", "what a concept", "mind blown", | |
"shocking", "who would have guessed", "truly inspiring" | |
] | |
for phrase in sarcasm_phrases: | |
# Find all occurrences of the phrase | |
for match in re.finditer(re.escape(phrase), text_lower): | |
boost += 0.2 | |
explanations.append(f"Sarcastic phrase: '{phrase}'") | |
highlights.append({ | |
'start': match.start(), | |
'end': match.end(), | |
'type': 'sarcastic_phrase', | |
'text': phrase | |
}) | |
# Exaggerated expressions | |
exaggerated_match = re.search(r'\b(SO|TOTALLY|ABSOLUTELY|REALLY|VERY)\b.*\b(great|good|perfect|amazing|helpful|useful)\b', text, re.IGNORECASE) | |
if exaggerated_match: | |
boost += 0.25 | |
explanations.append("Exaggerated positive expression") | |
highlights.append({ | |
'start': exaggerated_match.start(), | |
'end': exaggerated_match.end(), | |
'type': 'exaggerated', | |
'text': exaggerated_match.group() | |
}) | |
return boost, explanations, highlights | |
def emoji_punctuation_analysis(text: str) -> tuple[float, list, list]: | |
explanations = [] | |
highlights = [] | |
boost = 0.0 | |
# Extract emojis with positions | |
try: | |
emojis = emoji.emoji_list(text) | |
sarcastic_emojis = ['π', 'π', 'π', 'π€', 'π€¨', 'π€', 'π€·', 'π', 'π', 'π€‘', 'π', 'π€―'] | |
for emoji_info in emojis: | |
if emoji_info['emoji'] in sarcastic_emojis: | |
boost += 0.15 | |
explanations.append(f"Sarcastic emoji: {emoji_info['emoji']}") | |
highlights.append({ | |
'start': emoji_info['match_start'], | |
'end': emoji_info['match_end'], | |
'type': 'sarcastic_emoji', | |
'text': emoji_info['emoji'] | |
}) | |
except Exception as e: | |
# Fallback if emoji library has issues | |
pass | |
# Excessive punctuation | |
for match in re.finditer(r'[!?]{2,}', text): | |
boost += 0.1 | |
explanations.append(f"Excessive punctuation: {match.group()}") | |
highlights.append({ | |
'start': match.start(), | |
'end': match.end(), | |
'type': 'excessive_punct', | |
'text': match.group() | |
}) | |
# Ellipsis (often sarcastic) | |
for match in re.finditer(r'\.{3,}', text): | |
boost += 0.15 | |
explanations.append(f"Trailing ellipsis: {match.group()}") | |
highlights.append({ | |
'start': match.start(), | |
'end': match.end(), | |
'type': 'ellipsis', | |
'text': match.group() | |
}) | |
return boost, explanations, highlights | |
def rhetorical_questions_analysis(text: str) -> tuple[float, list, list]: | |
explanations = [] | |
highlights = [] | |
boost = 0.0 | |
rhetorical_patterns = [ | |
(r'what could possibly go wrong\?', "Rhetorical question"), | |
(r'who would have thought\?', "Rhetorical question"), | |
(r'seriously\?', "Emphatic question"), | |
(r'really\?.*really\?', "Repeated question"), | |
(r'no way\?', "Disbelief question"), | |
(r'you don\'t say\?', "Sarcastic response"), | |
(r'shocking.*\?', "Mock surprise") | |
] | |
for pattern, description in rhetorical_patterns: | |
for match in re.finditer(pattern, text, re.IGNORECASE): | |
boost += 0.3 | |
explanations.append(description) | |
highlights.append({ | |
'start': match.start(), | |
'end': match.end(), | |
'type': 'rhetorical_question', | |
'text': match.group() | |
}) | |
return boost, explanations, highlights | |
def capitalization_analysis(text: str) -> tuple[float, list, list]: | |
explanations = [] | |
highlights = [] | |
boost = 0.0 | |
# ALL CAPS words | |
for match in re.finditer(r'\b[A-Z]{3,}\b', text): | |
if match.group() not in ['AND', 'THE', 'FOR', 'BUT', 'YOU', 'ARE']: | |
boost += 0.1 | |
explanations.append(f"Emphatic caps: {match.group()}") | |
highlights.append({ | |
'start': match.start(), | |
'end': match.end(), | |
'type': 'caps_emphasis', | |
'text': match.group() | |
}) | |
# Letter repetition | |
for match in re.finditer(r'(.)\1{2,}', text): | |
boost += 0.1 | |
explanations.append(f"Letter repetition: {match.group()}") | |
highlights.append({ | |
'start': match.start(), | |
'end': match.end(), | |
'type': 'repetition', | |
'text': match.group() | |
}) | |
return boost, explanations, highlights | |
# Combined analysis with highlighting | |
def enhanced_rule_analysis(text: str) -> tuple[float, list, list]: | |
all_explanations = [] | |
all_highlights = [] | |
total_boost = 0.0 | |
# Apply all analysis functions | |
boost1, exp1, high1 = social_media_sarcasm_cues(text) | |
boost2, exp2, high2 = emoji_punctuation_analysis(text) | |
boost3, exp3, high3 = rhetorical_questions_analysis(text) | |
boost4, exp4, high4 = capitalization_analysis(text) | |
total_boost = boost1 + boost2 + boost3 + boost4 | |
all_explanations.extend(exp1 + exp2 + exp3 + exp4) | |
all_highlights.extend(high1 + high2 + high3 + high4) | |
# Cap the total boost | |
total_boost = min(total_boost, 0.8) | |
return total_boost, all_explanations, all_highlights | |
# Multi-model prediction function | |
def get_model_predictions_current(text: str, models: dict, tokenizers: dict, device) -> dict: | |
predictions = {} | |
for name, model in models.items(): | |
if model is None or tokenizers[name] is None: | |
predictions[name] = 0.0 | |
continue | |
try: | |
inputs = tokenizers[name]([text], return_tensors="pt", truncation=True, padding=True).to(device) | |
model.to(device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# Handle different output formats | |
if logits.shape[-1] == 2: # Binary classification | |
score = F.softmax(logits, dim=-1)[0, 1].item() | |
else: # Single output | |
score = torch.sigmoid(logits)[0, 0].item() | |
predictions[name] = score | |
except Exception as e: | |
st.warning(f"Error with {name} model: {str(e)}") | |
predictions[name] = 0.0 | |
return predictions | |
# Modify get_model_predictions to accept context and reply | |
def get_model_predictions_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict: | |
predictions = {} | |
for name, model in models.items(): | |
if model is None or tokenizers[name] is None: | |
predictions[name] = 0.0 | |
continue | |
try: | |
# Use sentence-pair interface | |
inputs = tokenizers[name]( | |
context, | |
reply, | |
return_tensors="pt", | |
truncation=True, | |
padding=True | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
model.to(device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
if logits.shape[-1] == 2: # Binary classification | |
score = F.softmax(logits, dim=-1)[0, 1].item() | |
else: | |
score = torch.sigmoid(logits)[0, 0].item() | |
predictions[name] = score | |
except Exception as e: | |
st.warning(f"Error with {name} model: {str(e)}") | |
predictions[name] = 0.0 | |
return predictions | |
# Enhanced ensemble prediction | |
def ensemble_prediction(model_scores: dict, rule_boost: float, weights: dict = None) -> float: | |
if weights is None: | |
# Default weights - adjust based on model performance | |
weights = { | |
'helinivan': 0.4, | |
'distilbert': 0.5, # Higher weight for Reddit-trained model | |
'rules': 0.1 | |
} | |
ensemble_score = 0.0 | |
total_weight = 0.0 | |
# Weighted average of model predictions | |
for model_name, score in model_scores.items(): | |
if score > 0: # Only include valid predictions | |
weight = weights.get(model_name, 0.3) | |
ensemble_score += score * weight | |
total_weight += weight | |
# Add rule-based contribution | |
if total_weight > 0: | |
ensemble_score = ensemble_score / total_weight | |
# Apply rule-based boost | |
final_score = min(ensemble_score + (rule_boost * weights.get('rules', 0.1)), 1.0) | |
return final_score | |
# Create highlighted text HTML | |
def create_highlighted_text(text: str, highlights: list) -> str: | |
if not isinstance(text, str): | |
return "" | |
if not highlights: | |
return text.replace("&", "&").replace("<", "<").replace(">", ">") | |
# Sort highlights by start position | |
sorted_highlights = sorted(highlights, key=lambda x: x['start']) | |
color_map = { | |
'sarcastic_phrase': '#ff6b6b', | |
'sarcastic_emoji': '#4ecdc4', | |
'excessive_punct': '#45b7d1', | |
'rhetorical_question': '#96ceb4', | |
'caps_emphasis': '#feca57', | |
'repetition': '#ff9ff3', | |
'exaggerated': '#54a0ff', | |
'ellipsis': '#fd79a8' | |
} | |
result = "" | |
last_end = 0 | |
for highlight in sorted_highlights: | |
start, end = highlight['start'], highlight['end'] | |
highlight_type = highlight['type'] | |
color = color_map.get(highlight_type, '#dda0dd') | |
# Add text before highlight | |
if start > last_end: | |
before_text = text[last_end:start] | |
result += before_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
# Add highlighted text | |
highlighted_text = text[start:end] | |
safe_text = highlighted_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
result += f'<span style="background-color: {color}; padding: 2px 4px; border-radius: 3px; color: black;">{safe_text}</span>' | |
last_end = end | |
# Add remaining text | |
if last_end < len(text): | |
remaining_text = text[last_end:] | |
result += remaining_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
return result | |
# Enhanced confidence gauge | |
def create_confidence_gauge(score: float) -> go.Figure: | |
fig = go.Figure(go.Indicator( | |
mode = "gauge+number+delta", | |
value = score, | |
domain = {'x': [0, 1], 'y': [0, 1]}, | |
title = {'text': "Ensemble Sarcasm Score"}, | |
delta = {'reference': 0.5}, | |
gauge = { | |
'axis': {'range': [None, 1]}, | |
'bar': {'color': "darkblue"}, | |
'steps': [ | |
{'range': [0, 0.3], 'color': "lightgray"}, | |
{'range': [0.3, 0.6], 'color': "yellow"}, | |
{'range': [0.6, 1], 'color': "red"} | |
], | |
'threshold': { | |
'line': {'color': "red", 'width': 4}, | |
'thickness': 0.75, | |
'value': 0.7 | |
} | |
} | |
)) | |
fig.update_layout(height=300) | |
return fig | |
# Multi-model feature importance visualization | |
def create_model_comparison_chart(model_scores: dict, rule_boost: float, final_score: float) -> go.Figure: | |
models = list(model_scores.keys()) | |
scores = list(model_scores.values()) | |
# Add rule-based and final scores | |
models.extend(['Rule-based', 'Final Ensemble']) | |
scores.extend([rule_boost, final_score]) | |
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'] | |
fig = go.Figure(go.Bar( | |
x=models, | |
y=scores, | |
marker_color=colors[:len(models)], | |
text=[f'{score:.3f}' for score in scores], | |
textposition='auto', | |
)) | |
fig.update_layout( | |
title="Model Comparison & Ensemble Result", | |
yaxis_title="Sarcasm Score", | |
height=400, | |
showlegend=False | |
) | |
return fig | |
# Real-time analysis function with multiple models | |
def analyze_text_realtime_current(text: str, models: dict, tokenizers: dict, device) -> dict: | |
if not text.strip(): | |
return { | |
'score': 0.0, | |
'label': 'Enter text to analyze', | |
'explanations': [], | |
'highlights': [], | |
'model_scores': {} | |
} | |
try: | |
# Get rule-based analysis | |
rule_boost, explanations, highlights = enhanced_rule_analysis(text) | |
# Get predictions from all models | |
model_scores = get_model_predictions_current(text, models, tokenizers, device) | |
# Ensemble prediction | |
final_score = ensemble_prediction(model_scores, rule_boost) | |
# Determine label | |
if final_score > 0.8: | |
label = "Extremely Sarcastic π€¨π" | |
elif final_score > 0.7: | |
label = "Highly Sarcastic π€¨" | |
elif final_score > 0.6: | |
label = "Likely Sarcastic π" | |
elif final_score > 0.4: | |
label = "Possibly Sarcastic π€" | |
elif final_score > 0.3: | |
label = "Probably Sincere π" | |
else: | |
label = "Sincere π" | |
return { | |
'score': final_score, | |
'model_scores': model_scores, | |
'rule_boost': rule_boost, | |
'label': label, | |
'explanations': explanations, | |
'highlights': highlights | |
} | |
except Exception as e: | |
return { | |
'score': 0.0, | |
'label': f'Error: {str(e)}', | |
'explanations': [], | |
'highlights': [], | |
'model_scores': {} | |
} | |
# Modify analyze_text_realtime to accept context and reply | |
def analyze_text_realtime_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict: | |
if not reply.strip(): | |
return { | |
'score': 0.0, | |
'label': 'Enter context and reply to analyze', | |
'explanations': [], | |
'highlights': [], | |
'model_scores': {} | |
} | |
try: | |
# Use reply for rule-based analysis (context is not used in rules) | |
rule_boost, explanations, highlights = enhanced_rule_analysis(reply) | |
# Get predictions from all models using context and reply | |
model_scores = get_model_predictions_experiment(context, reply, models, tokenizers, device) | |
final_score = ensemble_prediction(model_scores, rule_boost) | |
# ...label assignment unchanged... | |
if final_score > 0.8: | |
label = "Extremely Sarcastic π€¨π" | |
elif final_score > 0.7: | |
label = "Highly Sarcastic π€¨" | |
elif final_score > 0.6: | |
label = "Likely Sarcastic π" | |
elif final_score > 0.4: | |
label = "Possibly Sarcastic π€" | |
elif final_score > 0.3: | |
label = "Probably Sincere π" | |
else: | |
label = "Sincere π" | |
return { | |
'score': final_score, | |
'model_scores': model_scores, | |
'rule_boost': rule_boost, | |
'label': label, | |
'explanations': explanations, | |
'highlights': highlights | |
} | |
except Exception as e: | |
return { | |
'score': 0.0, | |
'label': f'Error: {str(e)}', | |
'explanations': [], | |
'highlights': [], | |
'model_scores': {} | |
} | |
# Streamlit UI | |
st.set_page_config(page_title="Enhanced Sarcasm Detector", page_icon="π€¨", layout="wide") | |
st.title("π¨οΈ Enhanced Multi-Model Sarcasm Detector") | |
st.markdown("*Combining DistilBERT (Reddit-trained) + HelinIvan + Rule-based Analysis*") | |
# Load models | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
st.markdown(f"**Device:** {device}") | |
with st.spinner("Loading AI models..."): | |
models, tokenizers = load_models() | |
# Model status display | |
st.markdown("### π€ Model Status") | |
status_cols = st.columns(2) | |
with status_cols[0]: | |
helinivan_status = "β Loaded" if models.get('helinivan') else "β Failed" | |
st.markdown(f"**HelinIvan Model:** {helinivan_status}") | |
with status_cols[1]: | |
distilbert_status = "β Loaded" if models.get('distilbert') else "β Failed" | |
st.markdown(f"**DistilBERT Model:** {distilbert_status}") | |
# Sidebar with examples and tips (shared) | |
with st.sidebar: | |
st.markdown("### π‘ **Quick Examples**") | |
example_buttons = [ | |
("Oh great, more traffic π", "social_media"), | |
("Yeah, I just LOVE waiting in line", "emphasis"), | |
("What could possibly go wrong?", "rhetorical"), | |
("Perfect timing as always...", "timing"), | |
("Thanks for the help genius", "reddit_style"), | |
("WOW so helpful!!!", "caps_sarcasm"), | |
("No shit Sherlock π€‘", "reddit_sarcasm"), | |
("Truly groundbreaking stuff here", "mock_praise") | |
] | |
for example_text, example_type in example_buttons: | |
if st.button(f"π {example_text[:25]}...", key=example_type): | |
st.session_state.example_text = example_text | |
# --- Tabs for navigation --- | |
tab1, tab2 = st.tabs(["Single Message (Current)", "Context-Aware (Experimental)"]) | |
# --- Tab 1: Single Message (Current) --- | |
with tab1: | |
st.markdown("### Single Message Sarcasm Detection") | |
st.markdown("Analyze sarcasm in a single message (no conversational context).") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
# Text input with real-time analysi | |
default_text = st.session_state.get('example_text', '') | |
user_text = st.text_area( | |
"Enter a paragraph for multi-model sarcasm analysis:", | |
value=default_text, | |
height=120, | |
placeholder="Try: 'Oh fantastic, another meeting that could have been an email π What a brilliant use of everyone's time...'" | |
) | |
# Real-time analysis | |
if user_text: | |
with st.spinner("Analyzing with multiple models..."): | |
analysis = analyze_text_realtime_current(user_text, models, tokenizers, device) | |
# Display highlighted text | |
st.markdown("### π― **Analysis Results**") | |
highlighted_html = create_highlighted_text(user_text, analysis['highlights']) | |
st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True) | |
# Prediction and confidence | |
st.markdown(f"### **Prediction: {analysis['label']}**") | |
# Progress bar with custom colors | |
progress_color = "π΄" if analysis['score'] > 0.7 else "π‘" if analysis['score'] > 0.4 else "π’" | |
st.write(f"**Ensemble Score: {analysis['score']:.3f}** {progress_color}") | |
st.progress(analysis['score']) | |
with col2: | |
if user_text and 'analysis' in locals(): | |
# Confidence gauge | |
st.markdown("### π **Confidence Gauge**") | |
gauge_fig = create_confidence_gauge(analysis['score']) | |
st.plotly_chart(gauge_fig, use_container_width=True) | |
# Multi-model analysis section | |
if user_text and 'analysis' in locals() and analysis['model_scores']: | |
st.markdown("### π **Multi-Model Analysis**") | |
col3, col4 = st.columns([1, 1]) | |
with col3: | |
st.markdown("#### π **Individual Model Scores:**") | |
for model_name, score in analysis['model_scores'].items(): | |
model_display = { | |
'helinivan': 'HelinIvan Model', | |
'distilbert': 'DistilBERT Model' | |
} | |
display_name = model_display.get(model_name, model_name) | |
st.write(f"β’ **{display_name}:** {score:.3f}") | |
st.write(f"β’ **Rule-based boost:** +{analysis['rule_boost']:.3f}") | |
st.write(f"β’ **π― Final ensemble:** {analysis['score']:.3f}") | |
if analysis['explanations']: | |
st.markdown("#### π **Detected Patterns:**") | |
for i, explanation in enumerate(analysis['explanations'], 1): | |
st.write(f"{i}. {explanation}") | |
with col4: | |
st.markdown("#### π **Model Comparison:**") | |
comparison_fig = create_model_comparison_chart( | |
analysis['model_scores'], | |
analysis['rule_boost'], | |
analysis['score'] | |
) | |
st.plotly_chart(comparison_fig, use_container_width=True) | |
# Pattern legend | |
if user_text and 'analysis' in locals() and analysis['highlights']: | |
st.markdown("### π¨ **Highlighting Legend**") | |
legend_cols = st.columns(4) | |
legend_items = [ | |
("Sarcastic Phrases", "#ff6b6b"), | |
("Emojis", "#4ecdc4"), | |
("Punctuation", "#45b7d1"), | |
("Questions", "#96ceb4"), | |
("Emphasis", "#feca57"), | |
("Repetition", "#ff9ff3"), | |
("Exaggeration", "#54a0ff"), | |
("Ellipsis", "#fd79a8") | |
] | |
for i, (label, color) in enumerate(legend_items): | |
with legend_cols[i % 4]: | |
safe_label = label.replace("<", "<").replace(">", ">") | |
st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True) | |
# --- Tab 2: Context-Aware (Experimental) --- | |
with tab2: | |
st.markdown("### Context-Aware Sarcasm Detection (Experimental)") | |
st.info("This feature is experimental. The models are **not yet trained** on context+reply pairs. Predictions are based on formatting the input as a sentence pair, but results may not be reliable.") | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
context_text = st.text_area( | |
"Context (previous message):", | |
value=st.session_state.get('context_text', ''), | |
height=68, | |
placeholder="e.g. 'Can you finish this by today?'" | |
) | |
reply_text = st.text_area( | |
"Reply (current message):", | |
value=st.session_state.get('reply_text', st.session_state.get('example_text', '')), | |
height=80, | |
placeholder="e.g. 'Oh sure, because I have nothing else to do.'" | |
) | |
if reply_text: | |
with st.spinner("Analyzing with experimental context-aware input..."): | |
analysis_ctx = analyze_text_realtime_experiment(context_text, reply_text, models, tokenizers, device) | |
st.markdown("### π― **Analysis Results**") | |
highlighted_html = create_highlighted_text(reply_text, analysis_ctx['highlights']) | |
st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True) | |
st.markdown(f"### **Prediction: {analysis_ctx['label']}**") | |
progress_color = "π΄" if analysis_ctx['score'] > 0.7 else "π‘" if analysis_ctx['score'] > 0.4 else "π’" | |
st.write(f"**Ensemble Score: {analysis_ctx['score']:.3f}** {progress_color}") | |
st.progress(analysis_ctx['score']) | |
with col2: | |
if reply_text and 'analysis_ctx' in locals(): | |
st.markdown("### π **Confidence Gauge**") | |
gauge_fig = create_confidence_gauge(analysis_ctx['score']) | |
st.plotly_chart(gauge_fig, use_container_width=True) | |
if reply_text and 'analysis_ctx' in locals() and analysis_ctx['model_scores']: | |
st.markdown("### π **Multi-Model Analysis**") | |
col3, col4 = st.columns([1, 1]) | |
with col3: | |
st.markdown("#### π **Individual Model Scores:**") | |
for model_name, score in analysis_ctx['model_scores'].items(): | |
model_display = { | |
'helinivan': 'HelinIvan Model', | |
'distilbert': 'DistilBERT Model' | |
} | |
display_name = model_display.get(model_name, model_name) | |
st.write(f"β’ **{display_name}:** {score:.3f}") | |
st.write(f"β’ **Rule-based boost:** +{analysis_ctx['rule_boost']:.3f}") | |
st.write(f"β’ **π― Final ensemble:** {analysis_ctx['score']:.3f}") | |
if analysis_ctx['explanations']: | |
st.markdown("#### π **Detected Patterns:**") | |
for i, explanation in enumerate(analysis_ctx['explanations'], 1): | |
st.write(f"{i}. {explanation}") | |
with col4: | |
st.markdown("#### π **Model Comparison:**") | |
comparison_fig = create_model_comparison_chart( | |
analysis_ctx['model_scores'], | |
analysis_ctx['rule_boost'], | |
analysis_ctx['score'] | |
) | |
st.plotly_chart(comparison_fig, use_container_width=True) | |
if reply_text and 'analysis_ctx' in locals() and analysis_ctx['highlights']: | |
st.markdown("### π¨ **Highlighting Legend**") | |
legend_cols = st.columns(4) | |
legend_items = [ | |
("Sarcastic Phrases", "#ff6b6b"), | |
("Emojis", "#4ecdc4"), | |
("Punctuation", "#45b7d1"), | |
("Questions", "#96ceb4"), | |
("Emphasis", "#feca57"), | |
("Repetition", "#ff9ff3"), | |
("Exaggeration", "#54a0ff"), | |
("Ellipsis", "#fd79a8") | |
] | |
for i, (label, color) in enumerate(legend_items): | |
with legend_cols[i % 4]: | |
safe_label = label.replace("<", "<").replace(">", ">") | |
st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True) | |
# --- Shared tutorial and advanced settings --- | |
with st.expander("π **Multi-Model Sarcasm Detection Guide**"): | |
st.markdown(""" | |
### How the Enhanced Detection Works: | |
1. **π€ HelinIvan Model**: General English sarcasm detection | |
2. **π€ DistilBERT Model**: Specialized Reddit-trained sarcasm detector | |
3. **π Rule-Based Analysis**: Linguistic patterns and social media cues | |
4. **π― Ensemble Method**: Combines all approaches with weighted averaging | |
### Model Advantages: | |
- **HelinIvan**: Good for formal and general sarcasm | |
- **DistilBERT (Reddit)**: Excellent for informal, social media style sarcasm | |
- **Rule-based**: Catches obvious patterns and cultural references | |
### Why Ensemble Works Better: | |
- **Robustness**: Multiple models reduce individual model weaknesses | |
- **Coverage**: Different training data covers different sarcasm styles | |
- **Confidence**: Agreement between models increases reliability | |
### Try These Reddit-Style Examples: | |
- "No shit Sherlock π€‘" | |
- "Thanks Captain Obvious" | |
- "Groundbreaking discovery there genius" | |
- "What a concept... mind blown π€―" | |
""") | |
# Model weights adjustment (advanced users) | |
with st.expander("βοΈ **Advanced: Adjust Model Weights**"): | |
st.markdown("Fine-tune the ensemble by adjusting model importance:") | |
col_w1, col_w2, col_w3 = st.columns(3) | |
with col_w1: | |
helinivan_weight = st.slider("HelinIvan Weight", 0.0, 1.0, 0.4, 0.1) | |
with col_w2: | |
distilbert_weight = st.slider("DistilBERT Weight", 0.0, 1.0, 0.5, 0.1) | |
with col_w3: | |
rules_weight = st.slider("Rules Weight", 0.0, 0.5, 0.1, 0.05) | |
st.info(f"Weights - HelinIvan: {helinivan_weight}, DistilBERT: {distilbert_weight}, Rules: {rules_weight}") | |
# Clear session state | |
if st.button("π Clear Text", key="clear_main"): | |
if 'example_text' in st.session_state: | |
del st.session_state.example_text | |
st.rerun() | |