Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
import cv2 | |
from models.utils import save_mask_as_png | |
from database.crud import create_session | |
def resize_if_needed(image, max_side): | |
""" | |
Resize image if any side > max_side, keeping aspect ratio. | |
:param image: np.ndarray, loaded BGR image | |
:param max_side: int, max allowed size for width or height | |
:return: resized image | |
""" | |
h, w = image.shape[:2] | |
if max(h, w) <= max_side: | |
print("[DEBUG] The image has fine size") | |
return image # already fine | |
scale = max_side / max(h, w) | |
new_w, new_h = int(w * scale), int(h * scale) | |
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
print("[DEBUG] The image was resized") | |
return resized | |
def process_session_image(session_id, image_path, prompt_text, sam_wrapper, dino_wrapper, save_root="outputs"): | |
""" | |
Full pipeline: detect + segment + save + write to DB (session-based). | |
:param session_id: ID of the session (string) | |
:param image_path: Path to uploaded image (string) | |
:param prompt_text: Prompt from user (string) | |
:param sam_wrapper: Initialized SAM wrapper | |
:param dino_wrapper: Initialized DINO wrapper | |
:param save_root: Base output directory (default: "outputs") | |
:return: List of saved PNG file paths | |
""" | |
image = cv2.imread(image_path) | |
image = resize_if_needed(image, max_side=1536) | |
if image is None: | |
raise ValueError(f"Failed to load image from path: {image_path}") | |
# 1. Run DINO detection | |
boxes = dino_wrapper.detect(image, prompt_text) | |
# 2. Create output folder for this session | |
session_dir = os.path.join(save_root, session_id) | |
os.makedirs(session_dir, exist_ok=True) | |
saved_paths = [] | |
# 3. Run SAM on each box | |
for i, box in enumerate(boxes): | |
mask = sam_wrapper.predict_with_box(image, box) | |
if mask is None: | |
continue | |
filename = f"{uuid.uuid4().hex[:8]}_{i}_{prompt_text.replace(' ', '_')}.png" | |
full_path = os.path.join(session_dir, filename) | |
save_mask_as_png(image, mask, full_path) | |
relative_path = os.path.relpath(full_path, start=".").replace("\\", "/") | |
saved_paths.append(relative_path) | |
# 4. Save session in database | |
create_session(session_id=session_id, image_path=image_path, result_paths=saved_paths) | |
return saved_paths | |