File size: 12,293 Bytes
7905879
0ae8ad0
641e216
8c869bf
641e216
0ae8ad0
 
 
0e096ce
641e216
0ae8ad0
 
 
 
7905879
0ae8ad0
 
 
7905879
 
 
 
 
 
 
0ae8ad0
7905879
0ae8ad0
 
 
 
 
 
 
 
 
641e216
0ae8ad0
 
 
 
 
 
 
 
1b1598d
0ae8ad0
 
 
 
 
 
7905879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ae8ad0
7905879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ae8ad0
7905879
0ae8ad0
 
7905879
 
 
 
 
 
 
 
 
 
 
0ae8ad0
 
7905879
0ae8ad0
7905879
0ae8ad0
 
 
 
7905879
 
 
0ae8ad0
7905879
0ae8ad0
7905879
0ae8ad0
 
7905879
 
cf0f4c3
0ae8ad0
7905879
0ae8ad0
 
 
7905879
0ae8ad0
7905879
 
 
 
 
 
 
 
 
 
 
 
 
0ae8ad0
 
7905879
0ae8ad0
 
 
 
7905879
 
 
 
 
 
 
0ae8ad0
 
7905879
 
0ae8ad0
 
 
7905879
0ae8ad0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0f4c3
 
0ae8ad0
7905879
0ae8ad0
7905879
1b1598d
7905879
9dfa0c6
7905879
0ae8ad0
 
1b1598d
8c869bf
0ae8ad0
8c869bf
0ae8ad0
 
 
 
1b1598d
 
0ae8ad0
 
1b1598d
0ae8ad0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0f4c3
0ae8ad0
0e096ce
0ae8ad0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b1598d
0ae8ad0
 
 
 
 
 
 
 
 
 
 
 
 
 
1b1598d
8c869bf
0ae8ad0
 
 
 
 
 
 
 
 
1b1598d
0ae8ad0
 
 
 
 
8c869bf
641e216
7905879
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import os, numpy as np
from typing import List, Tuple, Dict, Any
from PIL import Image

import torch
import torch.nn.functional as F

import gradio as gr
from datasets import load_dataset
from sklearn.neighbors import NearestNeighbors

# =============== CONFIG ===============
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Embeddings backbone
OPENCLIP_BACKBONE = "ViT-H-14"
OPENCLIP_PRETRAIN = "laion2B-s32B-b79K"  # laion/CLIP-ViT-H-14-laion2B-s32B-b79K

# Dataset (THIS IS YOUR "MODEL" SOURCE NOW)
DATASET_NAME = "tukey/human_face_emotions_roboflow"
DATASET_SPLIT = "train"

INDEX_SIZE = int(os.getenv("INDEX_SIZE", 400))     # כמה דוגמאות מהדאטהסט לאינדוקס
TOPK_NEAREST = 5                                   # להצגה בגלריה
KNN_K_FOR_CLASS = 25                               # לשקלול רגשות

# Optional SD variations
USE_SD_VARIATIONS = True
SD_MODEL = "lambdalabs/sd-image-variations-diffusers"
# =====================================

# ---------- Load OpenCLIP for image embeddings ----------
try:
    import open_clip
    _openclip_model, _, _openclip_preprocess = open_clip.create_model_and_transforms(
        OPENCLIP_BACKBONE, pretrained=OPENCLIP_PRETRAIN
    )
    _openclip_model = _openclip_model.to(DEVICE).eval()
except Exception as e:
    raise RuntimeError(
        f"Failed to load OpenCLIP ({OPENCLIP_BACKBONE} / {OPENCLIP_PRETRAIN}). "
        f"Install 'open_clip_torch' and verify CUDA if available. Error: {e}"
    )

@torch.inference_mode()
def embed_image(img: Image.Image) -> np.ndarray:
    img = img.convert("RGB")
    tens = _openclip_preprocess(img).unsqueeze(0).to(DEVICE)
    feats = _openclip_model.encode_image(tens)
    feats = F.normalize(feats, dim=-1).squeeze(0).detach().cpu().numpy().astype(np.float32)
    return feats  # shape [D]

# ---------- Labels & stress mapping ----------
EMO_MAP = {
    "anger": "anger", "angry": "anger",
    "disgust": "disgust",
    "fear": "fear",
    "happy": "happy", "happiness": "happy",
    "neutral": "neutral", "calm": "neutral",
    "sad": "sad", "sadness": "sad",
    "surprise": "surprise",
    "contempt": "contempt",
}
ALLOWED = set(EMO_MAP.values())  # whitelist קשיח

STRESS_WEIGHTS = {
    "anger": 0.95, "fear": 0.90, "disgust": 0.70, "sad": 0.80,
    "surprise": 0.55, "neutral": 0.30, "contempt": 0.65, "happy": 0.10,
}
def _bucket(p: float) -> str:
    return "Low" if p < 33 else ("Medium" if p < 66 else "High")

