File size: 11,251 Bytes
af8184f
 
 
 
 
 
 
 
 
0bce39f
1af19a0
af8184f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bce39f
af8184f
 
 
 
f539f23
af8184f
f539f23
af8184f
 
 
 
 
 
 
 
727ef82
 
0bce39f
727ef82
f6f4cf2
df7a54d
0bce39f
3413b87
727ef82
0bce39f
3413b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6f4cf2
af8184f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bce39f
af8184f
 
 
 
 
 
 
1af19a0
cc9e884
 
af8184f
727ef82
af8184f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0de29c6
af8184f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bbd57a
f8836c0
1af19a0
f8836c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import os
import base64
import json
import requests
import torch
import numpy as np
import cv2
from PIL import Image, ImageFilter
from scipy.ndimage import binary_dilation
from omegaconf import OmegaConf

# -----------------------------
# 1) הגדרת המפתח API של Gemini כפרמטר
# -----------------------------

SYSTEM_INST = """\
You are given an image. You must return information about the main character in the image.
Do not write anything else beyond this!

**Guidelines for identifying a character in the image:**
1. **Male:**
   - Infant (0–2) → "baby boy"
   - Toddler (2–5) → "toddler boy"
   - Child (6–11) → "boy"
   - Teenager (12–17) → "teen boy"
   - Young adul (18–35) → "young man"
   - adul (36–59) → "man"
   - Elderly (60+) → "elderly man"

2. **Female:**
   - Infant (0–2) → "baby girl"
   - Toddler (2–5) → "toddler girl"
   - Child (6–11) → "girl"
   - Teenager (12–17) → "teen girl"
   - Young adul (18–35) → "young woman"
   - adul (36–59) → "woman"
   - Elderly (60+) → "elderly woman"

3. **Unclear identification:**
   - Ambiguous character → "unidentified"
   - Ambiguous infant/toddler → "baby" or "toddler"

4. **No character in the image:**
   - Respond: "no person"

5. **Multiple characters:**
   - Identify the most central or prominent character.

Notes:
- If data is insufficient to classify → "insufficient data".
"""

conversation = []  # נשמור כאן את השיחה הנוכחית

female_keywords = {
    "baby girl", "toddler girl", "girl",
    "teen girl", "young woman", "woman",
    "elderly woman"
}

def is_female_from_text(gemini_text: str) -> bool:
    """בודק האם התשובה מ-Gemini מצביעה על אישה לפי מילות המפתח שהוגדרו."""
    return gemini_text.lower().strip() in female_keywords


def encode_image_to_base64(image: Image.Image) -> str:
    import io
    buffer = io.BytesIO()
    image.save(buffer, format='JPEG')
    encoded_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
    return encoded_str

def add_user_text(message: str):
    conversation.append({
        "role": "user",
        "parts": [
            {"text": message}
        ]
    })


def add_user_image_from_pil(image: Image.Image, mime_type: str = "image/jpeg"):
    encoded_str = encode_image_to_base64(image)
    conversation.append({
        "role": "user",
        "parts": [
            {
                "inline_data": {
                    "mime_type": mime_type,
                    "data": encoded_str
                }
            }
        ]
    })


def send_and_receive(api_key: str) -> str:
    url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
    params = {"key": api_key}
    headers = {"Content-Type": "application/json"}

    payload = {
        "systemInstruction": {
            "role": "system",
            "parts": [
                {"text": SYSTEM_INST}
            ]
        },
        "contents": conversation
    }

    response = requests.post(url, params=params, headers=headers, json=payload)
    if response.status_code != 200:
        print(f"[Gemini] שגיאה בסטטוס קוד: {response.status_code}")
        return "NO_ANSWER"

    resp_json = response.json()
    candidates = resp_json.get("candidates", [])
    if not candidates:
        print("[Gemini] לא התקבלה תשובה.")
        return "NO_ANSWER"

    model_content = candidates[0].get("content", {})
    model_parts = model_content.get("parts", [])
    if not model_parts:
        print("[Gemini] לא נמצא תוכן בתשובת המודל.")
        return "NO_ANSWER"

    model_text = model_parts[0].get("text", "").strip()
    conversation.append({
        "role": "model",
        "parts": [
            {"text": model_text}
        ]
    })
    return model_text


