Shani13524 commited on
Commit
0e096ce
·
verified ·
1 Parent(s): 52d90d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -185
app.py CHANGED
@@ -8,36 +8,35 @@ from PIL import Image
8
 
9
  import torch
10
  import open_clip
11
- from datasets import load_dataset, DatasetDict
12
  from sklearn.neighbors import NearestNeighbors
13
-
14
  from diffusers import StableDiffusionImageVariationPipeline
15
 
16
  # -----------------------------
17
  # Config
18
  # -----------------------------
19
  DATASET_ID = "tukey/human_face_emotions_roboflow"
20
- EMB_MODEL_NAME = "ViT-H-14" # open_clip model name
21
- EMB_PRETRAINED = "laion2b_s32b_b79k" # maps to laion/CLIP-ViT-H-14-laion2B-s32B-b79K
22
  GEN_MODEL_ID = "lambdalabs/sd-image-variations-diffusers"
23
 
24
- CACHE_DIR = Path("./cache")
25
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
26
  EMB_MEMMAP_PATH = CACHE_DIR / "clip_vith14_laion2b.float32.memmap"
27
  LABELS_MEMMAP_PATH = CACHE_DIR / "labels.U32.memmap"
28
  KNN_META_PATH = CACHE_DIR / "knn_meta.json"
29
 
30
- # generation defaults
31
- N_SYN = 12 # generate more, then keep top-5
 
 
32
  NUM_SYN_TO_SHOW = 5
33
- STEPS = 35
34
  GUIDANCE_SCALES = [2.5, 3.0, 3.5, 4.0]
35
 
36
- # device selection
37
  DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu")
38
 
39
  # -----------------------------
40
- # Canonical label + stress mapping (kept from your Colab logic)
41
  # -----------------------------
42
  CANON = {"anger","disgust","fear","happy","neutral","sad","surprise","contempt"}
43
  CANON_MAP = {
@@ -57,17 +56,14 @@ STRESS_W = {
57
  }
58
  def _bucket(pct: float) -> str:
59
  return "Low" if pct < 33 else ("Medium" if pct < 66 else "High")
60
-
61
  def stress_from_top3(res: List[Dict]) -> Tuple[float, str]:
62
  probs = {}
63
  for r in res:
64
  lbl = CANON_MAP.get(str(r["emotion"]).lower(), str(r["emotion"]).lower())
65
- if lbl not in CANON:
66
- continue
67
  probs[lbl] = probs.get(lbl, 0.0) + float(r["confidence_pct"]) / 100.0
68
  Z = sum(probs.values()) or 1.0
69
- for k in list(probs):
70
- probs[k] /= Z
71
  s01 = sum(probs.get(k, 0.0) * STRESS_W.get(k, 0.0) for k in probs)
72
  s01 = max(0.0, min(1.0, s01))
73
  pct = round(s01 * 100.0, 2)
@@ -89,108 +85,81 @@ _dataset_for_labels = None
89
  # -----------------------------
90
  def _load_openclip():
91
  global _openclip_model, _preprocess
92
- if _openclip_model is not None and _preprocess is not None:
93
- return _openclip_model, _preprocess
94
  model, _, preprocess = open_clip.create_model_and_transforms(
95
- model_name=EMB_MODEL_NAME,
96
- pretrained=EMB_PRETRAINED,
97
- device=DEVICE
98
  )
99
  model.eval()
100
  _openclip_model, _preprocess = model, preprocess
101
  return _openclip_model, _preprocess
102
 
103
- def _ensure_knn_index():
104
- """Build (first run) or load a memmap + KNN index over the dataset embeddings."""
 
 
 
105
  global _nn, _X, _labels_source, _dataset_for_labels
106
 
107
- if _nn is not None and _X is not None:
108
  return
109
 
110
- # Load dataset (train split; if missing, fallback to full)
 
 
111
  dataset = load_dataset(DATASET_ID, split="train")
 
 
112
  _dataset_for_labels = dataset
113
-
114
  N = len(dataset)
115
 
116
- # If memmaps already exist and meta is present -> load
117
  if EMB_MEMMAP_PATH.exists() and KNN_META_PATH.exists():
118
- meta = json.load(open(KNN_META_PATH, "r"))
119
- N_meta, D = int(meta["N"]), int(meta["D"])
120
- if N_meta == N:
121
  X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
122
- # labels memmap optional; if missing, we can fetch labels on the fly
123
- labels = None
124
- if LABELS_MEMMAP_PATH.exists():
125
- labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,))
126
- _fit_knn(X)
127
- _X = X
128
- _labels_source = labels
129
  return
