File size: 30,458 Bytes
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
 
3d13202
 
a00e3f2
 
 
 
 
 
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00e3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d13202
a00e3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d13202
a00e3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d13202
a00e3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d13202
a00e3f2
3d13202
 
 
 
 
 
 
a00e3f2
3d13202
 
 
 
a00e3f2
3d13202
 
 
 
a00e3f2
3d13202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
import nltk
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
import torch.nn.functional as F
import spacy
import re
from nltk.sentiment import SentimentIntensityAnalyzer
import emoji
import plotly.graph_objects as go
import plotly.express as px
from collections import Counter
import time
import numpy as np


# Configuration - Multiple Models
MODELS = {
    "helinivan": "helinivan/English-sarcasm-detector",
    "distilbert": "dima806/sarcasm-detection-distilbert"  
}

# Initialize NLTK VADER analyzer
try:
    nltk.data.path.append('/app/nltk_data')
    sia = SentimentIntensityAnalyzer()
except Exception as e:
    st.error(f"Error downloading NLTK data: {e}")
    sia = None

# Cache multiple models & tokenizers
@st.cache_resource
def load_models():
    models = {}
    tokenizers = {}
    
    for name, model_path in MODELS.items():
        try:
            model = AutoModelForSequenceClassification.from_pretrained(model_path, cache_dir="/tmp/hf_cache")
            tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/tmp/hf_cache")
            model.eval()
            models[name] = model
            tokenizers[name] = tokenizer
            st.success(f"βœ… Loaded {name} model successfully")
        except Exception as e:
            st.error(f"❌ Failed to load {name} model: {str(e)}")
            models[name] = None
            tokenizers[name] = None
    
    return models, tokenizers

# Lazy-load SpaCy (optional - not used in current implementation)
def load_spacy():
    try:
        return spacy.load("en_core_web_sm")
    except OSError:
        st.warning("SpaCy model 'en_core_web_sm' not found. Some features may be limited.")
        return None

# Pattern detection functions with highlighting info
def social_media_sarcasm_cues(text: str) -> tuple[float, list, list]:
    explanations = []
    highlights = []
    boost = 0.0
    text_lower = text.lower()
    
    # Enhanced sarcasm phrases (including Reddit-style patterns)
    sarcasm_phrases = [
        "oh sure", "yeah right", "of course", "totally", "absolutely", 
        "perfect", "wonderful", "fantastic", "amazing", "brilliant",
        "great job", "well done", "nice one", "good going", "way to go",
        "real smooth", "genius move", "solid plan", "makes sense",
        "just perfect", "exactly what i needed", "this is fine",
        # Reddit-style additions
        "thanks genius", "no shit sherlock", "well duh", "captain obvious",
        "groundbreaking", "revolutionary", "what a concept", "mind blown",
        "shocking", "who would have guessed", "truly inspiring"
    ]
    
    for phrase in sarcasm_phrases:
        # Find all occurrences of the phrase
        for match in re.finditer(re.escape(phrase), text_lower):
            boost += 0.2
            explanations.append(f"Sarcastic phrase: '{phrase}'")
            highlights.append({
                'start': match.start(),
                'end': match.end(),
                'type': 'sarcastic_phrase',
                'text': phrase
            })
    
    # Exaggerated expressions
    exaggerated_match = re.search(r'\b(SO|TOTALLY|ABSOLUTELY|REALLY|VERY)\b.*\b(great|good|perfect|amazing|helpful|useful)\b', text, re.IGNORECASE)
    if exaggerated_match:
        boost += 0.25
        explanations.append("Exaggerated positive expression")
        highlights.append({
            'start': exaggerated_match.start(),
            'end': exaggerated_match.end(),
            'type': 'exaggerated',
            'text': exaggerated_match.group()
        })
    
    return boost, explanations, highlights