# -----------------------------
# 3) טעינת מודל YOLO
# -----------------------------
from ultralytics import YOLO
YOLO_MODEL_PATH = '../../models/yolo11m.pt'

try:
    yolo_model = YOLO(YOLO_MODEL_PATH)
    yolo_model.to("cpu")
    print("[YOLO] מודל YOLO נטען בהצלחה.")
except Exception as e:
    print(f"[YOLO] לא מצליח לטעון את המודל בנתיב: {YOLO_MODEL_PATH}. שגיאה: {e}")
    yolo_model = None

TARGET_CLASS = "person"
CONF_THRESHOLD = 0.2

# -----------------------------
# 4) הכנה ל-SAM2
# -----------------------------
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# נתיבים יחסיים ל-Space של Hugging Face
SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_tiny.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_t.yaml"

sam2_predictor = None  # אתחול כ-None
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_sam2_model():
    """טוען את מודל SAM2 באופן גלובלי."""
    global sam2_predictor
    try:
        # טעינת המודל
        sam2_model = build_sam2(MODEL_CFG, SAM2_CHECKPOINT, device=device)
        sam2_predictor = SAM2ImagePredictor(sam2_model)
        print("[SAM2] מודל SAM2 נטען בהצלחה.")

    except FileNotFoundError as e:
        print(f"[ERROR] קובץ SAM2 לא נמצא: {e}")
        print(f"  - ודא שקובץ המודל '{SAM2_CHECKPOINT}' וקובץ הקונפיג '{MODEL_CFG}' קיימים בנתיבים הנכונים בתוך ה-Space שלך.")

    except Exception as e:
        print(f"[ERROR] שגיאה כללית בטעינת SAM2: {e}")
        print(f"  - סוג השגיאה: {type(e).__name__}")
        print(f"  - הודעת השגיאה: {e}")
        import traceback
        print(f"  - Traceback:")
        traceback.print_exc()
        print(f"  - בדוק את התאימות בין גרסאות הספריות (torch, torchvision, sam2) ואת תקינות קובץ המודל.")

# טעינת המודל בעת טעינת המודול
load_sam2_model()