130
 
131
- # Build embeddings (first run)
132
  model, preprocess = _load_openclip()
133
- D = None
134
- X_w = None
135
 
136
  def _label_of(i):
137
- try:
138
- ans = dataset[i]["qa"][0]["answer"]
139
- return str(ans) if ans is not None else ""
140
- except Exception:
141
- return ""
142
-
143
- # write labels memmap
144
- labels_mm = np.memmap(LABELS_MEMMAP_PATH, mode="w+", dtype="U32", shape=(N,))
145
 
146
  with torch.no_grad():
147
- # do the first item to get D
148
- x0 = preprocess(dataset[0]["image"]).unsqueeze(0)
149
- if DEVICE in ("cuda", "mps"):
150
- x0 = x0.to(DEVICE)
151
- v0 = model.encode_image(x0).float()
152
- v0 = v0 / v0.norm(dim=-1, keepdim=True)
153
- D = v0.shape[1]
154
- X_w = np.memmap(EMB_MEMMAP_PATH, mode="w+", dtype="float32", shape=(N, D))
155
- X_w[0] = v0.detach().cpu().numpy().squeeze()
156
- labels_mm[0] = _label_of(0)
157
-
158
- # rest
159
- for i in range(1, N):
160
- xi = preprocess(dataset[i]["image"]).unsqueeze(0)
161
- if DEVICE in ("cuda", "mps"):
162
- xi = xi.to(DEVICE)
163
- vi = model.encode_image(xi).float()
164
- vi = vi / vi.norm(dim=-1, keepdim=True)
165
- X_w[i] = vi.detach().cpu().numpy().squeeze()
166
- labels_mm[i] = _label_of(i)
167
-
168
- # flush to disk
169
- del X_w
170
- gc.collect()
171
-
172
- # Save meta, reload read-only view, fit knn
173
  json.dump({"N": int(N), "D": int(D)}, open(KNN_META_PATH, "w"))
174
  X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
175
  labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,))
176
- _fit_knn(X)
177
- _X = X
178
- _labels_source = labels
179
-
180
- def _fit_knn(X):
181
- global _nn
182
- _nn = NearestNeighbors(metric="cosine", algorithm="brute").fit(X)
183
 
184
  def _label_by_idx(i: int):
185
  global _labels_source, _dataset_for_labels
186
  if _labels_source is not None:
187
- lab = str(_labels_source[i])
188
- return lab if lab else None
189
- # fallback: live label read
190
- try:
191
- return _dataset_for_labels[i]["qa"][0]["answer"]
192
- except Exception:
193
- return None
194
 
195
  # -----------------------------
196
  # Embedding + inference utils
@@ -199,60 +168,42 @@ def embed_image(img: Image.Image) -> np.ndarray:
199
  model, preprocess = _load_openclip()
200
  with torch.no_grad():
201
  x = preprocess(img.convert("RGB")).unsqueeze(0)
202
- if DEVICE in ("cuda", "mps"):
203
- x = x.to(DEVICE)
204
  v = model.encode_image(x).float()
205
  v = v / v.norm(dim=-1, keepdim=True)
206
  return v.detach().cpu().numpy().squeeze()
207
 
208
  def _top3_emotions_weighted_from_embed(q: np.ndarray,
209
- start_k: int = 30, step: int = 30, method: str = "softmax", tau: float = 0.1):
 
210
  _ensure_knn_index()
211
- max_k = _X.shape[0]
212
- k = min(start_k, max_k)
213
- scores: Dict[str, float] = {}
214
-
215
  while True:
216
  dist, idx = _nn.kneighbors(q.reshape(1, -1), n_neighbors=k)
217
  idx, dist = idx[0], dist[0]
218
- sims = 1.0 - dist
219
- sims = np.clip(sims, 0.0, None)
220
  w = np.exp(sims / tau) if method == "softmax" else sims
221
 
222
- scores.clear()
223
- total_w = 0.0
224
  for i, wi in zip(idx, w):
225
  lab = _label_by_idx(int(i))
226
- if lab is None:
227
- continue
228
  lab = CANON_MAP.get(str(lab).lower(), str(lab).lower())
229
  scores[lab] = scores.get(lab, 0.0) + float(wi)
