vibe-shopping / virtual_try /auto_masker.py
sitatech's picture
Remove flashinfer dependency
a8d3569
raw
history blame
2.26 kB
import cv2
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from lang_sam import LangSAM
class AutoInpaintMaskGenerator:
def __init__(
self,
langsam_model: LangSAM | None = None,
):
"""
langsam_model: an instance of LangSAM already loaded
threshold: mask score threshold for filtering masks
mask_selection:
- "best": use the highest-scoring mask only
- "union": combine all masks passing threshold
"""
if langsam_model is None:
sam_path = hf_hub_download(
repo_id="facebook/sam2.1-hiera-large",
filename="sam2.1_hiera_large.pt",
)
langsam_model = LangSAM(
"sam2.1_hiera_large",
sam_path,
)
self.model = langsam_model
def generate_mask(
self,
image: Image.Image,
prompt: str,
threshold: float = 0.3,
) -> np.ndarray:
"""
Generate a binary mask for inpainting.
Returns:
A 2D P (dtype=uint8), with 255 for masked regions and 0 elsewhere.
"""
result = self.model.predict(
texts_prompt=[prompt],
images_pil=[image],
)[0]
masks = result["masks"] # (N, H, W)
scores = np.atleast_1d(result["mask_scores"]) # Ensure it's always at least 1D
# If only one mask returned, expand dims
if masks.ndim == 2:
masks = masks[np.newaxis, :, :] # Make it (1, H, W)
if len(masks) == 0:
raise ValueError("No masks found.")
# Filter masks by score threshold
valid_indices = scores >= threshold
if len(valid_indices) == 0:
raise ValueError("No masks scored the required threshold.")
combined_mask = np.any(masks[valid_indices], axis=0)
# Convert to uint8 binary mask for inpainting
binary_mask = (combined_mask.astype(np.uint8)) * 255 # 0 or 255
# Apply dilation, to give more flexibility to the inpainting model
kernel = np.ones((10, 10), np.uint8)
dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1)
return dilated_mask