# ---------- Load dataset & build index ----------
def _extract_label(rec: Dict[str, Any]) -> str:
    # התאמה לשדות אפשריים בדאטהסט
    if "label" in rec and rec["label"]:
        raw = rec["label"]
        if isinstance(raw, (list, tuple)): raw = raw[0]
        return str(raw).strip().lower()
    if "labels" in rec and rec["labels"]:
        raw = rec["labels"][0]
        return str(raw).strip().lower()
    if "qa" in rec and rec["qa"] and isinstance(rec["qa"], list):
        qa0 = rec["qa"][0]
        if qa0 and "answer" in qa0:
            return str(qa0["answer"]).strip().lower()
    return ""

def _map_allowed(lbl: str) -> str:
    # ממפה לשם סטנדרטי, ומסנן החוצה לא מוכרות
    mapped = EMO_MAP.get(lbl, lbl)
    return mapped if mapped in ALLOWED else ""  # "" => drop

def _load_images_labels_for_index(n: int) -> Tuple[List[Image.Image], List[str]]:
    ds = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
    imgs, labels = [], []
    n = min(n, len(ds))
    for i in range(n):
        rec = ds[i]
        im = rec.get("image")
        if not isinstance(im, Image.Image):
            continue
        raw_lbl = _extract_label(rec)
        mapped = _map_allowed(raw_lbl)
        if not mapped:
            continue  # זורק תוויות לא מותרות/ריקות
        imgs.append(im.copy())
        labels.append(mapped)
    return imgs, labels

def build_index(imgs: List[Image.Image]) -> Tuple[NearestNeighbors, np.ndarray]:
    vecs = [embed_image(im) for im in imgs]
    X = np.stack(vecs, axis=0)
    nn = NearestNeighbors(metric="cosine", n_neighbors=min(max(TOPK_NEAREST, KNN_K_FOR_CLASS), len(imgs)))
    nn.fit(X)
    return nn, X

print("Loading dataset & building index (first time only)...")
DATASET_IMAGES, DATASET_LABELS = _load_images_labels_for_index(INDEX_SIZE)
if len(DATASET_IMAGES) == 0:
    raise RuntimeError("No images with allowed labels were loaded from the dataset.")
NN_MODEL, EMB_MATRIX = build_index(DATASET_IMAGES)
print(f"Index ready with {len(DATASET_IMAGES)} images (labels={sorted(set(DATASET_LABELS))}).")

# ---------- Nearest & KNN-based classification ----------
def nearest5(pil_img: Image.Image) -> List[Tuple[Image.Image, str]]:
    q = embed_image(pil_img).reshape(1, -1)
    n = min(5, len(DATASET_IMAGES))
    dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=n)
    out = []
    for rank, (dist, idx) in enumerate(zip(dists[0], idxs[0]), start=1):
        sim = max(0.0, 1.0 - float(dist))  # cosine distance -> similarity
        im = DATASET_IMAGES[int(idx)]
        caption = f"#{rank}  sim={sim:.3f}  idx={int(idx)}"
        out.append((im, caption))
    return out

def knn_probs(pil_img: Image.Image, k: int = KNN_K_FOR_CLASS) -> Dict[str, float]:
    q = embed_image(pil_img).reshape(1, -1)
    k = min(k, len(DATASET_IMAGES))
    dists, idxs = NN_MODEL.kneighbors(q, n_neighbors=k)
    sims = 1.0 - dists[0]  # higher is better
    sims = np.maximum(sims, 0.0)
    votes: Dict[str, float] = {}
    for sim, idx in zip(sims, idxs[0]):
        lbl = DATASET_LABELS[int(idx)]
        if lbl in ALLOWED:
            votes[lbl] = votes.get(lbl, 0.0) + float(sim)
    Z = sum(votes.values()) or 1.0
    return {k: v / Z for k, v in votes.items()}

def emotions_top3(pil_img: Image.Image) -> List[List[Any]]:
    probs = knn_probs(pil_img)
    items = sorted(probs.items(), key=lambda kv: kv[1], reverse=True)[:3]
    table = []
    for i, (emo, p) in enumerate(items, start=1):
        table.append([i, emo, round(100.0 * p, 2)])
    # משלימים אם יש פחות מ-3
    seen = {r[1] for r in table}
    for fill in ["neutral", "other"]:
        if len(table) >= 3: break
        if fill in ALLOWED and fill not in seen:
            table.append([len(table)+1, fill, 0.0])
    return table

def stress_index(pil_img: Image.Image) -> Tuple[str, float]:
    probs = knn_probs(pil_img)
    raw = sum(probs.get(k, 0.0) * STRESS_WEIGHTS.get(k, 0.5) for k in ALLOWED)
    pct = max(0.0, min(100.0, 100.0 * raw))
    return f"{pct:.1f}%  ({_bucket(pct)})", pct

# ---------- Optional: SD image variations ----------
sd_pipe = None
if USE_SD_VARIATIONS:
    try:
        from diffusers import StableDiffusionImageVariationPipeline
        sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
            SD_MODEL, torch_dtype=torch.float32
        )
        sd_pipe = sd_pipe.to(DEVICE)
    except Exception as e:
        print(f"[WARN] Could not load {SD_MODEL}. Generation disabled. Error: {e}")
        sd_pipe = None