230
- total_w += float(wi)
231
 
232
- # stop when we have >= 3 unique emotions
233
- if len([k for k in scores.keys() if k in CANON]) >= 3 or k == max_k:
234
  break
235
  k = min(k + step, max_k)
236
 
237
- if not scores:
238
- return []
239
-
240
- # keep only canonical keys
241
  scores = {k: v for k, v in scores.items() if k in CANON and v > 0}
242
- if not scores:
243
- return []
244
-
245
  top_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:3]
246
  vals = np.array([v for _, v in top_items], dtype=np.float32)
247
  pct = (vals / vals.sum()) * 100.0 if vals.sum() > 0 else np.zeros_like(vals)
248
- return [
249
- {"rank": i + 1, "emotion": lab, "confidence_pct": int(round(p))}
250
- for i, ((lab, _), p) in enumerate(zip(top_items, pct))
251
- ]
252
 
253
  def analyze_face(image: Image.Image):
254
- """Return top-3 emotions + stress for the original image."""
255
- _ensure_knn_index()
256
  q = embed_image(image)
257
  top3 = _top3_emotions_weighted_from_embed(q)
258
  stress_pct, stress_lbl = stress_from_top3(top3)
@@ -263,55 +214,31 @@ def analyze_face(image: Image.Image):
263
  # -----------------------------
264
  def _get_gen_pipe():
265
  global _gen_pipe
266
- if _gen_pipe is not None:
267
- return _gen_pipe
268
  gen_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
269
  pipe = StableDiffusionImageVariationPipeline.from_pretrained(
270
- GEN_MODEL_ID,
271
- revision="v2.0",
272
- torch_dtype=gen_dtype
273
- )
274
- pipe = pipe.to(DEVICE)
275
  _gen_pipe = pipe
276
  return _gen_pipe
277
 
278
- def generate_synthetics(base_image: Image.Image, base_embed: np.ndarray):
279
- """Generate N_SYN variations, compute similarity to original embedding, keep top-5."""
280
  pipe = _get_gen_pipe()
281
-
282
- # Deterministic seed stream
283
  base_gen = torch.Generator(device="cpu").manual_seed(42)
284
-
285
  records = []
286
- for i in range(N_SYN):
287
  seed = int(torch.randint(0, 2**31 - 1, (1,), generator=base_gen).item())
288
  gs = random.choice(GUIDANCE_SCALES)
289
  g = torch.Generator(device="cpu").manual_seed(seed)
290
-
291
- out = pipe(
292
- image=base_image.convert("RGB"),
293
- guidance_scale=gs,
294
- num_inference_steps=STEPS,
295
- generator=g
296
- )
297
  img = out.images[0]
298
-
299
- # embed and compute similarity to the original
300
  emb = embed_image(img)
301
- sim = float(np.dot(emb, base_embed)) # both normalized
302
-
303
- # top3 + stress for each synthetic
304
  top3_syn = _top3_emotions_weighted_from_embed(emb)
305
  stress_pct, stress_lbl = stress_from_top3(top3_syn)
306
-
307
- records.append({
308
- "image": img,
309
- "similarity": sim,
310
- "top3": top3_syn,
311
- "stress": f"{stress_pct}% ({stress_lbl})"
312
- })
313
-
314
- # keep best NUM_SYN_TO_SHOW by similarity
315
  records.sort(key=lambda r: r["similarity"], reverse=True)
316
  return records[:NUM_SYN_TO_SHOW]
317
 
@@ -319,21 +246,20 @@ def generate_synthetics(base_image: Image.Image, base_embed: np.ndarray):
319
  # Gradio app
320
  # -----------------------------
321
  def _format_top3_for_table(top3: List[Dict]) -> List[List]:
322
- rows = []
323
- for r in top3:
324
- rows.append([r["rank"], r["emotion"], r["confidence_pct"]])
325
- return rows
326
 
327
- with gr.Blocks(title="Face Emotions + Stress (CLIP ViT-H-14 + SD Variations)") as demo:
328
  gr.Markdown(
329
- "## Face Emotion & Stress Analyzer\n"
330
- "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (via `open_clip`)\n"
331
  "- Synthetic variations: **lambdalabs/sd-image-variations-diffusers**\n"
332
  "- KNN labels from: **tukey/human_face_emotions_roboflow**\n"
 
333
  )
334
 
335
  with gr.Row():
336
  inp = gr.Image(type="pil", label="Upload a face image", sources=["upload", "webcam"])
 