# -----------------------------
# 5) פונקציית טשטוש
# -----------------------------
def blur_regions_with_mask(
    image: Image.Image,
    mask: np.ndarray,
    blur_radius=20,
    pixel_size=20,
    expansion_pixels=1
):
    processed_image = image.copy()
    img_np = np.array(processed_image)

    structure = np.ones((expansion_pixels, expansion_pixels), dtype=bool)
    expanded_mask = binary_dilation(mask, structure=structure)

    blurred_whole = processed_image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
    blurred_whole_np = np.array(blurred_whole)

    ys, xs = np.where(expanded_mask)
    if len(xs) == 0 or len(ys) == 0:
        return processed_image

    x_min, x_max = xs.min(), xs.max()
    y_min, y_max = ys.min(), ys.max()

    region = blurred_whole_np[y_min:y_max, x_min:x_max]

    from PIL import Image as PILImage
    small = PILImage.fromarray(region).resize(
        ((x_max - x_min) // pixel_size, (y_max - y_min) // pixel_size),
        resample=Image.BILINEAR
    )
    pixelated = small.resize((x_max - x_min, y_max - y_min), PILImage.NEAREST)
    pixelated_np = np.array(pixelated)

    combined = img_np.copy()
    mask_region = expanded_mask[y_min:y_max, x_min:x_max]
    combined[y_min:y_max, x_min:x_max][mask_region] = pixelated_np[mask_region]

    return Image.fromarray(combined)


# -----------------------------
# 6) הפונקציה המרכזית
# -----------------------------
def process_image(
    pil_image: Image.Image,
    gemini_api_key: str,
    progress_callback=None
) -> Image.Image:
    if not gemini_api_key:
        raise ValueError("מפתח API של Gemini אינו מוזן.")
    """
    פונקציה המקבלת תמונת PIL, מפתח API של Gemini, ומחזירה את התמונה לאחר טשטוש נשים.
    """

    if progress_callback is None:
        # אם לא הועברה פונקציה לעדכון התקדמות, ניצור פונקציה ריקה
        def progress_callback(x, desc=""):
            pass

    conversation.clear()
    add_user_text("Processing a new image (backend)!")

    # 1) שלב YOLO
    progress_callback(0.0, "מתחיל זיהוי אנשים (YOLO)...")
    if yolo_model is None:
        print("[process_image] מודל YOLO לא נטען כראוי.")
        return pil_image

    np_image = np.array(pil_image)
    results = yolo_model.predict(np_image)
    bboxes_person = []

    for result in results:
        boxes = result.boxes
        for box in boxes:
            cls_name = yolo_model.names[int(box.cls)]
            conf = box.conf.item()
            if cls_name == TARGET_CLASS and conf >= CONF_THRESHOLD:
                x1, y1, x2, y2 = box.xyxy[0]
                bboxes_person.append([int(x1), int(y1), int(x2), int(y2)])

    progress_callback(0.1, f"נמצאו {len(bboxes_person)} בוקסי 'person' ב-YOLO")

    # 2) שלב Gemini (עבור כל בוקס בנפרד)
    women_boxes = []
    n_bboxes = len(bboxes_person) if bboxes_person else 1
    for i, bbox in enumerate(bboxes_person, start=1):
        fraction = 0.1 + (0.5 * i / n_bboxes)  # נניח חצי מההתקדמות מוקצה ל-Gemini
        progress_callback(fraction, f"[Gemini] בודק בוקס #{i} מתוך {len(bboxes_person)}")

        x1, y1, x2, y2 = bbox
        cropped = pil_image.crop((x1, y1, x2, y2))

        add_user_image_from_pil(cropped)
        add_user_text("---")

        gemini_text = send_and_receive(gemini_api_key)
        if is_female_from_text(gemini_text):
            women_boxes.append(bbox)

    # 3) שלב SAM2 (עבור בוקסים של נשים)
    if sam2_predictor is None:
        print("[process_image] SAM2 לא זמין/נטען. מחזירים תמונה ללא טשטוש.")
        raise ValueError("SAM2 model is not loaded.")

    progress_callback(0.6, f"מתחיל פילוח SAM2 על {len(women_boxes)} נשים...")
    sam2_predictor.set_image(np.array(pil_image))

    women_masks = []
    n_women = len(women_boxes) if women_boxes else 1
    for j, bbox in enumerate(women_boxes, start=1):
        fraction = 0.6 + (0.3 * j / n_women)  # עדכון עד 90%
        progress_callback(fraction, f"[SAM2] מפלח בוקס #{j} מתוך {len(women_boxes)}")

        box_np = np.array([bbox])
        masks, scores, _ = sam2_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=box_np,
            multimask_output=False,
        )

        if masks.ndim == 4 and masks.shape[1] == 1:
            mask = masks.squeeze(1)[0].astype(bool)
        elif masks.ndim == 3:
            mask = masks[0].astype(bool)
        else:
            raise ValueError(f"[SAM2] צורת masks לא צפויה: {masks.shape}")

        women_masks.append((bbox, mask))

    # 4) שלב טשטוש
    progress_callback(0.9, "מתחיל טשטוש האזורים המזוהים (Blur + פיקסול)...")
    final_image = pil_image.copy()
    for (bbox, mask) in women_masks:
        final_image = blur_regions_with_mask(final_image, mask)

    progress_callback(1.0, "סיימנו! מחזירים את התוצאה הסופית.")

    # המרת התמונה ל-base64
    encoded_image = encode_image_to_base64(final_image)
    return encoded_image