def generate_one_variation(pil_img: Image.Image, steps: int) -> Image.Image:
    if sd_pipe is None:
        raise gr.Error("Image-variation pipeline is not available on this Space.")
    pil_img = pil_img.convert("RGB")
    out = sd_pipe(pil_img, guidance_scale=3.0, num_inference_steps=int(steps)).images[0]
    return out

# =====================  GRADIO UI  =====================
CSS = ".box { border: 1px solid #e5e7eb; border-radius: 12px; padding: 10px; }"

with gr.Blocks(title="Face Emotion & Stress Analyzer — KNN over tukey dataset", css=CSS, fill_height=False) as demo:
    gr.Markdown(
        "### Face Emotion & Stress Analyzer — **KNN over `tukey/human_face_emotions_roboflow`**\n"
        "- Embeddings: **laion/CLIP-ViT-H-14-laion2B-s32B-b79K** (open_clip)\n"
        "- Emotion model: **KNN using labels from `tukey/human_face_emotions_roboflow`**\n"
        "- Optional SD variations: **lambdalabs/sd-image-variations-diffusers** (1 synthetic only)\n"
        "- Right column shows nearest 5 images from the dataset (clickable)."
    )

    # ---- Row 1: upload + (top3_emotion_original | stress_original) ----
    with gr.Row():
        with gr.Column(scale=2):
            upload_image = gr.Image(label="Upload face image", type="pil")
        with gr.Column(scale=1):
            top3_emotion_original = gr.Dataframe(
                headers=["Rank", "Emotion", "Confidence (%)"],
                datatype=["number", "str", "number"],
                interactive=False, label="Top-3 emotions (original image)",
                value=[]
            )
        with gr.Column(scale=1):
            stress_original = gr.Label(label="Stress index (original)")

    gr.Markdown("#### Analyze (no synthetics)")

    with gr.Row(equal_height=False):
        # ---------- LEFT COLUMN ----------
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown("**gen_variations_control** — generate only **one** synthetic")
                steps = gr.Slider(8, 40, value=12, step=1, label="Diffusion steps (higher=slower/better)")
                gen_btn = gr.Button("Generate 1 synthetic", variant="primary")
                picked_synth = gr.Image(label="Synthetic preview")
            top3_emotion_synth = gr.Dataframe(
                headers=["Rank", "Emotion", "Confidence (%)"],
                datatype=["number", "str", "number"],
                interactive=False, label="top3_emotion_synth",
                value=[]
            )
            stress_synth = gr.Label(label="stress_synth")

        # ---------- RIGHT COLUMN ----------
        with gr.Column(scale=1):
            nearest_images_5 = gr.Gallery(
                label="nearest_images_5 (1-click on 5 examples)",
                columns=5, rows=1, height=200, allow_preview=False, show_label=True
            )
            tops_emotion_nearest = gr.Dataframe(
                headers=["Rank", "Emotion", "Confidence (%)"],
                datatype=["number", "str", "number"],
                interactive=False, label="tops_emotion_nearest_image",
                value=[]
            )
            stress_nearest = gr.Label(label="stress_nearest_image")

    # --------- Hidden states ---------
    gallery_images_state = gr.State([])   # store PILs
    gallery_index_state = gr.State([])    # store dataset indexes (ints)

    # ================= Callbacks =================
    def on_upload(img: Image.Image):
        if img is None:
            return gr.update(), gr.update(value=""), [], [], []
        # original
        t3 = emotions_top3(img)
        s_label, _ = stress_index(img)
        # nearest gallery
        gal = nearest5(img)  # list[(PIL, caption)]
        gal_imgs = [g[0] for g in gal]
        gal_caps = [g[1] for g in gal]
        gallery = [(im, cap) for im, cap in zip(gal_imgs, gal_caps)]
        return t3, s_label, gallery, gal_imgs, list(range(len(gal_imgs)))

    upload_image.change(
        fn=on_upload,
        inputs=[upload_image],
        outputs=[top3_emotion_original, stress_original, nearest_images_5, gallery_images_state, gallery_index_state]
    )

    def on_gallery_select(evt: gr.SelectData, imgs: List[Image.Image], idxs: List[int]):
        if imgs is None or not imgs:
            return [], ""
        i = int(evt.index) if evt is not None else 0
        i = max(0, min(i, len(imgs)-1))
        im = imgs[i]
        t3 = emotions_top3(im)
        s_label, _ = stress_index(im)
        return t3, s_label

    nearest_images_5.select(
        fn=on_gallery_select,
        inputs=[gallery_images_state, gallery_index_state],
        outputs=[tops_emotion_nearest, stress_nearest]
    )

    def on_generate(img: Image.Image, steps_val: int):
        if img is None:
            raise gr.Error("Upload an image first.")
        if sd_pipe is None:
            raise gr.Error("Synthetic generation is disabled on this Space.")
        synth = generate_one_variation(img, steps_val)
        t3 = emotions_top3(synth)
        s_label, _ = stress_index(synth)
        return synth, t3, s_label

    gen_btn.click(
        fn=on_generate,
        inputs=[upload_image, steps],
        outputs=[picked_synth, top3_emotion_synth, stress_synth]
    )

if __name__ == "__main__":
    demo.launch()