337
 
338
  analyze_btn = gr.Button("Analyze & Generate Synthetics")
339
 
@@ -342,59 +268,54 @@ with gr.Blocks(title="Face Emotions + Stress (CLIP ViT-H-14 + SD Variations)") a
342
  top3_tbl = gr.Dataframe(
343
  headers=["Rank", "Emotion", "Confidence (%)"],
344
  datatype=["number", "str", "number"],
345
- interactive=False,
346
- row_count=(3, "fixed"),
347
- col_count=(3, "fixed"),
348
  label="Top-3 emotions (original image)"
349
  )
350
  stress_txt = gr.Label(label="Stress index (original)")
351
  with gr.Column():
352
  gal = gr.Gallery(
353
  label="Top 5 synthetic variations (click one)",
354
- columns=[5], height=200, preview=True
355
  )
356
  syn_stress = gr.Label(label="Stress index (selected synthetic)")
357
  syn_top3 = gr.JSON(label="Top-3 emotions (selected synthetic)")
358
 
359
  status = gr.Markdown(visible=False)
 
360
 
361
- # State to pass around generated records
362
- syn_state = gr.State([]) # list of dicts: {image, similarity, top3, stress}
363
-
364
- def run_pipeline(image: Image.Image):
365
  try:
366
- # Step 1: top3 + stress (original)
367
- top3, stress, q = analyze_face(image)
 
 
 
368
 
369
- # Step 2: synthetics + pick top 5
370
- syn = generate_synthetics(image, q)
371
 
372
- # gallery expects a list of images or (image, caption)
 
 
 
 
373
  items = [(r["image"], f"sim={r['similarity']:.3f}") for r in syn]
374
- top3_rows = _format_top3_for_table(top3)
375
- return top3_rows, stress, items, syn, gr.update(visible=False), None
376
  except Exception as e:
377
  return None, None, None, [], gr.update(visible=True, value=f"**Error:** {e}"), None
378
 
379
  analyze_btn.click(
380
- run_pipeline,
381
- inputs=[inp],
382
  outputs=[top3_tbl, stress_txt, gal, syn_state, status, syn_top3]
383
  )
384
 
385
  def on_gallery_select(evt: gr.SelectData, syn_records: List[Dict]):
386
- # evt.index is the clicked item index in gallery
387
- if not syn_records or evt is None:
388
- return gr.update(value=None), gr.update(value=None)
389
- i = int(evt.index)
390
- rec = syn_records[i]
391
  return gr.update(value=rec["stress"]), gr.update(value=rec["top3"])
392
 
393
- gal.select(
394
- fn=on_gallery_select,
395
- inputs=[syn_state],
396
- outputs=[syn_stress, syn_top3]
397
- )
398
 
399
  if __name__ == "__main__":
400
  demo.launch()
 
8
 
9
  import torch
10
  import open_clip
11
+ from datasets import load_dataset
12
  from sklearn.neighbors import NearestNeighbors
 
13
  from diffusers import StableDiffusionImageVariationPipeline
14
 
15
  # -----------------------------
16
  # Config
17
  # -----------------------------
18
  DATASET_ID = "tukey/human_face_emotions_roboflow"
19
+ EMB_MODEL_NAME = "ViT-H-14" # open_clip model name
20
+ EMB_PRETRAINED = "laion2b_s32b_b79k" # laion/CLIP-ViT-H-14-laion2B-s32B-b79K
21
  GEN_MODEL_ID = "lambdalabs/sd-image-variations-diffusers"
22
 
23
+ CACHE_DIR = Path("./cache"); CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
24
  EMB_MEMMAP_PATH = CACHE_DIR / "clip_vith14_laion2b.float32.memmap"
25
  LABELS_MEMMAP_PATH = CACHE_DIR / "labels.U32.memmap"
26
  KNN_META_PATH = CACHE_DIR / "knn_meta.json"
27
 
28
+ # Default speed settings (can be overridden by Fast mode at runtime)
29
+ INDEX_MAX = 1000 # cap number of dataset items used for the KNN index (first run only)
30
+ BATCH_SIZE = 32 # batch size for embedding build
31
+ N_SYN = 6 # how many variations to generate before picking top-5
32
  NUM_SYN_TO_SHOW = 5
33
+ STEPS = 20 # diffusion steps
34
  GUIDANCE_SCALES = [2.5, 3.0, 3.5, 4.0]