def emoji_punctuation_analysis(text: str) -> tuple[float, list, list]:
    explanations = []
    highlights = []
    boost = 0.0
    
    # Extract emojis with positions
    try:
        emojis = emoji.emoji_list(text)
        sarcastic_emojis = ['πŸ™„', '😏', 'πŸ˜’', 'πŸ€”', '🀨', '😀', '🀷', 'πŸ‘', 'πŸ™ƒ', '🀑', 'πŸ’€', '🀯']
        
        for emoji_info in emojis:
            if emoji_info['emoji'] in sarcastic_emojis:
                boost += 0.15
                explanations.append(f"Sarcastic emoji: {emoji_info['emoji']}")
                highlights.append({
                    'start': emoji_info['match_start'],
                    'end': emoji_info['match_end'],
                    'type': 'sarcastic_emoji',
                    'text': emoji_info['emoji']
                })
    except Exception as e:
        # Fallback if emoji library has issues
        pass
    
    # Excessive punctuation
    for match in re.finditer(r'[!?]{2,}', text):
        boost += 0.1
        explanations.append(f"Excessive punctuation: {match.group()}")
        highlights.append({
            'start': match.start(),
            'end': match.end(),
            'type': 'excessive_punct',
            'text': match.group()
        })
    
    # Ellipsis (often sarcastic)
    for match in re.finditer(r'\.{3,}', text):
        boost += 0.15
        explanations.append(f"Trailing ellipsis: {match.group()}")
        highlights.append({
            'start': match.start(),
            'end': match.end(),
            'type': 'ellipsis',
            'text': match.group()
        })
    
    return boost, explanations, highlights

def rhetorical_questions_analysis(text: str) -> tuple[float, list, list]:
    explanations = []
    highlights = []
    boost = 0.0
    
    rhetorical_patterns = [
        (r'what could possibly go wrong\?', "Rhetorical question"),
        (r'who would have thought\?', "Rhetorical question"),
        (r'seriously\?', "Emphatic question"),
        (r'really\?.*really\?', "Repeated question"),
        (r'no way\?', "Disbelief question"),
        (r'you don\'t say\?', "Sarcastic response"),
        (r'shocking.*\?', "Mock surprise")
    ]
    
    for pattern, description in rhetorical_patterns:
        for match in re.finditer(pattern, text, re.IGNORECASE):
            boost += 0.3
            explanations.append(description)
            highlights.append({
                'start': match.start(),
                'end': match.end(),
                'type': 'rhetorical_question',
                'text': match.group()
            })
    
    return boost, explanations, highlights

def capitalization_analysis(text: str) -> tuple[float, list, list]:
    explanations = []
    highlights = []
    boost = 0.0
    
    # ALL CAPS words
    for match in re.finditer(r'\b[A-Z]{3,}\b', text):
        if match.group() not in ['AND', 'THE', 'FOR', 'BUT', 'YOU', 'ARE']:
            boost += 0.1
            explanations.append(f"Emphatic caps: {match.group()}")
            highlights.append({
                'start': match.start(),
                'end': match.end(),
                'type': 'caps_emphasis',
                'text': match.group()
            })
    
    # Letter repetition
    for match in re.finditer(r'(.)\1{2,}', text):
        boost += 0.1
        explanations.append(f"Letter repetition: {match.group()}")
        highlights.append({
            'start': match.start(),
            'end': match.end(),
            'type': 'repetition',
            'text': match.group()
        })
    
    return boost, explanations, highlights

# Combined analysis with highlighting
def enhanced_rule_analysis(text: str) -> tuple[float, list, list]:
    all_explanations = []
    all_highlights = []
    total_boost = 0.0
    
    # Apply all analysis functions
    boost1, exp1, high1 = social_media_sarcasm_cues(text)
    boost2, exp2, high2 = emoji_punctuation_analysis(text)
    boost3, exp3, high3 = rhetorical_questions_analysis(text)
    boost4, exp4, high4 = capitalization_analysis(text)
    
    total_boost = boost1 + boost2 + boost3 + boost4
    all_explanations.extend(exp1 + exp2 + exp3 + exp4)
    all_highlights.extend(high1 + high2 + high3 + high4)
    
    # Cap the total boost
    total_boost = min(total_boost, 0.8)
    
    return total_boost, all_explanations, all_highlights

