Spaces:
Build error
Build error
import torch | |
from PIL import Image | |
import gradio as gr | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
# Load the SAM2 model | |
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large") | |
# Function to predict masks from the image and prompts | |
def generate_mask(image, prompt): | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
predictor.set_image(image) | |
masks, _, _ = predictor.predict(prompt) | |
return masks[0] # Returning the first mask for simplicity | |
# Set up the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Segmentation using SAM2") | |
# Input: Upload an image | |
image_input = gr.Image(label="Upload Image", type="pil") | |
# Input: Text prompt for image segmentation | |
prompt_input = gr.Textbox(label="Enter segmentation prompt", placeholder="Describe what you want to segment") | |
# Output: Display the mask generated by the SAM2 model | |
output_mask = gr.Image(label="Generated Mask") | |
# Button to trigger mask generation | |
generate_button = gr.Button("Generate Mask") | |
# Link button click with the segmentation function | |
generate_button.click(fn=generate_mask, inputs=[image_input, prompt_input], outputs=output_mask) | |
# Launch the Gradio app | |
demo.launch() | |