35
 
 
36
  DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu")
37
 
38
  # -----------------------------
39
+ # Canonical label + stress mapping (from your Colab)
40
  # -----------------------------
41
  CANON = {"anger","disgust","fear","happy","neutral","sad","surprise","contempt"}
42
  CANON_MAP = {
 
56
  }
57
  def _bucket(pct: float) -> str:
58
  return "Low" if pct < 33 else ("Medium" if pct < 66 else "High")
 
59
  def stress_from_top3(res: List[Dict]) -> Tuple[float, str]:
60
  probs = {}
61
  for r in res:
62
  lbl = CANON_MAP.get(str(r["emotion"]).lower(), str(r["emotion"]).lower())
63
+ if lbl not in CANON: continue
 
64
  probs[lbl] = probs.get(lbl, 0.0) + float(r["confidence_pct"]) / 100.0
65
  Z = sum(probs.values()) or 1.0
66
+ for k in list(probs): probs[k] /= Z
 
67
  s01 = sum(probs.get(k, 0.0) * STRESS_W.get(k, 0.0) for k in probs)
68
  s01 = max(0.0, min(1.0, s01))
69
  pct = round(s01 * 100.0, 2)
 
85
  # -----------------------------
86
  def _load_openclip():
87
  global _openclip_model, _preprocess
88
+ if _openclip_model is not None: return _openclip_model, _preprocess
 
89
  model, _, preprocess = open_clip.create_model_and_transforms(
90
+ model_name=EMB_MODEL_NAME, pretrained=EMB_PRETRAINED, device=DEVICE
 
 
91
  )
92
  model.eval()
93
  _openclip_model, _preprocess = model, preprocess
94
  return _openclip_model, _preprocess
95
 
96
+ def _fit_knn(X):
97
+ return NearestNeighbors(metric="cosine", algorithm="brute").fit(X)
98
+
99
+ def _ensure_knn_index(index_max: int | None = None, batch_size: int | None = None):
100
+ """Build (first run) or load a memmap + KNN over dataset embeddings."""
101
  global _nn, _X, _labels_source, _dataset_for_labels
102
 
103
+ if _nn is not None and _X is not None: # already ready
104
  return
105
 
106
+ index_max = index_max or INDEX_MAX
107
+ batch_size = batch_size or BATCH_SIZE
108
+
109
  dataset = load_dataset(DATASET_ID, split="train")
110
+ if index_max:
111
+ dataset = dataset.select(range(min(index_max, len(dataset))))
112
  _dataset_for_labels = dataset
 
113
  N = len(dataset)
114
 
115
+ # try loading existing cache if it matches N
116
  if EMB_MEMMAP_PATH.exists() and KNN_META_PATH.exists():
117
+ meta = json.load(open(KNN_META_PATH))
118
+ if int(meta.get("N", -1)) == N:
119
+ D = int(meta["D"])
120
  X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
121
+ labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,)) if LABELS_MEMMAP_PATH.exists() else None
122
+ _X = X; _labels_source = labels; _nn = _fit_knn(X)
 
 
 
 
 
123
  return
124
 
125
+ # build embeddings (batched)
126
  model, preprocess = _load_openclip()
127
+ labels_mm = np.memmap(LABELS_MEMMAP_PATH, mode="w+", dtype="U32", shape=(N,))
128
+ X_w = None; D = None
129
 
130
  def _label_of(i):
131
+ try: return str(dataset[i]["qa"][0]["answer"] or "")
132
+ except Exception: return ""
 
 
 
 
 
 
133
 
134
  with torch.no_grad():
135
+ for start in range(0, N, batch_size):
136
+ end = min(start + batch_size, N)
137
+ imgs = [dataset[i]["image"].convert("RGB") for i in range(start, end)]
138
+ x = torch.stack([preprocess(im) for im in imgs])
139
+ if DEVICE in ("cuda", "mps"): x = x.to(DEVICE)
140
+ v = model.encode_image(x).float()
141
+ v = v / v.norm(dim=-1, keepdim=True)
142
+
143
+ if X_w is None:
144
+ D = v.shape[1]
145
+ X_w = np.memmap(EMB_MEMMAP_PATH, mode="w+", dtype="float32", shape=(N, D))
146
+ X_w[start:end] = v.detach().cpu().numpy()
147
+
148
+ for i in range(start, end):
149
+ labels_mm[i] = _label_of(i)
150
+
151
+ del X_w; gc.collect()
 
 
 
 
 
 
 
 
 