# Multi-model prediction function
def get_model_predictions_current(text: str, models: dict, tokenizers: dict, device) -> dict:
    predictions = {}
    
    for name, model in models.items():
        if model is None or tokenizers[name] is None:
            predictions[name] = 0.0
            continue
            
        try:
            inputs = tokenizers[name]([text], return_tensors="pt", truncation=True, padding=True).to(device)
            model.to(device)
            
            with torch.no_grad():
                logits = model(**inputs).logits
                # Handle different output formats
                if logits.shape[-1] == 2:  # Binary classification
                    score = F.softmax(logits, dim=-1)[0, 1].item()
                else:  # Single output
                    score = torch.sigmoid(logits)[0, 0].item()
                
                predictions[name] = score
        except Exception as e:
            st.warning(f"Error with {name} model: {str(e)}")
            predictions[name] = 0.0
    
    return predictions

# Modify get_model_predictions to accept context and reply
def get_model_predictions_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict:
    predictions = {}
    for name, model in models.items():
        if model is None or tokenizers[name] is None:
            predictions[name] = 0.0
            continue
        try:
            # Use sentence-pair interface
            inputs = tokenizers[name](
                context,
                reply,
                return_tensors="pt",
                truncation=True,
                padding=True
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            model.to(device)
            with torch.no_grad():
                logits = model(**inputs).logits
                if logits.shape[-1] == 2:  # Binary classification
                    score = F.softmax(logits, dim=-1)[0, 1].item()
                else:
                    score = torch.sigmoid(logits)[0, 0].item()
                predictions[name] = score
        except Exception as e:
            st.warning(f"Error with {name} model: {str(e)}")
            predictions[name] = 0.0
    return predictions

# Enhanced ensemble prediction
def ensemble_prediction(model_scores: dict, rule_boost: float, weights: dict = None) -> float:
    if weights is None:
        # Default weights - adjust based on model performance
        weights = {
            'helinivan': 0.4,
            'distilbert': 0.5,  # Higher weight for Reddit-trained model
            'rules': 0.1
        }
    
    ensemble_score = 0.0
    total_weight = 0.0
    
    # Weighted average of model predictions
    for model_name, score in model_scores.items():
        if score > 0:  # Only include valid predictions
            weight = weights.get(model_name, 0.3)
            ensemble_score += score * weight
            total_weight += weight
    
    # Add rule-based contribution
    if total_weight > 0:
        ensemble_score = ensemble_score / total_weight
    
    # Apply rule-based boost
    final_score = min(ensemble_score + (rule_boost * weights.get('rules', 0.1)), 1.0)
    
    return final_score

# Create highlighted text HTML
def create_highlighted_text(text: str, highlights: list) -> str:
    if not isinstance(text, str):
        return ""
    
    if not highlights:
        return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
    
    # Sort highlights by start position
    sorted_highlights = sorted(highlights, key=lambda x: x['start'])
    
    color_map = {
        'sarcastic_phrase': '#ff6b6b',
        'sarcastic_emoji': '#4ecdc4',
        'excessive_punct': '#45b7d1',
        'rhetorical_question': '#96ceb4',
        'caps_emphasis': '#feca57',
        'repetition': '#ff9ff3',
        'exaggerated': '#54a0ff',
        'ellipsis': '#fd79a8'
    }
    
    result = ""
    last_end = 0
    
    for highlight in sorted_highlights:
        start, end = highlight['start'], highlight['end']
        highlight_type = highlight['type']
        color = color_map.get(highlight_type, '#dda0dd')
        
        # Add text before highlight
        if start > last_end:
            before_text = text[last_end:start]
            result += before_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
        
        # Add highlighted text
        highlighted_text = text[start:end]
        safe_text = highlighted_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
        result += f'<span style="background-color: {color}; padding: 2px 4px; border-radius: 3px; color: black;">{safe_text}</span>'
        
        last_end = end
    
    # Add remaining text
    if last_end < len(text):
        remaining_text = text[last_end:]
        result += remaining_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
    
    return result

# Enhanced confidence gauge
def create_confidence_gauge(score: float) -> go.Figure:
    fig = go.Figure(go.Indicator(
        mode = "gauge+number+delta",
        value = score,
        domain = {'x': [0, 1], 'y': [0, 1]},
        title = {'text': "Ensemble Sarcasm Score"},
        delta = {'reference': 0.5},
        gauge = {
            'axis': {'range': [None, 1]},
            'bar': {'color': "darkblue"},
            'steps': [
                {'range': [0, 0.3], 'color': "lightgray"},
                {'range': [0.3, 0.6], 'color': "yellow"},
                {'range': [0.6, 1], 'color': "red"}
            ],
            'threshold': {
                'line': {'color': "red", 'width': 4},
                'thickness': 0.75,
                'value': 0.7
            }
        }
    ))
    
    fig.update_layout(height=300)
    return fig

# Multi-model feature importance visualization
def create_model_comparison_chart(model_scores: dict, rule_boost: float, final_score: float) -> go.Figure:
    models = list(model_scores.keys())
    scores = list(model_scores.values())
    
    # Add rule-based and final scores
    models.extend(['Rule-based', 'Final Ensemble'])
    scores.extend([rule_boost, final_score])
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    
    fig = go.Figure(go.Bar(
        x=models,
        y=scores,
        marker_color=colors[:len(models)],
        text=[f'{score:.3f}' for score in scores],
        textposition='auto',
    ))
    
    fig.update_layout(
        title="Model Comparison & Ensemble Result",
        yaxis_title="Sarcasm Score",
        height=400,
        showlegend=False
    )
    
    return fig

# Real-time analysis function with multiple models
def analyze_text_realtime_current(text: str, models: dict, tokenizers: dict, device) -> dict:
    if not text.strip():
        return {
            'score': 0.0,
            'label': 'Enter text to analyze',
            'explanations': [],
            'highlights': [],
            'model_scores': {}
        }
    
    try:
        # Get rule-based analysis
        rule_boost, explanations, highlights = enhanced_rule_analysis(text)
        
        # Get predictions from all models
        model_scores = get_model_predictions_current(text, models, tokenizers, device)
        
        # Ensemble prediction
        final_score = ensemble_prediction(model_scores, rule_boost)
        
        # Determine label
        if final_score > 0.8:
            label = "Extremely Sarcastic 🀨😏"
        elif final_score > 0.7:
            label = "Highly Sarcastic 🀨"
        elif final_score > 0.6:
            label = "Likely Sarcastic 😏"
        elif final_score > 0.4:
            label = "Possibly Sarcastic πŸ€”"
        elif final_score > 0.3:
            label = "Probably Sincere πŸ™‚"
        else:
            label = "Sincere 😊"
        
        return {
            'score': final_score,
            'model_scores': model_scores,
            'rule_boost': rule_boost,
            'label': label,
            'explanations': explanations,
            'highlights': highlights
        }
    
    except Exception as e:
        return {
            'score': 0.0,
            'label': f'Error: {str(e)}',
            'explanations': [],
            'highlights': [],
            'model_scores': {}
        }

# Modify analyze_text_realtime to accept context and reply
def analyze_text_realtime_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict:
    if not reply.strip():
        return {
            'score': 0.0,
            'label': 'Enter context and reply to analyze',
            'explanations': [],
            'highlights': [],
            'model_scores': {}
        }
    try:
        # Use reply for rule-based analysis (context is not used in rules)
        rule_boost, explanations, highlights = enhanced_rule_analysis(reply)
        # Get predictions from all models using context and reply
        model_scores = get_model_predictions_experiment(context, reply, models, tokenizers, device)
        final_score = ensemble_prediction(model_scores, rule_boost)
        # ...label assignment unchanged...
        if final_score > 0.8:
            label = "Extremely Sarcastic 🀨😏"
        elif final_score > 0.7:
            label = "Highly Sarcastic 🀨"
        elif final_score > 0.6:
            label = "Likely Sarcastic 😏"
        elif final_score > 0.4:
            label = "Possibly Sarcastic πŸ€”"
        elif final_score > 0.3:
            label = "Probably Sincere πŸ™‚"
        else:
            label = "Sincere 😊"
        return {
            'score': final_score,
            'model_scores': model_scores,
            'rule_boost': rule_boost,
            'label': label,
            'explanations': explanations,
            'highlights': highlights
        }
    except Exception as e:
        return {
            'score': 0.0,
            'label': f'Error: {str(e)}',
            'explanations': [],
            'highlights': [],
            'model_scores': {}
        }

# Streamlit UI
st.set_page_config(page_title="Enhanced Sarcasm Detector", page_icon="🀨", layout="wide")

st.title("πŸ—¨οΈ Enhanced Multi-Model Sarcasm Detector")
st.markdown("*Combining DistilBERT (Reddit-trained) + HelinIvan + Rule-based Analysis*")

# Load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.markdown(f"**Device:** {device}")

with st.spinner("Loading AI models..."):
    models, tokenizers = load_models()

# Model status display
st.markdown("### πŸ€– Model Status")
status_cols = st.columns(2)
with status_cols[0]:
    helinivan_status = "βœ… Loaded" if models.get('helinivan') else "❌ Failed"
    st.markdown(f"**HelinIvan Model:** {helinivan_status}")
with status_cols[1]:
    distilbert_status = "βœ… Loaded" if models.get('distilbert') else "❌ Failed"
    st.markdown(f"**DistilBERT Model:** {distilbert_status}")

# Sidebar with examples and tips (shared)
with st.sidebar:
    st.markdown("### πŸ’‘ **Quick Examples**")
    example_buttons = [
        ("Oh great, more traffic πŸ™„", "social_media"),
        ("Yeah, I just LOVE waiting in line", "emphasis"),
        ("What could possibly go wrong?", "rhetorical"),
        ("Perfect timing as always...", "timing"),
        ("Thanks for the help genius", "reddit_style"),
        ("WOW so helpful!!!", "caps_sarcasm"),
        ("No shit Sherlock 🀑", "reddit_sarcasm"),
        ("Truly groundbreaking stuff here", "mock_praise")
    ]
    for example_text, example_type in example_buttons:
        if st.button(f"πŸ“ {example_text[:25]}...", key=example_type):
            st.session_state.example_text = example_text

# --- Tabs for navigation ---
tab1, tab2 = st.tabs(["Single Message (Current)", "Context-Aware (Experimental)"])

# --- Tab 1: Single Message (Current) ---
with tab1:
    st.markdown("### Single Message Sarcasm Detection")
    st.markdown("Analyze sarcasm in a single message (no conversational context).")
    col1, col2 = st.columns([2, 1])
    with col1:
        # Text input with real-time analysi
        default_text = st.session_state.get('example_text', '')
        user_text = st.text_area(
            "Enter a paragraph for multi-model sarcasm analysis:",
            value=default_text,
            height=120,
            placeholder="Try: 'Oh fantastic, another meeting that could have been an email πŸ™„ What a brilliant use of everyone's time...'"
        )
        
        # Real-time analysis
        if user_text:
            with st.spinner("Analyzing with multiple models..."):
                analysis = analyze_text_realtime_current(user_text, models, tokenizers, device)
                
            # Display highlighted text
            st.markdown("### 🎯 **Analysis Results**")
            highlighted_html = create_highlighted_text(user_text, analysis['highlights'])
            st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True)
            
            # Prediction and confidence
            st.markdown(f"### **Prediction: {analysis['label']}**")
            
            # Progress bar with custom colors
            progress_color = "πŸ”΄" if analysis['score'] > 0.7 else "🟑" if analysis['score'] > 0.4 else "🟒"
            st.write(f"**Ensemble Score: {analysis['score']:.3f}** {progress_color}")
            st.progress(analysis['score'])
            
    with col2:
        if user_text and 'analysis' in locals():
             # Confidence gauge
            st.markdown("### πŸ“Š **Confidence Gauge**")
            gauge_fig = create_confidence_gauge(analysis['score'])
            st.plotly_chart(gauge_fig, use_container_width=True)
            
    # Multi-model analysis section
    if user_text and 'analysis' in locals() and analysis['model_scores']:
        st.markdown("### πŸ” **Multi-Model Analysis**")
        col3, col4 = st.columns([1, 1])
        
        with col3:
            st.markdown("#### πŸ“‹ **Individual Model Scores:**")
            for model_name, score in analysis['model_scores'].items():
                model_display = {
                    'helinivan': 'HelinIvan Model',
                    'distilbert': 'DistilBERT Model'
                }
                display_name = model_display.get(model_name, model_name)
                st.write(f"β€’ **{display_name}:** {score:.3f}")
                
            st.write(f"β€’ **Rule-based boost:** +{analysis['rule_boost']:.3f}")
            st.write(f"β€’ **🎯 Final ensemble:** {analysis['score']:.3f}")
            
            if analysis['explanations']:
                st.markdown("#### πŸ” **Detected Patterns:**")
                for i, explanation in enumerate(analysis['explanations'], 1):
                    st.write(f"{i}. {explanation}")
                    
        with col4:
            st.markdown("#### πŸ“ˆ **Model Comparison:**")
            comparison_fig = create_model_comparison_chart(
                analysis['model_scores'], 
                analysis['rule_boost'], 
                analysis['score']
            )
            st.plotly_chart(comparison_fig, use_container_width=True)
            
    # Pattern legend
    if user_text and 'analysis' in locals() and analysis['highlights']:
        st.markdown("### 🎨 **Highlighting Legend**")
        legend_cols = st.columns(4)
        legend_items = [
            ("Sarcastic Phrases", "#ff6b6b"),
            ("Emojis", "#4ecdc4"),
            ("Punctuation", "#45b7d1"),
            ("Questions", "#96ceb4"),
            ("Emphasis", "#feca57"),
            ("Repetition", "#ff9ff3"),
            ("Exaggeration", "#54a0ff"),
            ("Ellipsis", "#fd79a8")
        ]
        for i, (label, color) in enumerate(legend_items):
            with legend_cols[i % 4]:
                safe_label = label.replace("<", "&lt;").replace(">", "&gt;")
                st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True)

