Shani13524 commited on
Commit
0ae8ad0
·
verified ·
1 Parent(s): cf0f4c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -323
app.py CHANGED
@@ -1,350 +1,294 @@
1
- import os, json, gc, random
2
- from pathlib import Path
3
- from typing import List, Tuple, Dict
4
-
5
- import gradio as gr
6
- import numpy as np
7
  from PIL import Image
8
 
9
  import torch
10
- import open_clip
 
 
11
  from datasets import load_dataset
12
  from sklearn.neighbors import NearestNeighbors
13
- from diffusers import StableDiffusionImageVariationPipeline
14
-
15
- # -----------------------------
16
- # Config (CPU-friendly)
17
- # -----------------------------
18
- DATASET_ID = "tukey/human_face_emotions_roboflow"
19
- EMB_MODEL_NAME = "ViT-H-14" # open_clip model name
20
- EMB_PRETRAINED = "laion2b_s32b_b79k" # laion/CLIP-ViT-H-14-laion2B-s32B-b79K
21
- GEN_MODEL_ID = "lambdalabs/sd-image-variations-diffusers"
22
-
23
- CACHE_DIR = Path("./cache"); CACHE_DIR.mkdir(parents=True, exist_ok=True)
24
- EMB_MEMMAP_PATH = CACHE_DIR / "clip_vith14_laion2b.float32.memmap"
25
- LABELS_MEMMAP_PATH = CACHE_DIR / "labels.U32.memmap"
26
- KNN_META_PATH = CACHE_DIR / "knn_meta.json"
27
-
28
- # tiny index + light generation
29
- INDEX_MAX_DEFAULT = 80
30
- BATCH_SIZE_DEFAULT = 32
31
- N_SYN_DEFAULT = 3
32
- STEPS_DEFAULT = 12
33
- GUIDANCE_SCALES = [2.5, 3.0, 3.5, 4.0]
34
- NUM_SYN_TO_SHOW = 5
35
-
36
- DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu")
37
-
38
- # -----------------------------
39
- # Canonical labels + stress
40
- # -----------------------------
41
- CANON = {"anger","disgust","fear","happy","neutral","sad","surprise","contempt"}
42
- CANON_MAP = {
43
- "angry": "anger", "happiness": "happy", "happy": "happy",
44
- "sadness": "sad", "sad": "sad", "surprised": "surprise", "surprise": "surprise",
45
- "content": "neutral", "calm": "neutral", "neutral": "neutral",
46
- "contempt": "contempt", "fear": "fear", "disgust": "disgust", "anger": "anger",
47
- }
48
- STRESS_W = {"anger":0.95,"fear":0.90,"sad":0.80,"disgust":0.70,"contempt":0.65,"surprise":0.45,"neutral":0.20,"happy":0.05}
49
- def _bucket(pct: float) -> str: return "Low" if pct < 33 else ("Medium" if pct < 66 else "High")
50
- def stress_from_top3(res: List[Dict]) -> Tuple[float, str]:
51
- probs = {}
52
- for r in res:
53
- lbl = CANON_MAP.get(str(r["emotion"]).lower(), str(r["emotion"]).lower())
54
- if lbl not in CANON: continue
55
- probs[lbl] = probs.get(lbl, 0.0) + float(r["confidence_pct"]) / 100.0
56
- Z = sum(probs.values()) or 1.0
57
- for k in list(probs): probs[k] /= Z
58
- s01 = sum(probs.get(k, 0.0) * STRESS_W.get(k, 0.0) for k in probs)
59
- s01 = max(0.0, min(1.0, s01))
60
- pct = round(s01 * 100.0, 2)
61
- return pct, _bucket(pct)
62
-
63
- # -----------------------------
64
- # Lazy globals
65
- # -----------------------------
66
- _openclip_model = None
67
- _preprocess = None
68
- _nn = None
69
- _X = None
70
- _labels_source = None
71
- _gen_pipe = None
72
- _dataset_for_labels = None
73
-
74
- # -----------------------------
75
- # Init / cache helpers
76
- # -----------------------------
77
- def _load_openclip():
78
- global _openclip_model, _preprocess
79
- if _openclip_model is not None: return _openclip_model, _preprocess
80
- model, _, preprocess = open_clip.create_model_and_transforms(
81
- model_name=EMB_MODEL_NAME, pretrained=EMB_PRETRAINED, device=DEVICE
82
  )
