grounding-sam / gradio_demo.py
pablovela5620's picture
add video segmentation example
c788434
from pathlib import Path
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import gradio as gr
from datetime import datetime
from typing import Union, Literal
from functools import partial
import cv2
import torch
from main import grounding_dino_detect
from models import (
run_segmentation,
resize_img,
load_image,
get_downloaded_model_path,
load_grounding_model,
create_sam,
CONFIG_PATH,
)
from segment_anything import SamPredictor
from segment_anything.modeling import Sam
from groundingdino.models import GroundingDINO
import rerun as rr
rr.init("GroundingSAM")
# create a FastAPI app
app = FastAPI()
# create a static directory to store the static files
static_dir = Path("./static")
static_dir.mkdir(parents=True, exist_ok=True)
# mount FastAPI StaticFiles server
app.mount("/static", StaticFiles(directory=static_dir), name="static")
def log_video_segmentation(
video_path: Path,
prompt: str,
model: GroundingDINO,
predictor: Sam,
device: str = "cpu",
):
assert video_path.exists()
cap = cv2.VideoCapture(str(video_path))
idx = 0
while cap.isOpened():
ret, bgr = cap.read()
if not ret or idx > 20:
break
rr.set_time_sequence("frame", idx)
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
rgb = resize_img(rgb, 512)
rr.log_image("image", rgb)
detections, phrases, id_from_phrase = grounding_dino_detect(
model, device, rgb, prompt
)
predictor.set_image(rgb)
run_segmentation(predictor, rgb, detections, phrases, id_from_phrase)
idx += 1
def log_images_segmentation(
images: list[Union[str, Path]],
prompt: str,
model: GroundingDINO,
predictor: Sam,
device: str = "cpu",
):
for n, image_uri in enumerate(images):
rr.set_time_sequence("image", n)
image = load_image(image_uri)
rr.log_image("image", image)
detections, phrases, id_from_phrase = grounding_dino_detect(
model, device, image, prompt
)
predictor.set_image(image)
run_segmentation(predictor, image, detections, phrases, id_from_phrase)
# Gradio stuff
def predict(prompt: str, image_path: str, type: Literal["image", "video"]):
file_name = f"{datetime.utcnow().strftime('%s')}.html"
file_path = static_dir / file_name
rec = rr.memory_recording()
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model
grounded_checkpoint = get_downloaded_model_path("grounding")
model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=device)
sam = create_sam("vit_b", device)
predictor = SamPredictor(sam)
if type == "video":
log_video_segmentation(
Path(image_path),
prompt,
model,
predictor,
device=device,
)
else:
log_images_segmentation(
[image_path],
prompt,
model,
predictor,
device=device,
)
with open(file_path, "w") as f:
f.write(rec.as_html())
iframe = f"""<iframe src="/static/{file_name}" width="950" height="712"></iframe>"""
return iframe
def pred_img_demo():
pred_fn = partial(predict, type="image")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(value="tires, wheels", label="Prompt")
new_btn = gr.Button("Segment")
with gr.Column():
image = gr.Image(label="Image", type="filepath")
with gr.Row():
html = gr.HTML(label="HTML preview", show_label=True)
with gr.Row():
gr.Examples(
[
["fan, mirror, sofa", "living_room.jpeg"],
],
fn=pred_fn,
inputs=[text_input, image],
outputs=[html],
cache_examples=True,
)
new_btn.click(fn=pred_fn, inputs=[text_input, image], outputs=[html])
def pred_vid_demo():
pred_fn = partial(predict, type="video")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(value="tires, wheels", label="Prompt")
new_btn = gr.Button("Segment")
with gr.Column():
video = gr.Video(label="Video", type="filepath")
with gr.Row():
html = gr.HTML(label="HTML preview", show_label=True)
with gr.Row():
gr.Examples(
[
["dog, woman", "dog_and_woman.mp4"],
],
fn=pred_fn,
inputs=[text_input, video],
outputs=[html],
cache_examples=True,
)
new_btn.click(fn=pred_fn, inputs=[text_input, video], outputs=[html])
with gr.Blocks() as block:
with gr.Tab(label="Image"):
pred_img_demo()
with gr.Tab(label="Video"):
pred_vid_demo()
# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)