# --- Tab 2: Context-Aware (Experimental) ---
with tab2:
    st.markdown("### Context-Aware Sarcasm Detection (Experimental)")
    st.info("This feature is experimental. The models are **not yet trained** on context+reply pairs. Predictions are based on formatting the input as a sentence pair, but results may not be reliable.")
    col1, col2 = st.columns([2, 1])
    with col1:
        context_text = st.text_area(
            "Context (previous message):",
            value=st.session_state.get('context_text', ''),
            height=68,
            placeholder="e.g. 'Can you finish this by today?'"
        )
        reply_text = st.text_area(
            "Reply (current message):",
            value=st.session_state.get('reply_text', st.session_state.get('example_text', '')),
            height=80,
            placeholder="e.g. 'Oh sure, because I have nothing else to do.'"
        )
        if reply_text:
            with st.spinner("Analyzing with experimental context-aware input..."):
                analysis_ctx = analyze_text_realtime_experiment(context_text, reply_text, models, tokenizers, device)
            st.markdown("### 🎯 **Analysis Results**")
            highlighted_html = create_highlighted_text(reply_text, analysis_ctx['highlights'])
            st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True)
            st.markdown(f"### **Prediction: {analysis_ctx['label']}**")
            progress_color = "πŸ”΄" if analysis_ctx['score'] > 0.7 else "🟑" if analysis_ctx['score'] > 0.4 else "🟒"
            st.write(f"**Ensemble Score: {analysis_ctx['score']:.3f}** {progress_color}")
            st.progress(analysis_ctx['score'])
    with col2:
        if reply_text and 'analysis_ctx' in locals():
            st.markdown("### πŸ“Š **Confidence Gauge**")
            gauge_fig = create_confidence_gauge(analysis_ctx['score'])
            st.plotly_chart(gauge_fig, use_container_width=True)
    if reply_text and 'analysis_ctx' in locals() and analysis_ctx['model_scores']:
        st.markdown("### πŸ” **Multi-Model Analysis**")
        col3, col4 = st.columns([1, 1])
        with col3:
            st.markdown("#### πŸ“‹ **Individual Model Scores:**")
            for model_name, score in analysis_ctx['model_scores'].items():
                model_display = {
                    'helinivan': 'HelinIvan Model',
                    'distilbert': 'DistilBERT Model'
                }
                display_name = model_display.get(model_name, model_name)
                st.write(f"β€’ **{display_name}:** {score:.3f}")
            st.write(f"β€’ **Rule-based boost:** +{analysis_ctx['rule_boost']:.3f}")
            st.write(f"β€’ **🎯 Final ensemble:** {analysis_ctx['score']:.3f}")
            if analysis_ctx['explanations']:
                st.markdown("#### πŸ” **Detected Patterns:**")
                for i, explanation in enumerate(analysis_ctx['explanations'], 1):
                    st.write(f"{i}. {explanation}")
        with col4:
            st.markdown("#### πŸ“ˆ **Model Comparison:**")
            comparison_fig = create_model_comparison_chart(
                analysis_ctx['model_scores'], 
                analysis_ctx['rule_boost'], 
                analysis_ctx['score']
            )
            st.plotly_chart(comparison_fig, use_container_width=True)
    if reply_text and 'analysis_ctx' in locals() and analysis_ctx['highlights']:
        st.markdown("### 🎨 **Highlighting Legend**")
        legend_cols = st.columns(4)
        legend_items = [
            ("Sarcastic Phrases", "#ff6b6b"),
            ("Emojis", "#4ecdc4"),
            ("Punctuation", "#45b7d1"),
            ("Questions", "#96ceb4"),
            ("Emphasis", "#feca57"),
            ("Repetition", "#ff9ff3"),
            ("Exaggeration", "#54a0ff"),
            ("Ellipsis", "#fd79a8")
        ]
        for i, (label, color) in enumerate(legend_items):
            with legend_cols[i % 4]:
                safe_label = label.replace("<", "&lt;").replace(">", "&gt;")
                st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True)