83
- model.eval()
84
- _openclip_model, _preprocess = model, preprocess
85
- return _openclip_model, _preprocess
86
-
87
- def _fit_knn(X): return NearestNeighbors(metric="cosine", algorithm="brute").fit(X)
88
-
89
- def _ensure_knn_index(index_max: int, batch_size: int, progress: gr.Progress | None = None):
90
- """Build (first run) or load a tiny memmap + KNN over a subset of the dataset."""
91
- global _nn, _X, _labels_source, _dataset_for_labels
92
-
93
- if _nn is not None and _X is not None:
94
- return
95
-
96
- dataset = load_dataset(DATASET_ID, split="train")
97
- if index_max:
98
- dataset = dataset.select(range(min(index_max, len(dataset))))
99
- _dataset_for_labels = dataset
100
- N = len(dataset)
101
-
102
- if EMB_MEMMAP_PATH.exists() and KNN_META_PATH.exists():
103
- meta = json.load(open(KNN_META_PATH))
104
- if int(meta.get("N", -1)) == N:
105
- D = int(meta["D"])
106
- X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
107
- labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,)) if LABELS_MEMMAP_PATH.exists() else None
108
- _X = X; _labels_source = labels; _nn = _fit_knn(X)
109
- return
110
-
111
- model, preprocess = _load_openclip()
112
- labels_mm = np.memmap(LABELS_MEMMAP_PATH, mode="w+", dtype="U32", shape=(N,))
113
- X_w = None; D = None
114
-
115
- with torch.no_grad():
116
- for start in range(0, N, batch_size):
117
- end = min(start + batch_size, N)
118
- imgs = [dataset[i]["image"].convert("RGB") for i in range(start, end)]
119
- x = torch.stack([preprocess(im) for im in imgs])
120
- if DEVICE in ("cuda", "mps"): x = x.to(DEVICE)
121
- v = model.encode_image(x).float()
122
- v = v / v.norm(dim=-1, keepdim=True)
123
- if X_w is None:
124
- D = v.shape[1]
125
- X_w = np.memmap(EMB_MEMMAP_PATH, mode="w+", dtype="float32", shape=(N, D))
126
- X_w[start:end] = v.detach().cpu().numpy()
127
- for i in range(start, end):
128
- try: labels_mm[i] = str(dataset[i]["qa"][0]["answer"] or "")
129
- except Exception: labels_mm[i] = ""
130
- if progress: progress(((end)/N), desc=f"Building index {end}/{N}")
131
-
132
- del X_w; gc.collect()
133
- json.dump({"N": int(N), "D": int(D)}, open(KNN_META_PATH, "w"))
134
- X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
135
- labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,))
136
- _X = X; _labels_source = labels; _nn = _fit_knn(X)
137
-
138
- def _label_by_idx(i: int):
139
- global _labels_source, _dataset_for_labels
140
- if _labels_source is not None:
141
- lab = str(_labels_source[i]); return lab if lab else None
142
- try: return _dataset_for_labels[i]["qa"][0]["answer"]
143
- except Exception: return None
144
-
145
- # -----------------------------
146
- # Embedding + inference utils
147
- # -----------------------------
148
  def embed_image(img: Image.Image) -> np.ndarray:
149
- model, preprocess = _load_openclip()
150
- with torch.no_grad():
151
- x = preprocess(img.convert("RGB")).unsqueeze(0)
152
- if DEVICE in ("cuda", "mps"): x = x.to(DEVICE)
153
- v = model.encode_image(x).float()
154
- v = v / v.norm(dim=-1, keepdim=True)
155
- return v.detach().cpu().numpy().squeeze()
156
-
157
- def _top3_emotions_weighted_from_embed(q: np.ndarray,
158
- start_k: int = 30, step: int = 30,
159
- method: str = "softmax", tau: float = 0.1):
160
- max_k = _X.shape[0]; k = min(start_k, max_k)
161
- while True:
162
- dist, idx = _nn.kneighbors(q.reshape(1, -1), n_neighbors=k)
163
- idx, dist = idx[0], dist[0]
164
- sims = np.clip(1.0 - dist, 0.0, None)
165
- w = np.exp(sims / tau) if method == "softmax" else sims
166
- scores: Dict[str, float] = {}
167
- for i, wi in zip(idx, w):
168
- lab = _label_by_idx(int(i))
169
- if lab is None: continue
170
- lab = CANON_MAP.get(str(lab).lower(), str(lab).lower())
171
- scores[lab] = scores.get(lab, 0.0) + float(wi)
172
- if len([k for k in scores if k in CANON]) >= 3 or k == max_k:
173
- break
174
- k = min(k + step, max_k)
175
- scores = {k: v for k, v in scores.items() if k in CANON and v > 0}
176
- if not scores: return []
177
- top_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:3]
178
- vals = np.array([v for _, v in top_items], dtype=np.float32)
179
- pct = (vals / vals.sum()) * 100.0 if vals.sum() > 0 else np.zeros_like(vals)
180
- return [{"rank": i+1, "emotion": lab, "confidence_pct": int(round(p))}
181
- for i, ((lab, _), p) in enumerate(zip(top_items, pct))]
182
-
183
- def analyze_face(image: Image.Image):
184
- q = embed_image(image)
185
- top3 = _top3_emotions_weighted_from_embed(q)
186
- stress_pct, stress_lbl = stress_from_top3(top3)
187
- return top3, f"{stress_pct}% ({stress_lbl})", q
188
-
189
- # ----- Nearest neighbors images from dataset -----
190
- def _get_dataset_image(i: int) -> Image.Image:
191
- return _dataset_for_labels[int(i)]["image"].convert("RGB")
192
-
193
- def nearest_k_images_from_dataset(q_emb: np.ndarray, k: int = 5):
194
- dist, idx = _nn.kneighbors(q_emb.reshape(1, -1), n_neighbors=k)
195
- dist, idx = dist[0], idx[0]
196
- sims = (1.0 - dist).tolist()
197
  out = []
