language-sam / app.py
ariG23498's picture
ariG23498 HF staff
overlay
e3fa94a
# IMPORTS
import torch
import numpy as np
from PIL import Image
from lang_sam import LangSAM
import gradio as gr
def run_lang_sam(input_image, text_prompt, model):
height = width = 256
image = input_image.convert("RGB").resize((height, width))
# Get the mask using the model
masks, _, _, _ = model.predict(image, text_prompt)
# Convert masks to integer format and find the maximum mask
masks_int = masks.to(torch.uint8)
masks_max, _ = masks_int.max(dim=0, keepdim=True)
unified_mask = masks_max.squeeze(0).to(torch.bool)
# Create a colored layer for the mask (choose your color in RGB format)
color = (255, 0, 0) # Red color, for example
colored_mask = np.zeros((256, 256, 3), dtype=np.uint8)
colored_mask[unified_mask] = color # Apply the color to the mask area
# Convert the colored mask to PIL for blending
colored_mask_pil = Image.fromarray(colored_mask)
# Blend the colored mask with the original image
# You can adjust the alpha to change the transparency of the colored mask
alpha = 0.5 # Transparency factor (between 0 and 1)
blended_image = Image.blend(image, colored_mask_pil, alpha=alpha)
return blended_image
def setup_gradio_interface(model):
block = gr.Blocks()
with block:
gr.Markdown("<h1><center>Lang SAM<h1><center>")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
text_prompt = gr.Textbox(label="Enter what you want to segment")
run_button = gr.Button(value="Run")
with gr.Column():
output_mask = gr.Image(type="numpy", label="Segmentation Mask")
run_button.click(
fn=lambda image, prompt: run_lang_sam(
image, prompt, model,
),
inputs=[input_image, text_prompt],
outputs=[output_mask],
)
gr.Examples(
examples=[["bw-image.jpeg", "road"]],
inputs=[input_image, text_prompt],
outputs=[output_mask],
fn=lambda image, prompt: run_lang_sam(
image, prompt, model,
),
cache_examples=True,
label="Try this example input!",
)
return block
if __name__ == "__main__":
model = LangSAM()
gradio_interface = setup_gradio_interface(model)
gradio_interface.launch(share=False, show_api=False, show_error=True)