# --- Shared tutorial and advanced settings ---
with st.expander("πŸŽ“ **Multi-Model Sarcasm Detection Guide**"):
    st.markdown("""
    ### How the Enhanced Detection Works:
    1. **πŸ€– HelinIvan Model**: General English sarcasm detection
    2. **πŸ€– DistilBERT Model**: Specialized Reddit-trained sarcasm detector
    3. **πŸ“ Rule-Based Analysis**: Linguistic patterns and social media cues
    4. **🎯 Ensemble Method**: Combines all approaches with weighted averaging

    ### Model Advantages:
    - **HelinIvan**: Good for formal and general sarcasm
    - **DistilBERT (Reddit)**: Excellent for informal, social media style sarcasm
    - **Rule-based**: Catches obvious patterns and cultural references

    ### Why Ensemble Works Better:
    - **Robustness**: Multiple models reduce individual model weaknesses
    - **Coverage**: Different training data covers different sarcasm styles
    - **Confidence**: Agreement between models increases reliability

    ### Try These Reddit-Style Examples:
    - "No shit Sherlock 🀑"
    - "Thanks Captain Obvious"
    - "Groundbreaking discovery there genius"
    - "What a concept... mind blown 🀯"
    """)

# Model weights adjustment (advanced users)
with st.expander("βš™οΈ **Advanced: Adjust Model Weights**"):
    st.markdown("Fine-tune the ensemble by adjusting model importance:")
    col_w1, col_w2, col_w3 = st.columns(3)
    with col_w1:
        helinivan_weight = st.slider("HelinIvan Weight", 0.0, 1.0, 0.4, 0.1)
    with col_w2:
        distilbert_weight = st.slider("DistilBERT Weight", 0.0, 1.0, 0.5, 0.1)
    with col_w3:
        rules_weight = st.slider("Rules Weight", 0.0, 0.5, 0.1, 0.05)
    st.info(f"Weights - HelinIvan: {helinivan_weight}, DistilBERT: {distilbert_weight}, Rules: {rules_weight}")

# Clear session state
if st.button("πŸ”„ Clear Text", key="clear_main"):
    if 'example_text' in st.session_state:
        del st.session_state.example_text
    st.rerun()