Shani13524 commited on
Commit
cf0f4c3
·
verified ·
1 Parent(s): 9dfa0c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -25
app.py CHANGED
@@ -13,7 +13,7 @@ from sklearn.neighbors import NearestNeighbors
13
  from diffusers import StableDiffusionImageVariationPipeline
14
 
15
  # -----------------------------
16
- # Config (CPU-friendly defaults)
17
  # -----------------------------
18
  DATASET_ID = "tukey/human_face_emotions_roboflow"
19
  EMB_MODEL_NAME = "ViT-H-14" # open_clip model name
@@ -25,18 +25,18 @@ EMB_MEMMAP_PATH = CACHE_DIR / "clip_vith14_laion2b.float32.memmap"
25
  LABELS_MEMMAP_PATH = CACHE_DIR / "labels.U32.memmap"
26
  KNN_META_PATH = CACHE_DIR / "knn_meta.json"
27
 
28
- # **tiny index** + light gen for CPU
29
- INDEX_MAX_DEFAULT = 80 # small subset → fast
30
  BATCH_SIZE_DEFAULT = 32
31
  N_SYN_DEFAULT = 3
32
  STEPS_DEFAULT = 12
33
  GUIDANCE_SCALES = [2.5, 3.0, 3.5, 4.0]
34
- NUM_SYN_TO_SHOW = 5 # show up to 5 if you generate that many
35
 
36
  DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu")
37
 
38
  # -----------------------------
39
- # Canonical label + stress mapping (your logic)
40
  # -----------------------------
41
  CANON = {"anger","disgust","fear","happy","neutral","sad","surprise","contempt"}