198
- for i, s in zip(idx, sims):
199
- img = _get_dataset_image(int(i))
200
- out.append((img, float(s), int(i)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  return out
202
 
203
- # -----------------------------
204
- # Generator (optional)
205
- # -----------------------------
206
- def _get_gen_pipe():
207
- global _gen_pipe
208
- if _gen_pipe is not None: return _gen_pipe
209
- gen_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
210
- pipe = StableDiffusionImageVariationPipeline.from_pretrained(
211
- GEN_MODEL_ID, revision="v2.0", torch_dtype=gen_dtype
212
- ).to(DEVICE)
213
- _gen_pipe = pipe
214
- return _gen_pipe
215
-
216
- def generate_synthetics(base_image: Image.Image, base_embed: np.ndarray, n_syn: int, steps: int, progress: gr.Progress):
217
- pipe = _get_gen_pipe()
218
- base_gen = torch.Generator(device="cpu").manual_seed(42)
219
- records = []
220
- for _ in progress.tqdm(range(n_syn), desc="Generating"):
221
- seed = int(torch.randint(0, 2**31 - 1, (1,), generator=base_gen).item())
222
- gs = random.choice(GUIDANCE_SCALES)
223
- g = torch.Generator(device="cpu").manual_seed(seed)
224
- out = pipe(image=base_image.convert("RGB"), guidance_scale=gs, num_inference_steps=steps, generator=g)
225
- img = out.images[0]
226
- emb = embed_image(img)
227
- sim = float(np.dot(emb, base_embed))
228
- top3_syn = _top3_emotions_weighted_from_embed(emb)
229
- stress_pct, stress_lbl = stress_from_top3(top3_syn)
230
- records.append({"image": img, "similarity": sim, "top3": top3_syn, "stress": f"{stress_pct}% ({stress_lbl})"})
231
- records.sort(key=lambda r: r["similarity"], reverse=True)
232
- return records[:NUM_SYN_TO_SHOW]
233
-
234
- # -----------------------------
235
- # UI
236
- # -----------------------------
237
- def _format_top3_for_table(top3: List[Dict]) -> List[List]:
238
- return [[r["rank"], r["emotion"], r["confidence_pct"]] for r in top3]
239
-
240
- with gr.Blocks(title="Face Emotions + Stress (CPU Fast)") as demo:
241
  gr.Markdown(
242
- "## Face Emotion & Stress Analyzer — CPU-friendly\n"
243
  "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
244
- "- Optional SD variations: **lambdalabs/sd-image-variations-diffusers**\n"
245
- "- Also shows **nearest 5 images from the dataset** for 1-click results.\n"
 
246
  )
247
 
 
248
  with gr.Row():
249
- inp = gr.Image(type="pil", label="Upload a face image", sources=["upload", "webcam"])
250
- idx_cap = gr.Slider(20, 200, value=INDEX_MAX_DEFAULT, step=10, label="Index size (smaller = faster)")
251
- bs = gr.Slider(8, 64, value=BATCH_SIZE_DEFAULT, step=8, label="Batch size (build)")
252
-
253
- analyze_btn = gr.Button("Analyze (no synthetics)")
254
-
255
- with gr.Row():
256
- with gr.Column():
257
- top3_tbl = gr.Dataframe(
258
  headers=["Rank", "Emotion", "Confidence (%)"],
259
  datatype=["number", "str", "number"],
260
- interactive=False, row_count=(3, "fixed"), col_count=(3, "fixed"),
261
- label="Top-3 emotions (original image)"
262
  )
263
- stress_txt = gr.Label(label="Stress index (original)")
264
- with gr.Column():
265
- # Nearest 5 from dataset (one-click examples)
266
- nn_gal = gr.Gallery(
267
- label="Nearest 5 from dataset (click one)",
268
- columns=[5], height=220, preview=True
 
 
 
 
 
 
 
 
 
 
 
 
269
  )
270
- nn_stress = gr.Label(label="Stress (nearest image)")
271
- nn_top3 = gr.JSON(label="Top-3 emotions (nearest image)")
272
-
273
- # Optional generator
274
- n_syn = gr.Slider(0, 5, value=N_SYN_DEFAULT, step=1, label="How many SD variations to generate")
275
- steps = gr.Slider(8, 30, value=STEPS_DEFAULT, step=2, label="Diffusion steps (higher = slower/better)")
276
- gen_btn = gr.Button("Generate variations (optional)")
277
- gal = gr.Gallery(label="Synthetic variations (click one)", columns=[5], height=220, preview=True)
278
- syn_stress = gr.Label(label="Stress (selected synthetic)")
279
- syn_top3 = gr.JSON(label="Top-3 emotions (selected synthetic)")
280
-
281
- status = gr.Markdown(visible=False)
282
-
283
- # State
284
- syn_state = gr.State([]) # generated variations
285
- q_state = gr.State(None) # embedding of original image
286
- img_state = gr.State(None) # original image
287
-
288
- # ---- Analyze ----
289
- def do_analyze(image: Image.Image, cap: int, batch: int, progress=gr.Progress(track_tqdm=True)):
290
- try:
291
- _ensure_knn_index(index_max=int(cap), batch_size=int(batch), progress=progress)
292
- top3, stress, q = analyze_face(image)
293
-
294
- # nearest 5 images from dataset
295
- neigh = nearest_k_images_from_dataset(np.array(q, dtype=np.float32), k=5)
296
- nn_items = [(im, f"sim={sim:.3f} • idx={idx}") for im, sim, idx in neigh]
297
-
298
- # return: top3, stress, nn gallery, (empty SD gallery), syn_state, q, img, status
299
- return (_format_top3_for_table(top3), stress,
300
- nn_items, [], [], q, image, gr.update(visible=False))
301
- except Exception as e:
302
- return None, None, [], [], [], None, None, gr.update(visible=True, value=f"**Error:** {e}")
303
-
304
- analyze_btn.click(
305
- do_analyze,
306
- inputs=[inp, idx_cap, bs],
307
- outputs=[top3_tbl, stress_txt, nn_gal, gal, syn_state, q_state, img_state, status]
308
- )
309
 
310
- # ---- One-click on a nearest image ----
311
- def on_nn_select(evt: gr.SelectData, q):
312
- if q is None:
313
- return gr.update(value="Analyze first"), None
314
- neigh = nearest_k_images_from_dataset(np.array(q, dtype=np.float32), k=5)
315
- i = max(0, min(int(evt.index), len(neigh)-1))
316
- img, _, _ = neigh[i]
317
- emb = embed_image(img)
318
- top3 = _top3_emotions_weighted_from_embed(emb)
319
- stress_pct, stress_lbl = stress_from_top3(top3)
320
- return f"{stress_pct}% ({stress_lbl})", top3
321
-
322
- nn_gal.select(fn=on_nn_select, inputs=[q_state], outputs=[nn_stress, nn_top3])
323
-
324
- # ---- Optional: generate SD variations ----
325
- def do_generate(n: int, s: int, q, img, progress=gr.Progress()):
326
- if q is None or img is None:
327
- return [], [], gr.update(visible=True, value="**Error:** Analyze first."), None
328
- try:
329
- recs = generate_synthetics(img, np.array(q, dtype=np.float32), n_syn=int(n), steps=int(s), progress=progress)
330
- items = [(r["image"], f"sim={r['similarity']:.3f}") for r in recs]
331
- return items, recs, gr.update(visible=False), None
332
- except Exception as e:
333
- return [], [], gr.update(visible=True, value=f"**Error:** {e}"), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- gen_btn.click(
336
- do_generate,
337
- inputs=[n_syn, steps, q_state, img_state],
338
- outputs=[gal, syn_state, status, syn_top3]
 
 
 
 
 
 
 
 
 
 
 
339
  )
340
 
341
- # select from generated synthetics
342
- def on_gallery_select(evt: gr.SelectData, syn_records: List[Dict]):
343
- if not syn_records or evt is None: return gr.update(value=None), gr.update(value=None)
344
- i = int(evt.index); rec = syn_records[i]
345
- return gr.update(value=rec["stress"]), gr.update(value=rec["top3"])
 
 
 
 
346
 
347
- gal.select(fn=on_gallery_select, inputs=[syn_state], outputs=[syn_stress, syn_top3])
 
 
 
 
348
 
349
  if __name__ == "__main__":
350
  demo.launch()
 
1
+ import os, io, math, json, random, numpy as np
2
+ from typing import List, Tuple, Dict, Any
 
 
 
 
3
  from PIL import Image
4
 
5
  import torch
6
+ import torch.nn.functional as F
7
+
8
+ import gradio as gr
9
  from datasets import load_dataset
10
  from sklearn.neighbors import NearestNeighbors
11
+ from transformers import pipeline
12
+
13
+ # =============== CONFIG ===============
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Embeddings (your chosen model)
17
+ OPENCLIP_BACKBONE = "ViT-H-14"
18
+ OPENCLIP_PRETRAIN = "laion2B-s32B-b79K" # laion/CLIP-ViT-H-14-laion2B-s32B-b79K
19
+ INDEX_SIZE = int(os.getenv("INDEX_SIZE", 400)) # כמה תמונות מהדאטהסט לאינדוקס
20
+ TOPK_NEAREST = 5
21
+
22
+ # Emotion model (your chosen model)
23
+ EMO_MODEL = "prithivMLmods/Facial-Emotion-Detection-SigLIP2"
24
+
25
+ # Optional image-variation generator (your chosen model)
26
+ USE_SD_VARIATIONS = True
27
+ SD_MODEL = "lambdalabs/sd-image-variations-diffusers"
28
+ # =====================================
29
+
30
+ # ---------- Load OpenCLIP for image embeddings ----------
31
+ try:
32
+ import open_clip
33
+ _openclip_model, _, _openclip_preprocess = open_clip.create_model_and_transforms(
34
+ OPENCLIP_BACKBONE, pretrained=OPENCLIP_PRETRAIN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  )
36
+ _openclip_model = _openclip_model.to(DEVICE).eval()
37
+ except Exception as e:
38
+ raise RuntimeError(
39
+ f"Failed to load OpenCLIP ({OPENCLIP_BACKBONE} / {OPENCLIP_PRETRAIN}). "
40
+ f"Install 'open_clip_torch' and verify CUDA if available. Error: {e}"
41
+ )
42
+
43
+ @torch.inference_mode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def embed_image(img: Image.Image) -> np.ndarray:
45
+ img = img.convert("RGB")
46
+ tens = _openclip_preprocess(img).unsqueeze(0).to(DEVICE)
47
+ feats = _openclip_model.encode_image(tens)
48
+ feats = F.normalize(feats, dim=-1).squeeze(0).detach().cpu().numpy().astype(np.float32)
49
+ return feats # shape [D]
50
+
51
+ # ---------- Dataset + index ----------
52
+ DATASET_NAME = "tukey/human_face_emotions_roboflow"
53
+ DATASET_SPLIT = "train"
54
+
55
+ def _load_images_for_index(n: int) -> List[Image.Image]:
56
+ ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
57
+ n = min(n, len(ds))
58
+ imgs = []
59
+ for i in range(n):
60
+ # column is usually "image"
61
+ im = ds[i].get("image")
62
+ if isinstance(im, Image.Image):
63
+ imgs.append(im.copy())
64
+ return imgs
65
+
66
+ def build_index(imgs: List[Image.Image]) -> Tuple[NearestNeighbors, np.ndarray]:
67
+ vecs = []
68
+ for im in imgs:
69
+ vecs.append(embed_image(im))
70
+ X = np.stack(vecs, axis=0)
71
+ nn = NearestNeighbors(metric="cosine", n_neighbors=min(TOPK_NEAREST, len(imgs)))
72
+ nn.fit(X)
73
+ return nn, X
74
+
75
+ print("Loading dataset & building index (first time only)...")
76
+ DATASET_IMAGES: List[Image.Image] = _load_images_for_index(INDEX_SIZE)
77
+ NN_MODEL, EMB_MATRIX = build_index(DATASET_IMAGES)
78
+ print(f"Index ready with {len(DATASET_IMAGES)} images.")
79
+
80
+ def nearest5(pil_img: Image.Image) -> List[Tuple[Image.Image, str]]:
81
+ q = embed_image(pil_img).reshape(1, -1)
82
+ dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=min(5, len(DATASET_IMAGES)))
83
+ # cosine distance -> similarity = 1 - dist
 
 
 
 
 
 
 
 
 
84
  out = []
85
+ for rank, (dist, idx) in enumerate(zip(dists[0], idxs[0]), start=1):
86
+ sim = 1.0 - float(dist)
87
+ im = DATASET_IMAGES[int(idx)]
88
+ caption = f"#{rank} sim={sim:.3f} idx={int(idx)}"
89
+ out.append((im, caption))
90
+ return out # list of (PIL, caption)
91
+
92
+ # ---------- Emotion & Stress ----------
93
+ EMO_MAP = {
94
+ "anger": "anger", "angry": "anger",
95
+ "disgust": "disgust",
96
+ "fear": "fear",
97
+ "happy": "happy", "happiness": "happy",
98
+ "neutral": "neutral", "calm": "neutral",
99
+ "sad": "sad", "sadness": "sad",
100
+ "surprise": "surprise",
101
+ "contempt": "contempt",
102
+ }
103
+
104
+ # higher == more stressed
105
+ STRESS_WEIGHTS = {
106
+ "anger": 0.95,
107
+ "fear": 0.90,
108
+ "disgust": 0.70,
109
+ "sad": 0.80,
110
+ "surprise": 0.55,
111
+ "neutral": 0.30,
112
+ "contempt": 0.65,
113
+ "happy": 0.10,
114
+ }
115
+
116
+ def _bucket(p: float) -> str:
117
+ return "Low" if p < 33 else ("Medium" if p < 66 else "High")
118
+
119
+ emo_pipe = pipeline("image-classification", model=EMO_MODEL, device=0 if DEVICE == "cuda" else -1)
120
+
121
+ def _pipe_to_probs(res: List[Dict[str, Any]]) -> Dict[str, float]:
122
+ acc: Dict[str, float] = {}
123
+ for r in res:
124
+ label = (r.get("label") or r.get("emotion") or "").lower()
125
+ if not label:
126
+ continue
127
+ key = EMO_MAP.get(label, label)
128
+ score = float(r.get("score") or r.get("confidence") or r.get("confidence_pct", 0.0) / 100.0)
129
+ acc[key] = acc.get(key, 0.0) + score
130
+ Z = sum(acc.values()) or 1.0
131
+ for k in list(acc.keys()):
132
+ acc[k] = acc[k] / Z
133
+ return acc
134
+
135
+ def emotions_top3(pil_img: Image.Image) -> List[List[Any]]:
136
+ res = emo_pipe(pil_img)
137
+ probs = _pipe_to_probs(res)
138
+ items = sorted(probs.items(), key=lambda kv: kv[1], reverse=True)[:3]
139
+ table = []
140
+ for i, (emo, p) in enumerate(items, start=1):
141
+ table.append([i, emo, round(100.0 * p, 2)])
142
+ return table # [[rank, emotion, pct]]
143
+
144
+ def stress_index(pil_img: Image.Image) -> Tuple[str, float]:
145
+ res = emo_pipe(pil_img)
146
+ probs = _pipe_to_probs(res)
147
+ raw = 0.0
148
+ for k, v in probs.items():
149
+ w = STRESS_WEIGHTS.get(k, 0.5)
150
+ raw += v * w
151
+ pct = max(0.0, min(100.0, 100.0 * raw))
152
+ return f"{pct:.1f}% ({_bucket(pct)})", pct
153
+
154
+ # ---------- Optional: SD image variations (1 image only) ----------
155
+ sd_pipe = None
156
+ if USE_SD_VARIATIONS:
157
+ try:
158
+ from diffusers import StableDiffusionImageVariationPipeline
159
+ sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
160
+ SD_MODEL, torch_dtype=torch.float32
161
+ )
162
+ sd_pipe = sd_pipe.to(DEVICE)
163
+ except Exception as e:
164
+ print(f"[WARN] Could not load {SD_MODEL}. Generation disabled. Error: {e}")
165
+ sd_pipe = None
166
+
167
+ def generate_one_variation(pil_img: Image.Image, steps: int) -> Image.Image:
168
+ if sd_pipe is None:
169
+ raise gr.Error("Image-variation pipeline is not available on this Space.")
170
+ pil_img = pil_img.convert("RGB")
171
+ out = sd_pipe(pil_img, guidance_scale=3.0, num_inference_steps=int(steps)).images[0]
172
  return out
