Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,350 +1,294 @@
|
|
| 1 |
-
import os, json,
|
| 2 |
-
from
|
| 3 |
-
from typing import List, Tuple, Dict
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
import torch
|
| 10 |
-
import
|
|
|
|
|
|
|
| 11 |
from datasets import load_dataset
|
| 12 |
from sklearn.neighbors import NearestNeighbors
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# -----------------------------
|
| 39 |
-
# Canonical labels + stress
|
| 40 |
-
# -----------------------------
|
| 41 |
-
CANON = {"anger","disgust","fear","happy","neutral","sad","surprise","contempt"}
|
| 42 |
-
CANON_MAP = {
|
| 43 |
-
"angry": "anger", "happiness": "happy", "happy": "happy",
|
| 44 |
-
"sadness": "sad", "sad": "sad", "surprised": "surprise", "surprise": "surprise",
|
| 45 |
-
"content": "neutral", "calm": "neutral", "neutral": "neutral",
|
| 46 |
-
"contempt": "contempt", "fear": "fear", "disgust": "disgust", "anger": "anger",
|
| 47 |
-
}
|
| 48 |
-
STRESS_W = {"anger":0.95,"fear":0.90,"sad":0.80,"disgust":0.70,"contempt":0.65,"surprise":0.45,"neutral":0.20,"happy":0.05}
|
| 49 |
-
def _bucket(pct: float) -> str: return "Low" if pct < 33 else ("Medium" if pct < 66 else "High")
|
| 50 |
-
def stress_from_top3(res: List[Dict]) -> Tuple[float, str]:
|
| 51 |
-
probs = {}
|
| 52 |
-
for r in res:
|
| 53 |
-
lbl = CANON_MAP.get(str(r["emotion"]).lower(), str(r["emotion"]).lower())
|
| 54 |
-
if lbl not in CANON: continue
|
| 55 |
-
probs[lbl] = probs.get(lbl, 0.0) + float(r["confidence_pct"]) / 100.0
|
| 56 |
-
Z = sum(probs.values()) or 1.0
|
| 57 |
-
for k in list(probs): probs[k] /= Z
|
| 58 |
-
s01 = sum(probs.get(k, 0.0) * STRESS_W.get(k, 0.0) for k in probs)
|
| 59 |
-
s01 = max(0.0, min(1.0, s01))
|
| 60 |
-
pct = round(s01 * 100.0, 2)
|
| 61 |
-
return pct, _bucket(pct)
|
| 62 |
-
|
| 63 |
-
# -----------------------------
|
| 64 |
-
# Lazy globals
|
| 65 |
-
# -----------------------------
|
| 66 |
-
_openclip_model = None
|
| 67 |
-
_preprocess = None
|
| 68 |
-
_nn = None
|
| 69 |
-
_X = None
|
| 70 |
-
_labels_source = None
|
| 71 |
-
_gen_pipe = None
|
| 72 |
-
_dataset_for_labels = None
|
| 73 |
-
|
| 74 |
-
# -----------------------------
|
| 75 |
-
# Init / cache helpers
|
| 76 |
-
# -----------------------------
|
| 77 |
-
def _load_openclip():
|
| 78 |
-
global _openclip_model, _preprocess
|
| 79 |
-
if _openclip_model is not None: return _openclip_model, _preprocess
|
| 80 |
-
model, _, preprocess = open_clip.create_model_and_transforms(
|
| 81 |
-
model_name=EMB_MODEL_NAME, pretrained=EMB_PRETRAINED, device=DEVICE
|
| 82 |
)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
global _nn, _X, _labels_source, _dataset_for_labels
|
| 92 |
-
|
| 93 |
-
if _nn is not None and _X is not None:
|
| 94 |
-
return
|
| 95 |
-
|
| 96 |
-
dataset = load_dataset(DATASET_ID, split="train")
|
| 97 |
-
if index_max:
|
| 98 |
-
dataset = dataset.select(range(min(index_max, len(dataset))))
|
| 99 |
-
_dataset_for_labels = dataset
|
| 100 |
-
N = len(dataset)
|
| 101 |
-
|
| 102 |
-
if EMB_MEMMAP_PATH.exists() and KNN_META_PATH.exists():
|
| 103 |
-
meta = json.load(open(KNN_META_PATH))
|
| 104 |
-
if int(meta.get("N", -1)) == N:
|
| 105 |
-
D = int(meta["D"])
|
| 106 |
-
X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
|
| 107 |
-
labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,)) if LABELS_MEMMAP_PATH.exists() else None
|
| 108 |
-
_X = X; _labels_source = labels; _nn = _fit_knn(X)
|
| 109 |
-
return
|
| 110 |
-
|
| 111 |
-
model, preprocess = _load_openclip()
|
| 112 |
-
labels_mm = np.memmap(LABELS_MEMMAP_PATH, mode="w+", dtype="U32", shape=(N,))
|
| 113 |
-
X_w = None; D = None
|
| 114 |
-
|
| 115 |
-
with torch.no_grad():
|
| 116 |
-
for start in range(0, N, batch_size):
|
| 117 |
-
end = min(start + batch_size, N)
|
| 118 |
-
imgs = [dataset[i]["image"].convert("RGB") for i in range(start, end)]
|
| 119 |
-
x = torch.stack([preprocess(im) for im in imgs])
|
| 120 |
-
if DEVICE in ("cuda", "mps"): x = x.to(DEVICE)
|
| 121 |
-
v = model.encode_image(x).float()
|
| 122 |
-
v = v / v.norm(dim=-1, keepdim=True)
|
| 123 |
-
if X_w is None:
|
| 124 |
-
D = v.shape[1]
|
| 125 |
-
X_w = np.memmap(EMB_MEMMAP_PATH, mode="w+", dtype="float32", shape=(N, D))
|
| 126 |
-
X_w[start:end] = v.detach().cpu().numpy()
|
| 127 |
-
for i in range(start, end):
|
| 128 |
-
try: labels_mm[i] = str(dataset[i]["qa"][0]["answer"] or "")
|
| 129 |
-
except Exception: labels_mm[i] = ""
|
| 130 |
-
if progress: progress(((end)/N), desc=f"Building index {end}/{N}")
|
| 131 |
-
|
| 132 |
-
del X_w; gc.collect()
|
| 133 |
-
json.dump({"N": int(N), "D": int(D)}, open(KNN_META_PATH, "w"))
|
| 134 |
-
X = np.memmap(EMB_MEMMAP_PATH, mode="r", dtype="float32", shape=(N, D))
|
| 135 |
-
labels = np.memmap(LABELS_MEMMAP_PATH, mode="r", dtype="U32", shape=(N,))
|
| 136 |
-
_X = X; _labels_source = labels; _nn = _fit_knn(X)
|
| 137 |
-
|
| 138 |
-
def _label_by_idx(i: int):
|
| 139 |
-
global _labels_source, _dataset_for_labels
|
| 140 |
-
if _labels_source is not None:
|
| 141 |
-
lab = str(_labels_source[i]); return lab if lab else None
|
| 142 |
-
try: return _dataset_for_labels[i]["qa"][0]["answer"]
|
| 143 |
-
except Exception: return None
|
| 144 |
-
|
| 145 |
-
# -----------------------------
|
| 146 |
-
# Embedding + inference utils
|
| 147 |
-
# -----------------------------
|
| 148 |
def embed_image(img: Image.Image) -> np.ndarray:
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
# ----- Nearest neighbors images from dataset -----
|
| 190 |
-
def _get_dataset_image(i: int) -> Image.Image:
|
| 191 |
-
return _dataset_for_labels[int(i)]["image"].convert("RGB")
|
| 192 |
-
|
| 193 |
-
def nearest_k_images_from_dataset(q_emb: np.ndarray, k: int = 5):
|
| 194 |
-
dist, idx = _nn.kneighbors(q_emb.reshape(1, -1), n_neighbors=k)
|
| 195 |
-
dist, idx = dist[0], idx[0]
|
| 196 |
-
sims = (1.0 - dist).tolist()
|
| 197 |
out = []
|
| 198 |
-
for
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
return out
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
|
| 205 |
-
#
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
gen_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 210 |
-
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
| 211 |
-
GEN_MODEL_ID, revision="v2.0", torch_dtype=gen_dtype
|
| 212 |
-
).to(DEVICE)
|
| 213 |
-
_gen_pipe = pipe
|
| 214 |
-
return _gen_pipe
|
| 215 |
-
|
| 216 |
-
def generate_synthetics(base_image: Image.Image, base_embed: np.ndarray, n_syn: int, steps: int, progress: gr.Progress):
|
| 217 |
-
pipe = _get_gen_pipe()
|
| 218 |
-
base_gen = torch.Generator(device="cpu").manual_seed(42)
|
| 219 |
-
records = []
|
| 220 |
-
for _ in progress.tqdm(range(n_syn), desc="Generating"):
|
| 221 |
-
seed = int(torch.randint(0, 2**31 - 1, (1,), generator=base_gen).item())
|
| 222 |
-
gs = random.choice(GUIDANCE_SCALES)
|
| 223 |
-
g = torch.Generator(device="cpu").manual_seed(seed)
|
| 224 |
-
out = pipe(image=base_image.convert("RGB"), guidance_scale=gs, num_inference_steps=steps, generator=g)
|
| 225 |
-
img = out.images[0]
|
| 226 |
-
emb = embed_image(img)
|
| 227 |
-
sim = float(np.dot(emb, base_embed))
|
| 228 |
-
top3_syn = _top3_emotions_weighted_from_embed(emb)
|
| 229 |
-
stress_pct, stress_lbl = stress_from_top3(top3_syn)
|
| 230 |
-
records.append({"image": img, "similarity": sim, "top3": top3_syn, "stress": f"{stress_pct}% ({stress_lbl})"})
|
| 231 |
-
records.sort(key=lambda r: r["similarity"], reverse=True)
|
| 232 |
-
return records[:NUM_SYN_TO_SHOW]
|
| 233 |
-
|
| 234 |
-
# -----------------------------
|
| 235 |
-
# UI
|
| 236 |
-
# -----------------------------
|
| 237 |
-
def _format_top3_for_table(top3: List[Dict]) -> List[List]:
|
| 238 |
-
return [[r["rank"], r["emotion"], r["confidence_pct"]] for r in top3]
|
| 239 |
-
|
| 240 |
-
with gr.Blocks(title="Face Emotions + Stress (CPU Fast)") as demo:
|
| 241 |
gr.Markdown(
|
| 242 |
-
"
|
| 243 |
"- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
|
| 244 |
-
"-
|
| 245 |
-
"-
|
|
|
|
| 246 |
)
|
| 247 |
|
|
|
|
| 248 |
with gr.Row():
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
analyze_btn = gr.Button("Analyze (no synthetics)")
|
| 254 |
-
|
| 255 |
-
with gr.Row():
|
| 256 |
-
with gr.Column():
|
| 257 |
-
top3_tbl = gr.Dataframe(
|
| 258 |
headers=["Rank", "Emotion", "Confidence (%)"],
|
| 259 |
datatype=["number", "str", "number"],
|
| 260 |
-
interactive=False,
|
| 261 |
-
|
| 262 |
)
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
)
|
| 270 |
-
|
| 271 |
-
nn_top3 = gr.JSON(label="Top-3 emotions (nearest image)")
|
| 272 |
-
|
| 273 |
-
# Optional generator
|
| 274 |
-
n_syn = gr.Slider(0, 5, value=N_SYN_DEFAULT, step=1, label="How many SD variations to generate")
|
| 275 |
-
steps = gr.Slider(8, 30, value=STEPS_DEFAULT, step=2, label="Diffusion steps (higher = slower/better)")
|
| 276 |
-
gen_btn = gr.Button("Generate variations (optional)")
|
| 277 |
-
gal = gr.Gallery(label="Synthetic variations (click one)", columns=[5], height=220, preview=True)
|
| 278 |
-
syn_stress = gr.Label(label="Stress (selected synthetic)")
|
| 279 |
-
syn_top3 = gr.JSON(label="Top-3 emotions (selected synthetic)")
|
| 280 |
-
|
| 281 |
-
status = gr.Markdown(visible=False)
|
| 282 |
-
|
| 283 |
-
# State
|
| 284 |
-
syn_state = gr.State([]) # generated variations
|
| 285 |
-
q_state = gr.State(None) # embedding of original image
|
| 286 |
-
img_state = gr.State(None) # original image
|
| 287 |
-
|
| 288 |
-
# ---- Analyze ----
|
| 289 |
-
def do_analyze(image: Image.Image, cap: int, batch: int, progress=gr.Progress(track_tqdm=True)):
|
| 290 |
-
try:
|
| 291 |
-
_ensure_knn_index(index_max=int(cap), batch_size=int(batch), progress=progress)
|
| 292 |
-
top3, stress, q = analyze_face(image)
|
| 293 |
-
|
| 294 |
-
# nearest 5 images from dataset
|
| 295 |
-
neigh = nearest_k_images_from_dataset(np.array(q, dtype=np.float32), k=5)
|
| 296 |
-
nn_items = [(im, f"sim={sim:.3f} • idx={idx}") for im, sim, idx in neigh]
|
| 297 |
-
|
| 298 |
-
# return: top3, stress, nn gallery, (empty SD gallery), syn_state, q, img, status
|
| 299 |
-
return (_format_top3_for_table(top3), stress,
|
| 300 |
-
nn_items, [], [], q, image, gr.update(visible=False))
|
| 301 |
-
except Exception as e:
|
| 302 |
-
return None, None, [], [], [], None, None, gr.update(visible=True, value=f"**Error:** {e}")
|
| 303 |
-
|
| 304 |
-
analyze_btn.click(
|
| 305 |
-
do_analyze,
|
| 306 |
-
inputs=[inp, idx_cap, bs],
|
| 307 |
-
outputs=[top3_tbl, stress_txt, nn_gal, gal, syn_state, q_state, img_state, status]
|
| 308 |
-
)
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
#
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
return
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
)
|
| 340 |
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
if __name__ == "__main__":
|
| 350 |
demo.launch()
|
|
|
|
| 1 |
+
import os, io, math, json, random, numpy as np
|
| 2 |
+
from typing import List, Tuple, Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
|
| 5 |
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
from datasets import load_dataset
|
| 10 |
from sklearn.neighbors import NearestNeighbors
|
| 11 |
+
from transformers import pipeline
|
| 12 |
+
|
| 13 |
+
# =============== CONFIG ===============
|
| 14 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
|
| 16 |
+
# Embeddings (your chosen model)
|
| 17 |
+
OPENCLIP_BACKBONE = "ViT-H-14"
|
| 18 |
+
OPENCLIP_PRETRAIN = "laion2B-s32B-b79K" # laion/CLIP-ViT-H-14-laion2B-s32B-b79K
|
| 19 |
+
INDEX_SIZE = int(os.getenv("INDEX_SIZE", 400)) # כמה תמונות מהדאטהסט לאינדוקס
|
| 20 |
+
TOPK_NEAREST = 5
|
| 21 |
+
|
| 22 |
+
# Emotion model (your chosen model)
|
| 23 |
+
EMO_MODEL = "prithivMLmods/Facial-Emotion-Detection-SigLIP2"
|
| 24 |
+
|
| 25 |
+
# Optional image-variation generator (your chosen model)
|
| 26 |
+
USE_SD_VARIATIONS = True
|
| 27 |
+
SD_MODEL = "lambdalabs/sd-image-variations-diffusers"
|
| 28 |
+
# =====================================
|
| 29 |
+
|
| 30 |
+
# ---------- Load OpenCLIP for image embeddings ----------
|
| 31 |
+
try:
|
| 32 |
+
import open_clip
|
| 33 |
+
_openclip_model, _, _openclip_preprocess = open_clip.create_model_and_transforms(
|
| 34 |
+
OPENCLIP_BACKBONE, pretrained=OPENCLIP_PRETRAIN
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
)
|
| 36 |
+
_openclip_model = _openclip_model.to(DEVICE).eval()
|
| 37 |
+
except Exception as e:
|
| 38 |
+
raise RuntimeError(
|
| 39 |
+
f"Failed to load OpenCLIP ({OPENCLIP_BACKBONE} / {OPENCLIP_PRETRAIN}). "
|
| 40 |
+
f"Install 'open_clip_torch' and verify CUDA if available. Error: {e}"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
@torch.inference_mode()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def embed_image(img: Image.Image) -> np.ndarray:
|
| 45 |
+
img = img.convert("RGB")
|
| 46 |
+
tens = _openclip_preprocess(img).unsqueeze(0).to(DEVICE)
|
| 47 |
+
feats = _openclip_model.encode_image(tens)
|
| 48 |
+
feats = F.normalize(feats, dim=-1).squeeze(0).detach().cpu().numpy().astype(np.float32)
|
| 49 |
+
return feats # shape [D]
|
| 50 |
+
|
| 51 |
+
# ---------- Dataset + index ----------
|
| 52 |
+
DATASET_NAME = "tukey/human_face_emotions_roboflow"
|
| 53 |
+
DATASET_SPLIT = "train"
|
| 54 |
+
|
| 55 |
+
def _load_images_for_index(n: int) -> List[Image.Image]:
|
| 56 |
+
ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
| 57 |
+
n = min(n, len(ds))
|
| 58 |
+
imgs = []
|
| 59 |
+
for i in range(n):
|
| 60 |
+
# column is usually "image"
|
| 61 |
+
im = ds[i].get("image")
|
| 62 |
+
if isinstance(im, Image.Image):
|
| 63 |
+
imgs.append(im.copy())
|
| 64 |
+
return imgs
|
| 65 |
+
|
| 66 |
+
def build_index(imgs: List[Image.Image]) -> Tuple[NearestNeighbors, np.ndarray]:
|
| 67 |
+
vecs = []
|
| 68 |
+
for im in imgs:
|
| 69 |
+
vecs.append(embed_image(im))
|
| 70 |
+
X = np.stack(vecs, axis=0)
|
| 71 |
+
nn = NearestNeighbors(metric="cosine", n_neighbors=min(TOPK_NEAREST, len(imgs)))
|
| 72 |
+
nn.fit(X)
|
| 73 |
+
return nn, X
|
| 74 |
+
|
| 75 |
+
print("Loading dataset & building index (first time only)...")
|
| 76 |
+
DATASET_IMAGES: List[Image.Image] = _load_images_for_index(INDEX_SIZE)
|
| 77 |
+
NN_MODEL, EMB_MATRIX = build_index(DATASET_IMAGES)
|
| 78 |
+
print(f"Index ready with {len(DATASET_IMAGES)} images.")
|
| 79 |
+
|
| 80 |
+
def nearest5(pil_img: Image.Image) -> List[Tuple[Image.Image, str]]:
|
| 81 |
+
q = embed_image(pil_img).reshape(1, -1)
|
| 82 |
+
dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=min(5, len(DATASET_IMAGES)))
|
| 83 |
+
# cosine distance -> similarity = 1 - dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
out = []
|
| 85 |
+
for rank, (dist, idx) in enumerate(zip(dists[0], idxs[0]), start=1):
|
| 86 |
+
sim = 1.0 - float(dist)
|
| 87 |
+
im = DATASET_IMAGES[int(idx)]
|
| 88 |
+
caption = f"#{rank} sim={sim:.3f} idx={int(idx)}"
|
| 89 |
+
out.append((im, caption))
|
| 90 |
+
return out # list of (PIL, caption)
|
| 91 |
+
|
| 92 |
+
# ---------- Emotion & Stress ----------
|
| 93 |
+
EMO_MAP = {
|
| 94 |
+
"anger": "anger", "angry": "anger",
|
| 95 |
+
"disgust": "disgust",
|
| 96 |
+
"fear": "fear",
|
| 97 |
+
"happy": "happy", "happiness": "happy",
|
| 98 |
+
"neutral": "neutral", "calm": "neutral",
|
| 99 |
+
"sad": "sad", "sadness": "sad",
|
| 100 |
+
"surprise": "surprise",
|
| 101 |
+
"contempt": "contempt",
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# higher == more stressed
|
| 105 |
+
STRESS_WEIGHTS = {
|
| 106 |
+
"anger": 0.95,
|
| 107 |
+
"fear": 0.90,
|
| 108 |
+
"disgust": 0.70,
|
| 109 |
+
"sad": 0.80,
|
| 110 |
+
"surprise": 0.55,
|
| 111 |
+
"neutral": 0.30,
|
| 112 |
+
"contempt": 0.65,
|
| 113 |
+
"happy": 0.10,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
def _bucket(p: float) -> str:
|
| 117 |
+
return "Low" if p < 33 else ("Medium" if p < 66 else "High")
|
| 118 |
+
|
| 119 |
+
emo_pipe = pipeline("image-classification", model=EMO_MODEL, device=0 if DEVICE == "cuda" else -1)
|
| 120 |
+
|
| 121 |
+
def _pipe_to_probs(res: List[Dict[str, Any]]) -> Dict[str, float]:
|
| 122 |
+
acc: Dict[str, float] = {}
|
| 123 |
+
for r in res:
|
| 124 |
+
label = (r.get("label") or r.get("emotion") or "").lower()
|
| 125 |
+
if not label:
|
| 126 |
+
continue
|
| 127 |
+
key = EMO_MAP.get(label, label)
|
| 128 |
+
score = float(r.get("score") or r.get("confidence") or r.get("confidence_pct", 0.0) / 100.0)
|
| 129 |
+
acc[key] = acc.get(key, 0.0) + score
|
| 130 |
+
Z = sum(acc.values()) or 1.0
|
| 131 |
+
for k in list(acc.keys()):
|
| 132 |
+
acc[k] = acc[k] / Z
|
| 133 |
+
return acc
|
| 134 |
+
|
| 135 |
+
def emotions_top3(pil_img: Image.Image) -> List[List[Any]]:
|
| 136 |
+
res = emo_pipe(pil_img)
|
| 137 |
+
probs = _pipe_to_probs(res)
|
| 138 |
+
items = sorted(probs.items(), key=lambda kv: kv[1], reverse=True)[:3]
|
| 139 |
+
table = []
|
| 140 |
+
for i, (emo, p) in enumerate(items, start=1):
|
| 141 |
+
table.append([i, emo, round(100.0 * p, 2)])
|
| 142 |
+
return table # [[rank, emotion, pct]]
|
| 143 |
+
|
| 144 |
+
def stress_index(pil_img: Image.Image) -> Tuple[str, float]:
|
| 145 |
+
res = emo_pipe(pil_img)
|
| 146 |
+
probs = _pipe_to_probs(res)
|
| 147 |
+
raw = 0.0
|
| 148 |
+
for k, v in probs.items():
|
| 149 |
+
w = STRESS_WEIGHTS.get(k, 0.5)
|
| 150 |
+
raw += v * w
|
| 151 |
+
pct = max(0.0, min(100.0, 100.0 * raw))
|
| 152 |
+
return f"{pct:.1f}% ({_bucket(pct)})", pct
|
| 153 |
+
|
| 154 |
+
# ---------- Optional: SD image variations (1 image only) ----------
|
| 155 |
+
sd_pipe = None
|
| 156 |
+
if USE_SD_VARIATIONS:
|
| 157 |
+
try:
|
| 158 |
+
from diffusers import StableDiffusionImageVariationPipeline
|
| 159 |
+
sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
| 160 |
+
SD_MODEL, torch_dtype=torch.float32
|
| 161 |
+
)
|
| 162 |
+
sd_pipe = sd_pipe.to(DEVICE)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"[WARN] Could not load {SD_MODEL}. Generation disabled. Error: {e}")
|
| 165 |
+
sd_pipe = None
|
| 166 |
+
|
| 167 |
+
def generate_one_variation(pil_img: Image.Image, steps: int) -> Image.Image:
|
| 168 |
+
if sd_pipe is None:
|
| 169 |
+
raise gr.Error("Image-variation pipeline is not available on this Space.")
|
| 170 |
+
pil_img = pil_img.convert("RGB")
|
| 171 |
+
out = sd_pipe(pil_img, guidance_scale=3.0, num_inference_steps=int(steps)).images[0]
|
| 172 |
return out
|
| 173 |
|
| 174 |
+
# ===================== GRADIO UI =====================
|
| 175 |
+
CSS = """
|
| 176 |
+
.box { border: 1px solid #e5e7eb; border-radius: 12px; padding: 10px; }
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
with gr.Blocks(title="Face Emotion & Stress Analyzer — CPU-friendly", css=CSS, fill_height=False) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
gr.Markdown(
|
| 181 |
+
"### Face Emotion & Stress Analyzer — CPU-friendly\n"
|
| 182 |
"- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
|
| 183 |
+
"- Emotion model: **prithivMLmods/Facial-Emotion-Detection-SigLIP2**\n"
|
| 184 |
+
"- Optional SD variations: **lambdalabs/sd-image-variations-diffusers** (1 synthetic only)\n"
|
| 185 |
+
"- Right column shows nearest 5 images from the dataset (clickable)."
|
| 186 |
)
|
| 187 |
|
| 188 |
+
# ---- Row 1: upload + (top3_emotion_original | stress_original) ----
|
| 189 |
with gr.Row():
|
| 190 |
+
with gr.Column(scale=2):
|
| 191 |
+
upload_image = gr.Image(label="Upload face image", type="pil")
|
| 192 |
+
with gr.Column(scale=1):
|
| 193 |
+
top3_emotion_original = gr.Dataframe(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
headers=["Rank", "Emotion", "Confidence (%)"],
|
| 195 |
datatype=["number", "str", "number"],
|
| 196 |
+
interactive=False, label="Top-3 emotions (original image)",
|
| 197 |
+
value=[]
|
| 198 |
)
|
| 199 |
+
with gr.Column(scale=1):
|
| 200 |
+
stress_original = gr.Label(label="Stress index (original)")
|
| 201 |
+
|
| 202 |
+
gr.Markdown("#### Analyze (no synthetics)")
|
| 203 |
+
|
| 204 |
+
with gr.Row(equal_height=False):
|
| 205 |
+
# ---------- LEFT COLUMN ----------
|
| 206 |
+
with gr.Column(scale=1):
|
| 207 |
+
with gr.Group():
|
| 208 |
+
gr.Markdown("**gen_variations_control** — generate only **one** synthetic")
|
| 209 |
+
steps = gr.Slider(8, 40, value=12, step=1, label="Diffusion steps (higher=slower/better)")
|
| 210 |
+
gen_btn = gr.Button("Generate 1 synthetic", variant="primary")
|
| 211 |
+
picked_synth = gr.Image(label="Synthetic preview")
|
| 212 |
+
top3_emotion_synth = gr.Dataframe(
|
| 213 |
+
headers=["Rank", "Emotion", "Confidence (%)"],
|
| 214 |
+
datatype=["number", "str", "number"],
|
| 215 |
+
interactive=False, label="top3_emotion_synth",
|
| 216 |
+
value=[]
|
| 217 |
)
|
| 218 |
+
stress_synth = gr.Label(label="stress_synth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
# ---------- RIGHT COLUMN ----------
|
| 221 |
+
with gr.Column(scale=1):
|
| 222 |
+
nearest_images_5 = gr.Gallery(
|
| 223 |
+
label="nearest_images_5 (1-click on 5 examples)",
|
| 224 |
+
columns=5, rows=1, height=200, allow_preview=False, show_label=True
|
| 225 |
+
)
|
| 226 |
+
tops_emotion_nearest = gr.Dataframe(
|
| 227 |
+
headers=["Rank", "Emotion", "Confidence (%)"],
|
| 228 |
+
datatype=["number", "str", "number"],
|
| 229 |
+
interactive=False, label="tops_emotion_nearest_image",
|
| 230 |
+
value=[]
|
| 231 |
+
)
|
| 232 |
+
stress_nearest = gr.Label(label="stress_nearest_image")
|
| 233 |
+
|
| 234 |
+
# --------- Hidden states ---------
|
| 235 |
+
gallery_images_state = gr.State([]) # store PILs
|
| 236 |
+
gallery_index_state = gr.State([]) # store dataset indexes (ints)
|
| 237 |
+
|
| 238 |
+
# ================= Callbacks =================
|
| 239 |
+
def on_upload(img: Image.Image):
|
| 240 |
+
if img is None:
|
| 241 |
+
return gr.update(), gr.update(value=""), [], [], []
|
| 242 |
+
# original
|
| 243 |
+
t3 = emotions_top3(img)
|
| 244 |
+
s_label, _ = stress_index(img)
|
| 245 |
+
# nearest gallery
|
| 246 |
+
gal = nearest5(img) # list[(PIL, caption)]
|
| 247 |
+
gal_imgs = [g[0] for g in gal]
|
| 248 |
+
gal_caps = [g[1] for g in gal]
|
| 249 |
+
# gr.Gallery accepts [(img, caption), ...]
|
| 250 |
+
gallery = [(im, cap) for im, cap in zip(gal_imgs, gal_caps)]
|
| 251 |
+
# return
|
| 252 |
+
return t3, s_label, gallery, gal_imgs, list(range(len(gal_imgs)))
|
| 253 |
+
|
| 254 |
+
upload_image.change(
|
| 255 |
+
fn=on_upload,
|
| 256 |
+
inputs=[upload_image],
|
| 257 |
+
outputs=[top3_emotion_original, stress_original, nearest_images_5, gallery_images_state, gallery_index_state]
|
| 258 |
+
)
|
| 259 |
|
| 260 |
+
def on_gallery_select(evt: gr.SelectData, imgs: List[Image.Image], idxs: List[int]):
|
| 261 |
+
# evt.index is the clicked cell
|
| 262 |
+
if imgs is None or not imgs:
|
| 263 |
+
return [], ""
|
| 264 |
+
i = int(evt.index) if evt is not None else 0
|
| 265 |
+
i = max(0, min(i, len(imgs)-1))
|
| 266 |
+
im = imgs[i]
|
| 267 |
+
t3 = emotions_top3(im)
|
| 268 |
+
s_label, _ = stress_index(im)
|
| 269 |
+
return t3, s_label
|
| 270 |
+
|
| 271 |
+
nearest_images_5.select(
|
| 272 |
+
fn=on_gallery_select,
|
| 273 |
+
inputs=[gallery_images_state, gallery_index_state],
|
| 274 |
+
outputs=[tops_emotion_nearest, stress_nearest]
|
| 275 |
)
|
| 276 |
|
| 277 |
+
def on_generate(img: Image.Image, steps_val: int):
|
| 278 |
+
if img is None:
|
| 279 |
+
raise gr.Error("Upload an image first.")
|
| 280 |
+
if sd_pipe is None:
|
| 281 |
+
raise gr.Error("Synthetic generation is disabled on this Space.")
|
| 282 |
+
synth = generate_one_variation(img, steps_val)
|
| 283 |
+
t3 = emotions_top3(synth)
|
| 284 |
+
s_label, _ = stress_index(synth)
|
| 285 |
+
return synth, t3, s_label
|
| 286 |
|
| 287 |
+
gen_btn.click(
|
| 288 |
+
fn=on_generate,
|
| 289 |
+
inputs=[upload_image, steps],
|
| 290 |
+
outputs=[picked_synth, top3_emotion_synth, stress_synth]
|
| 291 |
+
)
|
| 292 |
|
| 293 |
if __name__ == "__main__":
|
| 294 |
demo.launch()
|