42
  CANON_MAP = {
@@ -87,7 +87,7 @@ def _load_openclip():
87
  def _fit_knn(X): return NearestNeighbors(metric="cosine", algorithm="brute").fit(X)
88
 
89
  def _ensure_knn_index(index_max: int, batch_size: int, progress: gr.Progress | None = None):
90
- """Build (first run) or load a memmap + KNN over a tiny dataset subset."""
91
  global _nn, _X, _labels_source, _dataset_for_labels
92
 
93
  if _nn is not None and _X is not None:
@@ -99,7 +99,6 @@ def _ensure_knn_index(index_max: int, batch_size: int, progress: gr.Progress | N
99
  _dataset_for_labels = dataset
100
  N = len(dataset)
101
 
102
- # load cache if exists with same N
103
  if EMB_MEMMAP_PATH.exists() and KNN_META_PATH.exists():
104
  meta = json.load(open(KNN_META_PATH))
105
  if int(meta.get("N", -1)) == N:
@@ -109,12 +108,10 @@ def _ensure_knn_index(index_max: int, batch_size: int, progress: gr.Progress | N
109
  _X = X; _labels_source = labels; _nn = _fit_knn(X)
110
  return
111
 
112
- # build tiny embedding index (batched)
113
  model, preprocess = _load_openclip()
114
  labels_mm = np.memmap(LABELS_MEMMAP_PATH, mode="w+", dtype="U32", shape=(N,))
115
  X_w = None; D = None
116
 
117
- step = 0
118
  with torch.no_grad():
119
  for start in range(0, N, batch_size):
120
  end = min(start + batch_size, N)
@@ -130,7 +127,6 @@ def _ensure_knn_index(index_max: int, batch_size: int, progress: gr.Progress | N
130
  for i in range(start, end):
131
  try: labels_mm[i] = str(dataset[i]["qa"][0]["answer"] or "")
132
  except Exception: labels_mm[i] = ""
133
- step += 1
134
  if progress: progress(((end)/N), desc=f"Building index {end}/{N}")
135
 
136
  del X_w; gc.collect()
@@ -190,8 +186,22 @@ def analyze_face(image: Image.Image):
190
  stress_pct, stress_lbl = stress_from_top3(top3)
191
  return top3, f"{stress_pct}% ({stress_lbl})", q
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # -----------------------------
194
- # Generator (optional, on click)
195
  # -----------------------------
196
  def _get_gen_pipe():
197
  global _gen_pipe
@@ -207,7 +217,7 @@ def generate_synthetics(base_image: Image.Image, base_embed: np.ndarray, n_syn:
207
  pipe = _get_gen_pipe()
208
  base_gen = torch.Generator(device="cpu").manual_seed(42)
209
  records = []
210
- for i in progress.tqdm(range(n_syn), desc="Generating"):
211
  seed = int(torch.randint(0, 2**31 - 1, (1,), generator=base_gen).item())
212
  gs = random.choice(GUIDANCE_SCALES)
213
  g = torch.Generator(device="cpu").manual_seed(seed)
@@ -222,7 +232,7 @@ def generate_synthetics(base_image: Image.Image, base_embed: np.ndarray, n_syn:
222
  return records[:NUM_SYN_TO_SHOW]
223
 
224
  # -----------------------------
225
- # Gradio app (two-step: Analyze → (optional) Generate)
226
  # -----------------------------
227
  def _format_top3_for_table(top3: List[Dict]) -> List[List]:
228
  return [[r["rank"], r["emotion"], r["confidence_pct"]] for r in top3]
@@ -231,8 +241,8 @@ with gr.Blocks(title="Face Emotions + Stress (CPU Fast)") as demo:
231
  gr.Markdown(
232
  "## Face Emotion & Stress Analyzer — CPU-friendly\n"
233
  "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
234
- "- Synthetic variations (optional): **lambdalabs/sd-image-variations-diffusers**\n"
235
- "- Uses a tiny cached index for speed on free CPU.\n"
236
  )
237
 
238
  with gr.Row():
@@ -252,35 +262,66 @@ with gr.Blocks(title="Face Emotions + Stress (CPU Fast)") as demo:
252
  )
253
  stress_txt = gr.Label(label="Stress index (original)")
254
  with gr.Column():
255
- # generation controls (optional)
256
- n_syn = gr.Slider(0, 5, value=N_SYN_DEFAULT, step=1, label="How many variations to generate")
 
 
 
 
 
 
 
 
257
  steps = gr.Slider(8, 30, value=STEPS_DEFAULT, step=2, label="Diffusion steps (higher = slower/better)")
258
  gen_btn = gr.Button("Generate variations (optional)")
259
  gal = gr.Gallery(label="Synthetic variations (click one)", columns=[5], height=220, preview=True)
260
- syn_stress = gr.Label(label="Stress index (selected synthetic)")
261
  syn_top3 = gr.JSON(label="Top-3 emotions (selected synthetic)")
262
 
263
  status = gr.Markdown(visible=False)
264
 
265
- # state we pass between steps
266
- syn_state = gr.State([]) # list of generated records
267
- q_state = gr.State(None) # original embedding
268
- img_state = gr.State(None) # original image (for gen step)
269
 
 
270
  def do_analyze(image: Image.Image, cap: int, batch: int, progress=gr.Progress(track_tqdm=True)):
271
  try:
272
  _ensure_knn_index(index_max=int(cap), batch_size=int(batch), progress=progress)
273
  top3, stress, q = analyze_face(image)
274
- return _format_top3_for_table(top3), stress, [], [], q, image, gr.update(visible=False)
 
 
 
 
 
 
 
275
  except Exception as e:
276
- return None, None, [], [], None, None, gr.update(visible=True, value=f"**Error:** {e}")
277
 
278
  analyze_btn.click(
279
  do_analyze,
280
  inputs=[inp, idx_cap, bs],
281
- outputs=[top3_tbl, stress_txt, gal, syn_state, q_state, img_state, status]
282
  )
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def do_generate(n: int, s: int, q, img, progress=gr.Progress()):
285
  if q is None or img is None:
286
  return [], [], gr.update(visible=True, value="**Error:** Analyze first."), None
@@ -297,6 +338,7 @@ with gr.Blocks(title="Face Emotions + Stress (CPU Fast)") as demo:
297
  outputs=[gal, syn_state, status, syn_top3]
298
  )
299
 
 
300
  def on_gallery_select(evt: gr.SelectData, syn_records: List[Dict]):
301
  if not syn_records or evt is None: return gr.update(value=None), gr.update(value=None)
302
  i = int(evt.index); rec = syn_records[i]
 
13
  from diffusers import StableDiffusionImageVariationPipeline
14
 
15
  # -----------------------------
16
+ # Config (CPU-friendly)
17
  # -----------------------------
18
  DATASET_ID = "tukey/human_face_emotions_roboflow"
19
  EMB_MODEL_NAME = "ViT-H-14" # open_clip model name
 
25
  LABELS_MEMMAP_PATH = CACHE_DIR / "labels.U32.memmap"
26
  KNN_META_PATH = CACHE_DIR / "knn_meta.json"
27
 
28
+ # tiny index + light generation
29
+ INDEX_MAX_DEFAULT = 80
30
  BATCH_SIZE_DEFAULT = 32
31
  N_SYN_DEFAULT = 3
32
  STEPS_DEFAULT = 12
33
  GUIDANCE_SCALES = [2.5, 3.0, 3.5, 4.0]
34
+ NUM_SYN_TO_SHOW = 5
35
 
36
  DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu")
37
 
38
  # -----------------------------
39
+ # Canonical labels + stress
40
  # -----------------------------
41
  CANON = {"anger","disgust","fear","happy","neutral","sad","surprise","contempt"}
42
  CANON_MAP = {
 
87
  def _fit_knn(X): return NearestNeighbors(metric="cosine", algorithm="brute").fit(X)
88
 
89
  def _ensure_knn_index(index_max: int, batch_size: int, progress: gr.Progress | None = None):
90
+ """Build (first run) or load a tiny memmap + KNN over a subset of the dataset."""
91
  global _nn, _X, _labels_source, _dataset_for_labels
92
 
93
  if _nn is not None and _X is not None:
 
99
  _dataset_for_labels = dataset
100
  N = len(dataset)
101
 
 
102
  if EMB_MEMMAP_PATH.exists() and KNN_META_PATH.exists():
103
  meta = json.load(open(KNN_META_PATH))
104
  if int(meta.get("N", -1)) == N:
 
108
  _X = X; _labels_source = labels; _nn = _fit_knn(X)
109
  return
110
 
 
111
  model, preprocess = _load_openclip()
112
  labels_mm = np.memmap(LABELS_MEMMAP_PATH, mode="w+", dtype="U32", shape=(N,))
113
  X_w = None; D = None
114
 
 
115
  with torch.no_grad():
116
  for start in range(0, N, batch_size):
117
  end = min(start + batch_size, N)
 
127
  for i in range(start, end):
128
  try: labels_mm[i] = str(dataset[i]["qa"][0]["answer"] or "")
129
  except Exception: labels_mm[i] = ""
 
130
  if progress: progress(((end)/N), desc=f"Building index {end}/{N}")
131
 
132
  del X_w; gc.collect()
 
186
  stress_pct, stress_lbl = stress_from_top3(top3)
187
  return top3, f"{stress_pct}% ({stress_lbl})", q
188
 
189
+ # ----- Nearest neighbors images from dataset -----
190
+ def _get_dataset_image(i: int) -> Image.Image:
191
+ return _dataset_for_labels[int(i)]["image"].convert("RGB")
192
+
193
+ def nearest_k_images_from_dataset(q_emb: np.ndarray, k: int = 5):
194
+ dist, idx = _nn.kneighbors(q_emb.reshape(1, -1), n_neighbors=k)
195
+ dist, idx = dist[0], idx[0]
196
+ sims = (1.0 - dist).tolist()
197
+ out = []
198
+ for i, s in zip(idx, sims):
199
+ img = _get_dataset_image(int(i))
200
+ out.append((img, float(s), int(i)))
201
+ return out
202
+
203
  # -----------------------------
204
+ # Generator (optional)
205
  # -----------------------------
206
  def _get_gen_pipe():
207
  global _gen_pipe
 
217
  pipe = _get_gen_pipe()
218
  base_gen = torch.Generator(device="cpu").manual_seed(42)
219
  records = []
220
+ for _ in progress.tqdm(range(n_syn), desc="Generating"):
221
  seed = int(torch.randint(0, 2**31 - 1, (1,), generator=base_gen).item())
222
  gs = random.choice(GUIDANCE_SCALES)
223
  g = torch.Generator(device="cpu").manual_seed(seed)
 
232
  return records[:NUM_SYN_TO_SHOW]
233
 
234
  # -----------------------------
235
+ # UI
236
  # -----------------------------
237
  def _format_top3_for_table(top3: List[Dict]) -> List[List]:
238
  return [[r["rank"], r["emotion"], r["confidence_pct"]] for r in top3]
 
241
  gr.Markdown(
242
  "## Face Emotion & Stress Analyzer — CPU-friendly\n"
243
  "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
244
+ "- Optional SD variations: **lambdalabs/sd-image-variations-diffusers**\n"
245
+ "- Also shows **nearest 5 images from the dataset** for 1-click results.\n"
246
  )
247
 
248
  with gr.Row():
 
262
  )
263
  stress_txt = gr.Label(label="Stress index (original)")
264
  with gr.Column():
265
+ # Nearest 5 from dataset (one-click examples)
266
+ nn_gal = gr.Gallery(
267
+ label="Nearest 5 from dataset (click one)",
268
+ columns=[5], height=220, preview=True
269
+ )
270
+ nn_stress = gr.Label(label="Stress (nearest image)")
271
+ nn_top3 = gr.JSON(label="Top-3 emotions (nearest image)")
272
+
273
+ # Optional generator
274
+ n_syn = gr.Slider(0, 5, value=N_SYN_DEFAULT, step=1, label="How many SD variations to generate")
275
  steps = gr.Slider(8, 30, value=STEPS_DEFAULT, step=2, label="Diffusion steps (higher = slower/better)")
276
  gen_btn = gr.Button("Generate variations (optional)")
277
  gal = gr.Gallery(label="Synthetic variations (click one)", columns=[5], height=220, preview=True)
278
+ syn_stress = gr.Label(label="Stress (selected synthetic)")
279
  syn_top3 = gr.JSON(label="Top-3 emotions (selected synthetic)")
280
 
281
  status = gr.Markdown(visible=False)
282
 
283
+ # State
284
+ syn_state = gr.State([]) # generated variations
285
+ q_state = gr.State(None) # embedding of original image
286
+ img_state = gr.State(None) # original image
287
 
288
+ # ---- Analyze ----
289
  def do_analyze(image: Image.Image, cap: int, batch: int, progress=gr.Progress(track_tqdm=True)):
290
  try:
291
  _ensure_knn_index(index_max=int(cap), batch_size=int(batch), progress=progress)
292
  top3, stress, q = analyze_face(image)
293
+
294
+ # nearest 5 images from dataset
295
+ neigh = nearest_k_images_from_dataset(np.array(q, dtype=np.float32), k=5)
296
+ nn_items = [(im, f"sim={sim:.3f} • idx={idx}") for im, sim, idx in neigh]
297
+
298
+ # return: top3, stress, nn gallery, (empty SD gallery), syn_state, q, img, status
299
+ return (_format_top3_for_table(top3), stress,
300
+ nn_items, [], [], q, image, gr.update(visible=False))
301
  except Exception as e:
302
+ return None, None, [], [], [], None, None, gr.update(visible=True, value=f"**Error:** {e}")
303
 
304
  analyze_btn.click(
305
  do_analyze,
306
  inputs=[inp, idx_cap, bs],
307
+ outputs=[top3_tbl, stress_txt, nn_gal, gal, syn_state, q_state, img_state, status]
308
  )
309
 
310
+ # ---- One-click on a nearest image ----
311
+ def on_nn_select(evt: gr.SelectData, q):
312
+ if q is None:
313
+ return gr.update(value="Analyze first"), None
314
+ neigh = nearest_k_images_from_dataset(np.array(q, dtype=np.float32), k=5)
315
+ i = max(0, min(int(evt.index), len(neigh)-1))
316
+ img, _, _ = neigh[i]
317
+ emb = embed_image(img)
318
+ top3 = _top3_emotions_weighted_from_embed(emb)
319
+ stress_pct, stress_lbl = stress_from_top3(top3)
320
+ return f"{stress_pct}% ({stress_lbl})", top3
321
+
322
+ nn_gal.select(fn=on_nn_select, inputs=[q_state], outputs=[nn_stress, nn_top3])
323
+
324
+ # ---- Optional: generate SD variations ----
325
  def do_generate(n: int, s: int, q, img, progress=gr.Progress()):
326
  if q is None or img is None:
327
  return [], [], gr.update(visible=True, value="**Error:** Analyze first."), None
 
338
  outputs=[gal, syn_state, status, syn_top3]
339
  )
340
 
341
+ # select from generated synthetics
342
  def on_gallery_select(evt: gr.SelectData, syn_records: List[Dict]):
343
  if not syn_records or evt is None: return gr.update(value=None), gr.update(value=None)
344
  i = int(evt.index); rec = syn_records[i]