173
 
174
+ # ===================== GRADIO UI =====================
175
+ CSS = """
176
+ .box { border: 1px solid #e5e7eb; border-radius: 12px; padding: 10px; }
177
+ """
178
+
179
+ with gr.Blocks(title="Face Emotion & Stress Analyzer — CPU-friendly", css=CSS, fill_height=False) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  gr.Markdown(
181
+ "### Face Emotion & Stress Analyzer — CPU-friendly\n"
182
  "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
183
+ "- Emotion model: **prithivMLmods/Facial-Emotion-Detection-SigLIP2**\n"
184
+ "- Optional SD variations: **lambdalabs/sd-image-variations-diffusers** (1 synthetic only)\n"
185
+ "- Right column shows nearest 5 images from the dataset (clickable)."
186
  )
187
 
188
+ # ---- Row 1: upload + (top3_emotion_original | stress_original) ----
189
  with gr.Row():
190
+ with gr.Column(scale=2):
191
+ upload_image = gr.Image(label="Upload face image", type="pil")
192
+ with gr.Column(scale=1):
193
+ top3_emotion_original = gr.Dataframe(
 
 
 
 
 
194
  headers=["Rank", "Emotion", "Confidence (%)"],
195
  datatype=["number", "str", "number"],
196
+ interactive=False, label="Top-3 emotions (original image)",
197
+ value=[]
198
  )
