Spaces:
Running
Running
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from lang_sam import LangSAM | |
| class AutoInpaintMaskGenerator: | |
| def __init__( | |
| self, | |
| langsam_model: LangSAM | None = None, | |
| ): | |
| """ | |
| langsam_model: an instance of LangSAM already loaded | |
| threshold: mask score threshold for filtering masks | |
| mask_selection: | |
| - "best": use the highest-scoring mask only | |
| - "union": combine all masks passing threshold | |
| """ | |
| if langsam_model is None: | |
| sam_path = hf_hub_download( | |
| repo_id="facebook/sam2.1-hiera-large", | |
| filename="sam2.1_hiera_large.pt", | |
| ) | |
| langsam_model = LangSAM( | |
| "sam2.1_hiera_large", | |
| sam_path, | |
| ) | |
| self.model = langsam_model | |
| def generate_mask( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| threshold: float = 0.3, | |
| ) -> np.ndarray: | |
| """ | |
| Generate a binary mask for inpainting. | |
| Returns: | |
| A 2D P (dtype=uint8), with 255 for masked regions and 0 elsewhere. | |
| """ | |
| result = self.model.predict( | |
| texts_prompt=[prompt], | |
| images_pil=[image], | |
| )[0] | |
| masks = result["masks"] # (N, H, W) | |
| scores = np.atleast_1d(result["mask_scores"]) # Ensure it's always at least 1D | |
| # If only one mask returned, expand dims | |
| if masks.ndim == 2: | |
| masks = masks[np.newaxis, :, :] # Make it (1, H, W) | |
| if len(masks) == 0: | |
| raise ValueError("No masks found.") | |
| # Filter masks by score threshold | |
| valid_indices = scores >= threshold | |
| if len(valid_indices) == 0: | |
| raise ValueError("No masks scored the required threshold.") | |
| combined_mask = np.any(masks[valid_indices], axis=0) | |
| # Convert to uint8 binary mask for inpainting | |
| binary_mask = (combined_mask.astype(np.uint8)) * 255 # 0 or 255 | |
| # Apply dilation, to give more flexibility to the inpainting model | |
| kernel = np.ones((10, 10), np.uint8) | |
| dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1) | |
| return dilated_mask | |