slimsam / app.py
merve's picture
merve HF staff
Update app.py
bb8da79 verified
raw history blame
No virus
4.66 kB
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
def sam_box_inference(image, model, x_min, y_min, x_max, y_max):
inputs = sam_processor(
Image.fromarray(image),
input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
mask = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
print(mask)
print(mask.shape)
return [(mask, "mask")]
def sam_point_inference(image, model, x, y):
inputs = sam_processor(
image,
input_points=[[[x, y]]],
return_tensors="pt").to(device)
with torch.no_grad():
outputs = sam_model(**inputs)
mask = sam_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
print(type(mask))
print(mask.shape)
return [(mask, "mask")]
def infer_point(img):
if img is None:
gr.Error("Please upload an image and select a point.")
if img["background"] is None:
gr.Error("Please upload an image and select a point.")
# background (original image) layers[0] ( point prompt) composite (total image)
image = img["background"].convert("RGB")
point_prompt = img["layers"][0]
total_image = img["composite"]
img_arr = np.array(point_prompt)
if not np.any(img_arr):
gr.Error("Please select a point on top of the image.")
else:
nonzero_indices = np.nonzero(img_arr)
img_arr = np.array(point_prompt)
nonzero_indices = np.nonzero(img_arr)
center_x = int(np.mean(nonzero_indices[1]))
center_y = int(np.mean(nonzero_indices[0]))
print("Point inference returned.")
return ((image, sam_point_inference(image, slimsam_model, center_x, center_y)),
(image, sam_point_inference(image, sam_model, center_x, center_y)))
def infer_box(prompts):
# background (original image) layers[0] ( point prompt) composite (total image)
image = prompts["image"]
if image is None:
gr.Error("Please upload an image and draw a box before submitting")
points = prompts["points"][0]
if points is None:
gr.Error("Please draw a box before submitting.")
print(points)
# x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
return ((image, sam_box_inference(image, slimsam_model, points[0], points[1], points[3], points[4])),
(image, sam_box_inference(image, sam_model, points[0], points[1], points[3], points[4])))
with gr.Blocks(title="SlimSAM") as demo:
gr.Markdown("# SlimSAM")
gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.")
with gr.Tab("Box Prompt"):
with gr.Row():
with gr.Column(scale=1):
# Title
gr.Markdown("To try box prompting, simply upload and image and draw a box on it.")
with gr.Row():
with gr.Column():
im = ImagePrompter()
btn = gr.Button("Submit")
with gr.Column():
output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
output_box_sam = gr.AnnotatedImage(label="SAM Output")
btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam])
with gr.Tab("Point Prompt"):
with gr.Row():
with gr.Column(scale=1):
# Title
gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.")
with gr.Row():
with gr.Column():
im = gr.ImageEditor(
type="pil",
)
with gr.Column():
output_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
output_sam = gr.AnnotatedImage(label="SAM Output")
im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam])
demo.launch(debug=True)