grounding-sam / gradio_demo.py
pablovela5620's picture
add cached examples
062711f
raw
history blame
4.01 kB
from pathlib import Path
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import gradio as gr
from datetime import datetime
from typing import Union, List
import cv2
import torch
from main import grounding_dino_detect
from models import (
run_segmentation,
resize_img,
load_image,
get_downloaded_model_path,
load_grounding_model,
create_sam,
CONFIG_PATH,
)
from segment_anything import SamPredictor
from segment_anything.modeling import Sam
from groundingdino.models import GroundingDINO
import rerun as rr
rr.init("GroundingSAM")
# create a FastAPI app
app = FastAPI()
# create a static directory to store the static files
static_dir = Path("./static")
static_dir.mkdir(parents=True, exist_ok=True)
# mount FastAPI StaticFiles server
app.mount("/static", StaticFiles(directory=static_dir), name="static")
def log_video_segmentation(
video_path: Path,
prompt: str,
model: GroundingDINO,
predictor: Sam,
device: str = "cpu",
):
assert video_path.exists()
cap = cv2.VideoCapture(str(video_path))
idx = 0
while cap.isOpened():
ret, bgr = cap.read()
if not ret or idx > 20:
break
rr.set_time_sequence("frame", idx)
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
rgb = resize_img(rgb, 512)
rr.log_image("image", rgb)
detections, phrases, id_from_phrase = grounding_dino_detect(
model, device, rgb, prompt
)
predictor.set_image(rgb)
run_segmentation(predictor, rgb, detections, phrases, id_from_phrase)
idx += 1
def log_images_segmentation(
images: list[Union[str, Path]],
prompt: str,
model: GroundingDINO,
predictor: Sam,
device: str = "cpu",
):
for n, image_uri in enumerate(images):
rr.set_time_sequence("image", n)
image = load_image(image_uri)
rr.log_image("image", image)
detections, phrases, id_from_phrase = grounding_dino_detect(
model, device, image, prompt
)
predictor.set_image(image)
run_segmentation(predictor, image, detections, phrases, id_from_phrase)
# Gradio stuff
def predict(prompt: str, image_path: str):
file_name = f"{datetime.utcnow().strftime('%s')}.html"
file_path = static_dir / file_name
rec = rr.memory_recording()
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model
grounded_checkpoint = get_downloaded_model_path("grounding")
model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=device)
sam = create_sam("vit_b", device)
predictor = SamPredictor(sam)
# log_video_segmentation(
# Path("dog_and_woman.mp4"),
# "dog, woman",
# model,
# predictor,
# device=device,
# )
log_images_segmentation(
[image_path],
prompt,
model,
predictor,
device=device,
)
with open(file_path, "w") as f:
f.write(rec.as_html())
iframe = f"""<iframe src="/static/{file_name}" width="950" height="712"></iframe>"""
return iframe
with gr.Blocks() as block:
with gr.Row():
with gr.Column():
text_input = gr.Textbox(value="tires, wheels", label="Prompt")
new_btn = gr.Button("Segment")
with gr.Column():
image = gr.Image(label="Image", type="filepath")
with gr.Row():
html = gr.HTML(label="HTML preview", show_label=True)
with gr.Row():
gr.Examples(
[
["fan, mirror, sofa", "living_room.jpeg"],
],
fn=predict,
inputs=[text_input, image],
outputs=[html],
cache_examples=True,
)
new_btn.click(fn=predict, inputs=[text_input, image], outputs=[html])
# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)