muhdfiqq's picture
Upload app.py
a00e3f2 verified
import nltk
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
import torch.nn.functional as F
import spacy
import re
from nltk.sentiment import SentimentIntensityAnalyzer
import emoji
import plotly.graph_objects as go
import plotly.express as px
from collections import Counter
import time
import numpy as np
# Configuration - Multiple Models
MODELS = {
"helinivan": "helinivan/English-sarcasm-detector",
"distilbert": "dima806/sarcasm-detection-distilbert"
}
# Initialize NLTK VADER analyzer
try:
nltk.data.path.append('/app/nltk_data')
sia = SentimentIntensityAnalyzer()
except Exception as e:
st.error(f"Error downloading NLTK data: {e}")
sia = None
# Cache multiple models & tokenizers
@st.cache_resource
def load_models():
models = {}
tokenizers = {}
for name, model_path in MODELS.items():
try:
model = AutoModelForSequenceClassification.from_pretrained(model_path, cache_dir="/tmp/hf_cache")
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/tmp/hf_cache")
model.eval()
models[name] = model
tokenizers[name] = tokenizer
st.success(f"βœ… Loaded {name} model successfully")
except Exception as e:
st.error(f"❌ Failed to load {name} model: {str(e)}")
models[name] = None
tokenizers[name] = None
return models, tokenizers
# Lazy-load SpaCy (optional - not used in current implementation)
def load_spacy():
try:
return spacy.load("en_core_web_sm")
except OSError:
st.warning("SpaCy model 'en_core_web_sm' not found. Some features may be limited.")
return None
# Pattern detection functions with highlighting info
def social_media_sarcasm_cues(text: str) -> tuple[float, list, list]:
explanations = []
highlights = []
boost = 0.0
text_lower = text.lower()
# Enhanced sarcasm phrases (including Reddit-style patterns)
sarcasm_phrases = [
"oh sure", "yeah right", "of course", "totally", "absolutely",
"perfect", "wonderful", "fantastic", "amazing", "brilliant",
"great job", "well done", "nice one", "good going", "way to go",
"real smooth", "genius move", "solid plan", "makes sense",
"just perfect", "exactly what i needed", "this is fine",
# Reddit-style additions
"thanks genius", "no shit sherlock", "well duh", "captain obvious",
"groundbreaking", "revolutionary", "what a concept", "mind blown",
"shocking", "who would have guessed", "truly inspiring"
]
for phrase in sarcasm_phrases:
# Find all occurrences of the phrase
for match in re.finditer(re.escape(phrase), text_lower):
boost += 0.2
explanations.append(f"Sarcastic phrase: '{phrase}'")
highlights.append({
'start': match.start(),
'end': match.end(),
'type': 'sarcastic_phrase',
'text': phrase
})
# Exaggerated expressions
exaggerated_match = re.search(r'\b(SO|TOTALLY|ABSOLUTELY|REALLY|VERY)\b.*\b(great|good|perfect|amazing|helpful|useful)\b', text, re.IGNORECASE)
if exaggerated_match:
boost += 0.25
explanations.append("Exaggerated positive expression")
highlights.append({
'start': exaggerated_match.start(),
'end': exaggerated_match.end(),
'type': 'exaggerated',
'text': exaggerated_match.group()
})
return boost, explanations, highlights
def emoji_punctuation_analysis(text: str) -> tuple[float, list, list]:
explanations = []
highlights = []
boost = 0.0
# Extract emojis with positions
try:
emojis = emoji.emoji_list(text)
sarcastic_emojis = ['πŸ™„', '😏', 'πŸ˜’', 'πŸ€”', '🀨', '😀', '🀷', 'πŸ‘', 'πŸ™ƒ', '🀑', 'πŸ’€', '🀯']
for emoji_info in emojis:
if emoji_info['emoji'] in sarcastic_emojis:
boost += 0.15
explanations.append(f"Sarcastic emoji: {emoji_info['emoji']}")
highlights.append({
'start': emoji_info['match_start'],
'end': emoji_info['match_end'],
'type': 'sarcastic_emoji',
'text': emoji_info['emoji']
})
except Exception as e:
# Fallback if emoji library has issues
pass
# Excessive punctuation
for match in re.finditer(r'[!?]{2,}', text):
boost += 0.1
explanations.append(f"Excessive punctuation: {match.group()}")
highlights.append({
'start': match.start(),
'end': match.end(),
'type': 'excessive_punct',
'text': match.group()
})
# Ellipsis (often sarcastic)
for match in re.finditer(r'\.{3,}', text):
boost += 0.15
explanations.append(f"Trailing ellipsis: {match.group()}")
highlights.append({
'start': match.start(),
'end': match.end(),
'type': 'ellipsis',
'text': match.group()
})
return boost, explanations, highlights
def rhetorical_questions_analysis(text: str) -> tuple[float, list, list]:
explanations = []
highlights = []
boost = 0.0
rhetorical_patterns = [
(r'what could possibly go wrong\?', "Rhetorical question"),
(r'who would have thought\?', "Rhetorical question"),
(r'seriously\?', "Emphatic question"),
(r'really\?.*really\?', "Repeated question"),
(r'no way\?', "Disbelief question"),
(r'you don\'t say\?', "Sarcastic response"),
(r'shocking.*\?', "Mock surprise")
]
for pattern, description in rhetorical_patterns:
for match in re.finditer(pattern, text, re.IGNORECASE):
boost += 0.3
explanations.append(description)
highlights.append({
'start': match.start(),
'end': match.end(),
'type': 'rhetorical_question',
'text': match.group()
})
return boost, explanations, highlights
def capitalization_analysis(text: str) -> tuple[float, list, list]:
explanations = []
highlights = []
boost = 0.0
# ALL CAPS words
for match in re.finditer(r'\b[A-Z]{3,}\b', text):
if match.group() not in ['AND', 'THE', 'FOR', 'BUT', 'YOU', 'ARE']:
boost += 0.1
explanations.append(f"Emphatic caps: {match.group()}")
highlights.append({
'start': match.start(),
'end': match.end(),
'type': 'caps_emphasis',
'text': match.group()
})
# Letter repetition
for match in re.finditer(r'(.)\1{2,}', text):
boost += 0.1
explanations.append(f"Letter repetition: {match.group()}")
highlights.append({
'start': match.start(),
'end': match.end(),
'type': 'repetition',
'text': match.group()
})
return boost, explanations, highlights
# Combined analysis with highlighting
def enhanced_rule_analysis(text: str) -> tuple[float, list, list]:
all_explanations = []
all_highlights = []
total_boost = 0.0
# Apply all analysis functions
boost1, exp1, high1 = social_media_sarcasm_cues(text)
boost2, exp2, high2 = emoji_punctuation_analysis(text)
boost3, exp3, high3 = rhetorical_questions_analysis(text)
boost4, exp4, high4 = capitalization_analysis(text)
total_boost = boost1 + boost2 + boost3 + boost4
all_explanations.extend(exp1 + exp2 + exp3 + exp4)
all_highlights.extend(high1 + high2 + high3 + high4)
# Cap the total boost
total_boost = min(total_boost, 0.8)
return total_boost, all_explanations, all_highlights
# Multi-model prediction function
def get_model_predictions_current(text: str, models: dict, tokenizers: dict, device) -> dict:
predictions = {}
for name, model in models.items():
if model is None or tokenizers[name] is None:
predictions[name] = 0.0
continue
try:
inputs = tokenizers[name]([text], return_tensors="pt", truncation=True, padding=True).to(device)
model.to(device)
with torch.no_grad():
logits = model(**inputs).logits
# Handle different output formats
if logits.shape[-1] == 2: # Binary classification
score = F.softmax(logits, dim=-1)[0, 1].item()
else: # Single output
score = torch.sigmoid(logits)[0, 0].item()
predictions[name] = score
except Exception as e:
st.warning(f"Error with {name} model: {str(e)}")
predictions[name] = 0.0
return predictions
# Modify get_model_predictions to accept context and reply
def get_model_predictions_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict:
predictions = {}
for name, model in models.items():
if model is None or tokenizers[name] is None:
predictions[name] = 0.0
continue
try:
# Use sentence-pair interface
inputs = tokenizers[name](
context,
reply,
return_tensors="pt",
truncation=True,
padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
model.to(device)
with torch.no_grad():
logits = model(**inputs).logits
if logits.shape[-1] == 2: # Binary classification
score = F.softmax(logits, dim=-1)[0, 1].item()
else:
score = torch.sigmoid(logits)[0, 0].item()
predictions[name] = score
except Exception as e:
st.warning(f"Error with {name} model: {str(e)}")
predictions[name] = 0.0
return predictions
# Enhanced ensemble prediction
def ensemble_prediction(model_scores: dict, rule_boost: float, weights: dict = None) -> float:
if weights is None:
# Default weights - adjust based on model performance
weights = {
'helinivan': 0.4,
'distilbert': 0.5, # Higher weight for Reddit-trained model
'rules': 0.1
}
ensemble_score = 0.0
total_weight = 0.0
# Weighted average of model predictions
for model_name, score in model_scores.items():
if score > 0: # Only include valid predictions
weight = weights.get(model_name, 0.3)
ensemble_score += score * weight
total_weight += weight
# Add rule-based contribution
if total_weight > 0:
ensemble_score = ensemble_score / total_weight
# Apply rule-based boost
final_score = min(ensemble_score + (rule_boost * weights.get('rules', 0.1)), 1.0)
return final_score
# Create highlighted text HTML
def create_highlighted_text(text: str, highlights: list) -> str:
if not isinstance(text, str):
return ""
if not highlights:
return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
# Sort highlights by start position
sorted_highlights = sorted(highlights, key=lambda x: x['start'])
color_map = {
'sarcastic_phrase': '#ff6b6b',
'sarcastic_emoji': '#4ecdc4',
'excessive_punct': '#45b7d1',
'rhetorical_question': '#96ceb4',
'caps_emphasis': '#feca57',
'repetition': '#ff9ff3',
'exaggerated': '#54a0ff',
'ellipsis': '#fd79a8'
}
result = ""
last_end = 0
for highlight in sorted_highlights:
start, end = highlight['start'], highlight['end']
highlight_type = highlight['type']
color = color_map.get(highlight_type, '#dda0dd')
# Add text before highlight
if start > last_end:
before_text = text[last_end:start]
result += before_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
# Add highlighted text
highlighted_text = text[start:end]
safe_text = highlighted_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
result += f'<span style="background-color: {color}; padding: 2px 4px; border-radius: 3px; color: black;">{safe_text}</span>'
last_end = end
# Add remaining text
if last_end < len(text):
remaining_text = text[last_end:]
result += remaining_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
return result
# Enhanced confidence gauge
def create_confidence_gauge(score: float) -> go.Figure:
fig = go.Figure(go.Indicator(
mode = "gauge+number+delta",
value = score,
domain = {'x': [0, 1], 'y': [0, 1]},
title = {'text': "Ensemble Sarcasm Score"},
delta = {'reference': 0.5},
gauge = {
'axis': {'range': [None, 1]},
'bar': {'color': "darkblue"},
'steps': [
{'range': [0, 0.3], 'color': "lightgray"},
{'range': [0.3, 0.6], 'color': "yellow"},
{'range': [0.6, 1], 'color': "red"}
],
'threshold': {
'line': {'color': "red", 'width': 4},
'thickness': 0.75,
'value': 0.7
}
}
))
fig.update_layout(height=300)
return fig
# Multi-model feature importance visualization
def create_model_comparison_chart(model_scores: dict, rule_boost: float, final_score: float) -> go.Figure:
models = list(model_scores.keys())
scores = list(model_scores.values())
# Add rule-based and final scores
models.extend(['Rule-based', 'Final Ensemble'])
scores.extend([rule_boost, final_score])
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
fig = go.Figure(go.Bar(
x=models,
y=scores,
marker_color=colors[:len(models)],
text=[f'{score:.3f}' for score in scores],
textposition='auto',
))
fig.update_layout(
title="Model Comparison & Ensemble Result",
yaxis_title="Sarcasm Score",
height=400,
showlegend=False
)
return fig
# Real-time analysis function with multiple models
def analyze_text_realtime_current(text: str, models: dict, tokenizers: dict, device) -> dict:
if not text.strip():
return {
'score': 0.0,
'label': 'Enter text to analyze',
'explanations': [],
'highlights': [],
'model_scores': {}
}
try:
# Get rule-based analysis
rule_boost, explanations, highlights = enhanced_rule_analysis(text)
# Get predictions from all models
model_scores = get_model_predictions_current(text, models, tokenizers, device)
# Ensemble prediction
final_score = ensemble_prediction(model_scores, rule_boost)
# Determine label
if final_score > 0.8:
label = "Extremely Sarcastic 🀨😏"
elif final_score > 0.7:
label = "Highly Sarcastic 🀨"
elif final_score > 0.6:
label = "Likely Sarcastic 😏"
elif final_score > 0.4:
label = "Possibly Sarcastic πŸ€”"
elif final_score > 0.3:
label = "Probably Sincere πŸ™‚"
else:
label = "Sincere 😊"
return {
'score': final_score,
'model_scores': model_scores,
'rule_boost': rule_boost,
'label': label,
'explanations': explanations,
'highlights': highlights
}
except Exception as e:
return {
'score': 0.0,
'label': f'Error: {str(e)}',
'explanations': [],
'highlights': [],
'model_scores': {}
}
# Modify analyze_text_realtime to accept context and reply
def analyze_text_realtime_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict:
if not reply.strip():
return {
'score': 0.0,
'label': 'Enter context and reply to analyze',
'explanations': [],
'highlights': [],
'model_scores': {}
}
try:
# Use reply for rule-based analysis (context is not used in rules)
rule_boost, explanations, highlights = enhanced_rule_analysis(reply)
# Get predictions from all models using context and reply
model_scores = get_model_predictions_experiment(context, reply, models, tokenizers, device)
final_score = ensemble_prediction(model_scores, rule_boost)
# ...label assignment unchanged...
if final_score > 0.8:
label = "Extremely Sarcastic 🀨😏"
elif final_score > 0.7:
label = "Highly Sarcastic 🀨"
elif final_score > 0.6:
label = "Likely Sarcastic 😏"
elif final_score > 0.4:
label = "Possibly Sarcastic πŸ€”"
elif final_score > 0.3:
label = "Probably Sincere πŸ™‚"
else:
label = "Sincere 😊"
return {
'score': final_score,
'model_scores': model_scores,
'rule_boost': rule_boost,
'label': label,
'explanations': explanations,
'highlights': highlights
}
except Exception as e:
return {
'score': 0.0,
'label': f'Error: {str(e)}',
'explanations': [],
'highlights': [],
'model_scores': {}
}
# Streamlit UI
st.set_page_config(page_title="Enhanced Sarcasm Detector", page_icon="🀨", layout="wide")
st.title("πŸ—¨οΈ Enhanced Multi-Model Sarcasm Detector")
st.markdown("*Combining DistilBERT (Reddit-trained) + HelinIvan + Rule-based Analysis*")
# Load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.markdown(f"**Device:** {device}")
with st.spinner("Loading AI models..."):
models, tokenizers = load_models()
# Model status display
st.markdown("### πŸ€– Model Status")
status_cols = st.columns(2)
with status_cols[0]:
helinivan_status = "βœ… Loaded" if models.get('helinivan') else "❌ Failed"
st.markdown(f"**HelinIvan Model:** {helinivan_status}")
with status_cols[1]:
distilbert_status = "βœ… Loaded" if models.get('distilbert') else "❌ Failed"
st.markdown(f"**DistilBERT Model:** {distilbert_status}")
# Sidebar with examples and tips (shared)
with st.sidebar:
st.markdown("### πŸ’‘ **Quick Examples**")
example_buttons = [
("Oh great, more traffic πŸ™„", "social_media"),
("Yeah, I just LOVE waiting in line", "emphasis"),
("What could possibly go wrong?", "rhetorical"),
("Perfect timing as always...", "timing"),
("Thanks for the help genius", "reddit_style"),
("WOW so helpful!!!", "caps_sarcasm"),
("No shit Sherlock 🀑", "reddit_sarcasm"),
("Truly groundbreaking stuff here", "mock_praise")
]
for example_text, example_type in example_buttons:
if st.button(f"πŸ“ {example_text[:25]}...", key=example_type):
st.session_state.example_text = example_text
# --- Tabs for navigation ---
tab1, tab2 = st.tabs(["Single Message (Current)", "Context-Aware (Experimental)"])
# --- Tab 1: Single Message (Current) ---
with tab1:
st.markdown("### Single Message Sarcasm Detection")
st.markdown("Analyze sarcasm in a single message (no conversational context).")
col1, col2 = st.columns([2, 1])
with col1:
# Text input with real-time analysi
default_text = st.session_state.get('example_text', '')
user_text = st.text_area(
"Enter a paragraph for multi-model sarcasm analysis:",
value=default_text,
height=120,
placeholder="Try: 'Oh fantastic, another meeting that could have been an email πŸ™„ What a brilliant use of everyone's time...'"
)
# Real-time analysis
if user_text:
with st.spinner("Analyzing with multiple models..."):
analysis = analyze_text_realtime_current(user_text, models, tokenizers, device)
# Display highlighted text
st.markdown("### 🎯 **Analysis Results**")
highlighted_html = create_highlighted_text(user_text, analysis['highlights'])
st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True)
# Prediction and confidence
st.markdown(f"### **Prediction: {analysis['label']}**")
# Progress bar with custom colors
progress_color = "πŸ”΄" if analysis['score'] > 0.7 else "🟑" if analysis['score'] > 0.4 else "🟒"
st.write(f"**Ensemble Score: {analysis['score']:.3f}** {progress_color}")
st.progress(analysis['score'])
with col2:
if user_text and 'analysis' in locals():
# Confidence gauge
st.markdown("### πŸ“Š **Confidence Gauge**")
gauge_fig = create_confidence_gauge(analysis['score'])
st.plotly_chart(gauge_fig, use_container_width=True)
# Multi-model analysis section
if user_text and 'analysis' in locals() and analysis['model_scores']:
st.markdown("### πŸ” **Multi-Model Analysis**")
col3, col4 = st.columns([1, 1])
with col3:
st.markdown("#### πŸ“‹ **Individual Model Scores:**")
for model_name, score in analysis['model_scores'].items():
model_display = {
'helinivan': 'HelinIvan Model',
'distilbert': 'DistilBERT Model'
}
display_name = model_display.get(model_name, model_name)
st.write(f"β€’ **{display_name}:** {score:.3f}")
st.write(f"β€’ **Rule-based boost:** +{analysis['rule_boost']:.3f}")
st.write(f"β€’ **🎯 Final ensemble:** {analysis['score']:.3f}")
if analysis['explanations']:
st.markdown("#### πŸ” **Detected Patterns:**")
for i, explanation in enumerate(analysis['explanations'], 1):
st.write(f"{i}. {explanation}")
with col4:
st.markdown("#### πŸ“ˆ **Model Comparison:**")
comparison_fig = create_model_comparison_chart(
analysis['model_scores'],
analysis['rule_boost'],
analysis['score']
)
st.plotly_chart(comparison_fig, use_container_width=True)
# Pattern legend
if user_text and 'analysis' in locals() and analysis['highlights']:
st.markdown("### 🎨 **Highlighting Legend**")
legend_cols = st.columns(4)
legend_items = [
("Sarcastic Phrases", "#ff6b6b"),
("Emojis", "#4ecdc4"),
("Punctuation", "#45b7d1"),
("Questions", "#96ceb4"),
("Emphasis", "#feca57"),
("Repetition", "#ff9ff3"),
("Exaggeration", "#54a0ff"),
("Ellipsis", "#fd79a8")
]
for i, (label, color) in enumerate(legend_items):
with legend_cols[i % 4]:
safe_label = label.replace("<", "&lt;").replace(">", "&gt;")
st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True)
# --- Tab 2: Context-Aware (Experimental) ---
with tab2:
st.markdown("### Context-Aware Sarcasm Detection (Experimental)")
st.info("This feature is experimental. The models are **not yet trained** on context+reply pairs. Predictions are based on formatting the input as a sentence pair, but results may not be reliable.")
col1, col2 = st.columns([2, 1])
with col1:
context_text = st.text_area(
"Context (previous message):",
value=st.session_state.get('context_text', ''),
height=68,
placeholder="e.g. 'Can you finish this by today?'"
)
reply_text = st.text_area(
"Reply (current message):",
value=st.session_state.get('reply_text', st.session_state.get('example_text', '')),
height=80,
placeholder="e.g. 'Oh sure, because I have nothing else to do.'"
)
if reply_text:
with st.spinner("Analyzing with experimental context-aware input..."):
analysis_ctx = analyze_text_realtime_experiment(context_text, reply_text, models, tokenizers, device)
st.markdown("### 🎯 **Analysis Results**")
highlighted_html = create_highlighted_text(reply_text, analysis_ctx['highlights'])
st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True)
st.markdown(f"### **Prediction: {analysis_ctx['label']}**")
progress_color = "πŸ”΄" if analysis_ctx['score'] > 0.7 else "🟑" if analysis_ctx['score'] > 0.4 else "🟒"
st.write(f"**Ensemble Score: {analysis_ctx['score']:.3f}** {progress_color}")
st.progress(analysis_ctx['score'])
with col2:
if reply_text and 'analysis_ctx' in locals():
st.markdown("### πŸ“Š **Confidence Gauge**")
gauge_fig = create_confidence_gauge(analysis_ctx['score'])
st.plotly_chart(gauge_fig, use_container_width=True)
if reply_text and 'analysis_ctx' in locals() and analysis_ctx['model_scores']:
st.markdown("### πŸ” **Multi-Model Analysis**")
col3, col4 = st.columns([1, 1])
with col3:
st.markdown("#### πŸ“‹ **Individual Model Scores:**")
for model_name, score in analysis_ctx['model_scores'].items():
model_display = {
'helinivan': 'HelinIvan Model',
'distilbert': 'DistilBERT Model'
}
display_name = model_display.get(model_name, model_name)
st.write(f"β€’ **{display_name}:** {score:.3f}")
st.write(f"β€’ **Rule-based boost:** +{analysis_ctx['rule_boost']:.3f}")
st.write(f"β€’ **🎯 Final ensemble:** {analysis_ctx['score']:.3f}")
if analysis_ctx['explanations']:
st.markdown("#### πŸ” **Detected Patterns:**")
for i, explanation in enumerate(analysis_ctx['explanations'], 1):
st.write(f"{i}. {explanation}")
with col4:
st.markdown("#### πŸ“ˆ **Model Comparison:**")
comparison_fig = create_model_comparison_chart(
analysis_ctx['model_scores'],
analysis_ctx['rule_boost'],
analysis_ctx['score']
)
st.plotly_chart(comparison_fig, use_container_width=True)
if reply_text and 'analysis_ctx' in locals() and analysis_ctx['highlights']:
st.markdown("### 🎨 **Highlighting Legend**")
legend_cols = st.columns(4)
legend_items = [
("Sarcastic Phrases", "#ff6b6b"),
("Emojis", "#4ecdc4"),
("Punctuation", "#45b7d1"),
("Questions", "#96ceb4"),
("Emphasis", "#feca57"),
("Repetition", "#ff9ff3"),
("Exaggeration", "#54a0ff"),
("Ellipsis", "#fd79a8")
]
for i, (label, color) in enumerate(legend_items):
with legend_cols[i % 4]:
safe_label = label.replace("<", "&lt;").replace(">", "&gt;")
st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True)
# --- Shared tutorial and advanced settings ---
with st.expander("πŸŽ“ **Multi-Model Sarcasm Detection Guide**"):
st.markdown("""
### How the Enhanced Detection Works:
1. **πŸ€– HelinIvan Model**: General English sarcasm detection
2. **πŸ€– DistilBERT Model**: Specialized Reddit-trained sarcasm detector
3. **πŸ“ Rule-Based Analysis**: Linguistic patterns and social media cues
4. **🎯 Ensemble Method**: Combines all approaches with weighted averaging
### Model Advantages:
- **HelinIvan**: Good for formal and general sarcasm
- **DistilBERT (Reddit)**: Excellent for informal, social media style sarcasm
- **Rule-based**: Catches obvious patterns and cultural references
### Why Ensemble Works Better:
- **Robustness**: Multiple models reduce individual model weaknesses
- **Coverage**: Different training data covers different sarcasm styles
- **Confidence**: Agreement between models increases reliability
### Try These Reddit-Style Examples:
- "No shit Sherlock 🀑"
- "Thanks Captain Obvious"
- "Groundbreaking discovery there genius"
- "What a concept... mind blown 🀯"
""")
# Model weights adjustment (advanced users)
with st.expander("βš™οΈ **Advanced: Adjust Model Weights**"):
st.markdown("Fine-tune the ensemble by adjusting model importance:")
col_w1, col_w2, col_w3 = st.columns(3)
with col_w1:
helinivan_weight = st.slider("HelinIvan Weight", 0.0, 1.0, 0.4, 0.1)
with col_w2:
distilbert_weight = st.slider("DistilBERT Weight", 0.0, 1.0, 0.5, 0.1)
with col_w3:
rules_weight = st.slider("Rules Weight", 0.0, 0.5, 0.1, 0.05)
st.info(f"Weights - HelinIvan: {helinivan_weight}, DistilBERT: {distilbert_weight}, Rules: {rules_weight}")
# Clear session state
if st.button("πŸ”„ Clear Text", key="clear_main"):
if 'example_text' in st.session_state:
del st.session_state.example_text
st.rerun()