152
  json.dump({"N": int(N), "D": int(D)}, open(KNN_META_PATH, "w"))
153
  X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
154
  labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,))
155
+ _X = X; _labels_source = labels; _nn = _fit_knn(X)
 
 
 
 
 
 
156
 
157
  def _label_by_idx(i: int):
158
  global _labels_source, _dataset_for_labels
159
  if _labels_source is not None:
160
+ lab = str(_labels_source[i]); return lab if lab else None
161
+ try: return _dataset_for_labels[i]["qa"][0]["answer"]
162
+ except Exception: return None
 
 
 
 
163
 
164
  # -----------------------------
165
  # Embedding + inference utils
 
168
  model, preprocess = _load_openclip()
169
  with torch.no_grad():
170
  x = preprocess(img.convert("RGB")).unsqueeze(0)
171
+ if DEVICE in ("cuda", "mps"): x = x.to(DEVICE)
 
172
  v = model.encode_image(x).float()
173
  v = v / v.norm(dim=-1, keepdim=True)
174
  return v.detach().cpu().numpy().squeeze()
175
 
176
  def _top3_emotions_weighted_from_embed(q: np.ndarray,
177
+ start_k: int = 30, step: int = 30,
178
+ method: str = "softmax", tau: float = 0.1):
179
  _ensure_knn_index()
180
+ max_k = _X.shape[0]; k = min(start_k, max_k)
 
 
 
181
  while True:
182
  dist, idx = _nn.kneighbors(q.reshape(1, -1), n_neighbors=k)
183
  idx, dist = idx[0], dist[0]
184
+ sims = np.clip(1.0 - dist, 0.0, None)
 
185
  w = np.exp(sims / tau) if method == "softmax" else sims
186
 
187
+ scores: Dict[str, float] = {}
 
188
  for i, wi in zip(idx, w):
189
  lab = _label_by_idx(int(i))
190
+ if lab is None: continue
 
191
  lab = CANON_MAP.get(str(lab).lower(), str(lab).lower())
192
  scores[lab] = scores.get(lab, 0.0) + float(wi)
 
193
 
194
+ if len([k for k in scores if k in CANON]) >= 3 or k == max_k:
 
195
  break
196
  k = min(k + step, max_k)
197
 
 
 
 
 
198
  scores = {k: v for k, v in scores.items() if k in CANON and v > 0}
199
+ if not scores: return []
 
 
200
  top_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:3]
201
  vals = np.array([v for _, v in top_items], dtype=np.float32)
202
  pct = (vals / vals.sum()) * 100.0 if vals.sum() > 0 else np.zeros_like(vals)
203
+ return [{"rank": i+1, "emotion": lab, "confidence_pct": int(round(p))}
204
+ for i, ((lab, _), p) in enumerate(zip(top_items, pct))]
 
 
205
 
206
  def analyze_face(image: Image.Image):
 
 
207
  q = embed_image(image)
208
  top3 = _top3_emotions_weighted_from_embed(q)
209
  stress_pct, stress_lbl = stress_from_top3(top3)
 
214
  # -----------------------------
215
  def _get_gen_pipe():
216
  global _gen_pipe
217
+ if _gen_pipe is not None: return _gen_pipe
 
218
  gen_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
219
  pipe = StableDiffusionImageVariationPipeline.from_pretrained(
220
+ GEN_MODEL_ID, revision="v2.0", torch_dtype=gen_dtype
221
+ ).to(DEVICE)
 
 
 
222
  _gen_pipe = pipe
223
  return _gen_pipe
224
 
225
+ def generate_synthetics(base_image: Image.Image, base_embed: np.ndarray, n_syn: int, steps: int):
 
226
  pipe = _get_gen_pipe()
 
 
227
  base_gen = torch.Generator(device="cpu").manual_seed(42)
 
228
  records = []
229
+ for _ in range(n_syn):
230
  seed = int(torch.randint(0, 2**31 - 1, (1,), generator=base_gen).item())
231
  gs = random.choice(GUIDANCE_SCALES)
232
  g = torch.Generator(device="cpu").manual_seed(seed)
233
+ out = pipe(image=base_image.convert("RGB"),
234
+ guidance_scale=gs, num_inference_steps=steps, generator=g)
 
 
 
 
 
235
  img = out.images[0]
 
 
236
  emb = embed_image(img)
