Spaces:
Runtime error
Runtime error
# IMPORTS | |
import torch | |
import numpy as np | |
from PIL import Image | |
from lang_sam import LangSAM | |
import gradio as gr | |
def run_lang_sam(input_image, text_prompt, model): | |
height = width = 256 | |
image = input_image.convert("RGB").resize((height, width)) | |
# Get the mask using the model | |
masks, _, _, _ = model.predict(image, text_prompt) | |
# Convert masks to integer format and find the maximum mask | |
masks_int = masks.to(torch.uint8) | |
masks_max, _ = masks_int.max(dim=0, keepdim=True) | |
unified_mask = masks_max.squeeze(0).to(torch.bool) | |
# Create a colored layer for the mask (choose your color in RGB format) | |
color = (255, 0, 0) # Red color, for example | |
colored_mask = np.zeros((256, 256, 3), dtype=np.uint8) | |
colored_mask[unified_mask] = color # Apply the color to the mask area | |
# Convert the colored mask to PIL for blending | |
colored_mask_pil = Image.fromarray(colored_mask) | |
# Blend the colored mask with the original image | |
# You can adjust the alpha to change the transparency of the colored mask | |
alpha = 0.5 # Transparency factor (between 0 and 1) | |
blended_image = Image.blend(image, colored_mask_pil, alpha=alpha) | |
return blended_image | |
def setup_gradio_interface(model): | |
block = gr.Blocks() | |
with block: | |
gr.Markdown("<h1><center>Lang SAM<h1><center>") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
text_prompt = gr.Textbox(label="Enter what you want to segment") | |
run_button = gr.Button(value="Run") | |
with gr.Column(): | |
output_mask = gr.Image(type="numpy", label="Segmentation Mask") | |
run_button.click( | |
fn=lambda image, prompt: run_lang_sam( | |
image, prompt, model, | |
), | |
inputs=[input_image, text_prompt], | |
outputs=[output_mask], | |
) | |
gr.Examples( | |
examples=[["bw-image.jpeg", "road"]], | |
inputs=[input_image, text_prompt], | |
outputs=[output_mask], | |
fn=lambda image, prompt: run_lang_sam( | |
image, prompt, model, | |
), | |
cache_examples=True, | |
label="Try this example input!", | |
) | |
return block | |
if __name__ == "__main__": | |
model = LangSAM() | |
gradio_interface = setup_gradio_interface(model) | |
gradio_interface.launch(share=False, show_api=False, show_error=True) |