File size: 2,370 Bytes
aa1c1e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import uuid
import cv2
from models.utils import save_mask_as_png
from database.crud import create_session

def resize_if_needed(image, max_side):
    """
    Resize image if any side > max_side, keeping aspect ratio.

    :param image: np.ndarray, loaded BGR image
    :param max_side: int, max allowed size for width or height
    :return: resized image
    """
    h, w = image.shape[:2]
    if max(h, w) <= max_side:
        print("[DEBUG] The image has fine size")
        return image  # already fine

    scale = max_side / max(h, w)
    new_w, new_h = int(w * scale), int(h * scale)
    resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
    print("[DEBUG] The image was resized")
    return resized



def process_session_image(session_id, image_path, prompt_text, sam_wrapper, dino_wrapper, save_root="outputs"):
    """
    Full pipeline: detect + segment + save + write to DB (session-based).

    :param session_id: ID of the session (string)
    :param image_path: Path to uploaded image (string)
    :param prompt_text: Prompt from user (string)
    :param sam_wrapper: Initialized SAM wrapper
    :param dino_wrapper: Initialized DINO wrapper
    :param save_root: Base output directory (default: "outputs")
    :return: List of saved PNG file paths
    """
    image = cv2.imread(image_path)
    image = resize_if_needed(image, max_side=1536)

    if image is None:
        raise ValueError(f"Failed to load image from path: {image_path}")

    # 1. Run DINO detection
    boxes = dino_wrapper.detect(image, prompt_text)

    # 2. Create output folder for this session
    session_dir = os.path.join(save_root, session_id)
    os.makedirs(session_dir, exist_ok=True)

    saved_paths = []

    # 3. Run SAM on each box
    for i, box in enumerate(boxes):
        mask = sam_wrapper.predict_with_box(image, box)
        if mask is None:
            continue

        filename = f"{uuid.uuid4().hex[:8]}_{i}_{prompt_text.replace(' ', '_')}.png"
        full_path = os.path.join(session_dir, filename)
        save_mask_as_png(image, mask, full_path)
        relative_path = os.path.relpath(full_path, start=".").replace("\\", "/")
        saved_paths.append(relative_path)

    # 4. Save session in database
    create_session(session_id=session_id, image_path=image_path, result_paths=saved_paths)

    return saved_paths