237
+ sim = float(np.dot(emb, base_embed))
 
 
238
  top3_syn = _top3_emotions_weighted_from_embed(emb)
239
  stress_pct, stress_lbl = stress_from_top3(top3_syn)
240
+ records.append({"image": img, "similarity": sim, "top3": top3_syn,
241
+ "stress": f"{stress_pct}% ({stress_lbl})"})
 
 
 
 
 
 
 
242
  records.sort(key=lambda r: r["similarity"], reverse=True)
243
  return records[:NUM_SYN_TO_SHOW]
244
 
 
246
  # Gradio app
247
  # -----------------------------
248
  def _format_top3_for_table(top3: List[Dict]) -> List[List]:
249
+ return [[r["rank"], r["emotion"], r["confidence_pct"]] for r in top3]
 
 
 
250
 
251
+ with gr.Blocks(title="Face Emotions + Stress (Fast)") as demo:
252
  gr.Markdown(
253
+ "## Face Emotion & Stress Analyzer (Fast)\n"
254
+ "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** via `open_clip`\n"
255
  "- Synthetic variations: **lambdalabs/sd-image-variations-diffusers**\n"
256
  "- KNN labels from: **tukey/human_face_emotions_roboflow**\n"
257
+ "- First run builds a cached index (capped by `INDEX_MAX`).\n"
258
  )
259
 
260
  with gr.Row():
261
  inp = gr.Image(type="pil", label="Upload a face image", sources=["upload", "webcam"])
262
+ fast_mode = gr.Checkbox(value=True, label="Fast mode (smaller index & fewer synthetics)")
263
 
264
  analyze_btn = gr.Button("Analyze & Generate Synthetics")
265
 
 
268
  top3_tbl = gr.Dataframe(
269
  headers=["Rank", "Emotion", "Confidence (%)"],
270
  datatype=["number", "str", "number"],
271
+ interactive=False, row_count=(3, "fixed"), col_count=(3, "fixed"),
 
 
272
  label="Top-3 emotions (original image)"
273
  )
274
  stress_txt = gr.Label(label="Stress index (original)")
275
  with gr.Column():
276
  gal = gr.Gallery(
277
  label="Top 5 synthetic variations (click one)",
278
+ columns=[5], height=200, preview=True # no selectable kwarg
279
  )
280
  syn_stress = gr.Label(label="Stress index (selected synthetic)")
281
  syn_top3 = gr.JSON(label="Top-3 emotions (selected synthetic)")
282
 
283
  status = gr.Markdown(visible=False)
284
+ syn_state = gr.State([])
285
 
286
+ def run_pipeline(image: Image.Image, fast: bool):
 
 
 
287
  try:
288
+ # Tune runtime knobs
289
+ idx_max = 600 if fast else INDEX_MAX
290
+ bs = 32 if fast else BATCH_SIZE
291
+ n_syn = 4 if fast else N_SYN
292
+ steps = 16 if fast else STEPS
293
 
294
+ # Ensure (or build) index with chosen cap/batch
295
+ _ensure_knn_index(index_max=idx_max, batch_size=bs)
296
 
297
+ # Original image analysis
298
+ top3, stress, q = analyze_face(image)
299
+
300
+ # Synthetics
301
+ syn = generate_synthetics(image, q, n_syn=n_syn, steps=steps)
302
  items = [(r["image"], f"sim={r['similarity']:.3f}") for r in syn]
303
+
304
+ return _format_top3_for_table(top3), stress, items, syn, gr.update(visible=False), None
305
  except Exception as e:
306
  return None, None, None, [], gr.update(visible=True, value=f"**Error:** {e}"), None
307
 
308
  analyze_btn.click(
309
+ run_pipeline, inputs=[inp, fast_mode],
 
310
  outputs=[top3_tbl, stress_txt, gal, syn_state, status, syn_top3]
311
  )
312
 
313
  def on_gallery_select(evt: gr.SelectData, syn_records: List[Dict]):
314
+ if not syn_records or evt is None: return gr.update(value=None), gr.update(value=None)
315
+ i = int(evt.index); rec = syn_records[i]
 
 
 
316
  return gr.update(value=rec["stress"]), gr.update(value=rec["top3"])
317
 
318
+ gal.select(fn=on_gallery_select, inputs=[syn_state], outputs=[syn_stress, syn_top3])
 
 
 
 
319
 
320
  if __name__ == "__main__":
321
  demo.launch()