199
+ with gr.Column(scale=1):
200
+ stress_original = gr.Label(label="Stress index (original)")
201
+
202
+ gr.Markdown("#### Analyze (no synthetics)")
203
+
204
+ with gr.Row(equal_height=False):
205
+ # ---------- LEFT COLUMN ----------
206
+ with gr.Column(scale=1):
207
+ with gr.Group():
208
+ gr.Markdown("**gen_variations_control** — generate only **one** synthetic")
209
+ steps = gr.Slider(8, 40, value=12, step=1, label="Diffusion steps (higher=slower/better)")
210
+ gen_btn = gr.Button("Generate 1 synthetic", variant="primary")
211
+ picked_synth = gr.Image(label="Synthetic preview")
212
+ top3_emotion_synth = gr.Dataframe(
213
+ headers=["Rank", "Emotion", "Confidence (%)"],
214
+ datatype=["number", "str", "number"],
215
+ interactive=False, label="top3_emotion_synth",
216
+ value=[]
217
  )
218
+ stress_synth = gr.Label(label="stress_synth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ # ---------- RIGHT COLUMN ----------
221
+ with gr.Column(scale=1):
222
+ nearest_images_5 = gr.Gallery(
223
+ label="nearest_images_5 (1-click on 5 examples)",
224
+ columns=5, rows=1, height=200, allow_preview=False, show_label=True
225
+ )
226
+ tops_emotion_nearest = gr.Dataframe(
227
+ headers=["Rank", "Emotion", "Confidence (%)"],
228
+ datatype=["number", "str", "number"],
229
+ interactive=False, label="tops_emotion_nearest_image",
230
+ value=[]
231
+ )
232
+ stress_nearest = gr.Label(label="stress_nearest_image")
233
+
234
+ # --------- Hidden states ---------
235
+ gallery_images_state = gr.State([]) # store PILs
236
+ gallery_index_state = gr.State([]) # store dataset indexes (ints)
237
+
238
+ # ================= Callbacks =================
239
+ def on_upload(img: Image.Image):
240
+ if img is None:
241
+ return gr.update(), gr.update(value=""), [], [], []
242
+ # original
243
+ t3 = emotions_top3(img)
244
+ s_label, _ = stress_index(img)
245
+ # nearest gallery
246
+ gal = nearest5(img) # list[(PIL, caption)]
247
+ gal_imgs = [g[0] for g in gal]
248
+ gal_caps = [g[1] for g in gal]
249
+ # gr.Gallery accepts [(img, caption), ...]
250
+ gallery = [(im, cap) for im, cap in zip(gal_imgs, gal_caps)]
251
+ # return
252
+ return t3, s_label, gallery, gal_imgs, list(range(len(gal_imgs)))
253
+
254
+ upload_image.change(
255
+ fn=on_upload,
256
+ inputs=[upload_image],
257
+ outputs=[top3_emotion_original, stress_original, nearest_images_5, gallery_images_state, gallery_index_state]
258
+ )
259
 
260
+ def on_gallery_select(evt: gr.SelectData, imgs: List[Image.Image], idxs: List[int]):
261
+ # evt.index is the clicked cell
262
+ if imgs is None or not imgs:
263
+ return [], ""
264
+ i = int(evt.index) if evt is not None else 0
265
+ i = max(0, min(i, len(imgs)-1))
266
+ im = imgs[i]
267
+ t3 = emotions_top3(im)
268
+ s_label, _ = stress_index(im)
269
+ return t3, s_label
270
+
271
+ nearest_images_5.select(
272
+ fn=on_gallery_select,
273
+ inputs=[gallery_images_state, gallery_index_state],
274
+ outputs=[tops_emotion_nearest, stress_nearest]
275
  )
276
 
277
+ def on_generate(img: Image.Image, steps_val: int):
278
+ if img is None:
279
+ raise gr.Error("Upload an image first.")
280
+ if sd_pipe is None:
281
+ raise gr.Error("Synthetic generation is disabled on this Space.")
282
+ synth = generate_one_variation(img, steps_val)
283
+ t3 = emotions_top3(synth)
284
+ s_label, _ = stress_index(synth)
285
+ return synth, t3, s_label
286
 
287
+ gen_btn.click(
288
+ fn=on_generate,
289
+ inputs=[upload_image, steps],
290
+ outputs=[picked_synth, top3_emotion_synth, stress_synth]
291
+ )
292
 
293
  if __name__ == "__main__":
294
  demo.launch()