profplate commited on
Commit
a022204
·
verified ·
1 Parent(s): 065ba35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from html import escape
3
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
4
+ import torch
5
+
6
+ # Image captioning
7
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
8
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
9
+
10
+ # Ekman 6 basic emotions + neutral
11
+ classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None)
12
+
13
+ EMOTION_COLORS = {
14
+ "anger": "#ef4444",
15
+ "disgust": "#a3e635",
16
+ "fear": "#a855f7",
17
+ "joy": "#facc15",
18
+ "sadness": "#3b82f6",
19
+ "surprise": "#fb923c",
20
+ "neutral": "#94a3b8",
21
+ }
22
+
23
+ def analyze(image):
24
+ if image is None:
25
+ return "<p class='empty'>Upload an image to detect its basic emotions.</p>"
26
+
27
+ # Generate caption
28
+ image = image.convert("RGB")
29
+ inputs = blip_processor(image, return_tensors="pt")
30
+ with torch.no_grad():
31
+ caption_ids = blip_model.generate(**inputs, max_new_tokens=50)
32
+ caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
33
+ safe_caption = escape(caption)
34
+
35
+ # Classify emotions
36
+ results = classifier(caption)[0]
37
+ results.sort(key=lambda x: x["score"], reverse=True)
38
+
39
+ top = results[0]
40
+ top_color = EMOTION_COLORS.get(top["label"], "#666")
41
+
42
+ bars = []
43
+ for r in results:
44
+ color = EMOTION_COLORS.get(r["label"], "#666")
45
+ pct = r["score"] * 100
46
+ safe_label = escape(r["label"].upper())
47
+ bars.append(f"""
48
+ <div class="bar-row">
49
+ <span class="bar-label">{safe_label}</span>
50
+ <div class="bar-track">
51
+ <div class="bar-fill" style="width:{pct:.1f}%;background:{color}"></div>
52
+ </div>
53
+ <span class="bar-pct">{pct:.1f}%</span>
54
+ </div>""")
55
+
56
+ return f"""
57
+ <div class="caption-box">
58
+ <div class="caption-label">BLIP sees:</div>
59
+ <div class="caption-text">"{safe_caption}"</div>
60
+ </div>
61
+ <div class="verdict" style="background:{top_color}22;color:{top_color};border:1px solid {top_color}44">
62
+ {escape(top['label'].upper())} ({top['score']*100:.1f}%)
63
+ </div>
64
+ <div class="bars">{"".join(bars)}</div>
65
+ """
66
+
67
+ with gr.Blocks(title="Image Basic Emotions (Ekman 6)") as demo:
68
+ gr.Markdown("## Image Basic Emotions (Ekman 6)\nUpload an image. BLIP describes it, then a model detects 6 basic emotions + neutral.")
69
+
70
+ with gr.Row():
71
+ img_input = gr.Image(type="pil", label="Upload an image")
72
+ result = gr.HTML(
73
+ value="<p class='empty'>Your emotion analysis will appear here.</p>",
74
+ css_template="""
75
+ .caption-box {
76
+ background: #f0f4ff; border-radius: 10px; padding: 14px 18px;
77
+ margin-bottom: 16px; border: 1px solid #d0d8f0;
78
+ }
79
+ .caption-label { font-size: 0.75em; color: #888; text-transform: uppercase; letter-spacing: 0.05em; }
80
+ .caption-text { font-size: 1.1em; margin-top: 4px; color: #333; }
81
+ .verdict {
82
+ text-align: center; font-weight: 700; font-size: 1.3em;
83
+ padding: 10px; border-radius: 8px; margin-bottom: 14px;
84
+ }
85
+ .bars { display: flex; flex-direction: column; gap: 8px; }
86
+ .bar-row { display: flex; align-items: center; gap: 10px; }
87
+ .bar-label { width: 80px; font-weight: 600; font-size: 0.8em; text-align: right; }
88
+ .bar-track {
89
+ flex: 1; height: 22px; background: #f0f0f0; border-radius: 6px; overflow: hidden;
90
+ }
91
+ .bar-fill { height: 100%; border-radius: 6px; }
92
+ .bar-pct { width: 55px; font-family: monospace; font-size: 0.85em; color: #666; }
93
+ .empty { color: #999; text-align: center; padding: 40px 20px; }
94
+ """
95
+ )
96
+
97
+ img_input.change(fn=analyze, inputs=img_input, outputs=result)
98
+
99
+ demo.launch()