from pathlib import Path from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn import gradio as gr from datetime import datetime from typing import Union, Literal from functools import partial import cv2 import torch from main import grounding_dino_detect from models import ( run_segmentation, resize_img, load_image, get_downloaded_model_path, load_grounding_model, create_sam, CONFIG_PATH, ) from segment_anything import SamPredictor from segment_anything.modeling import Sam from groundingdino.models import GroundingDINO import rerun as rr rr.init("GroundingSAM") # create a FastAPI app app = FastAPI() # create a static directory to store the static files static_dir = Path("./static") static_dir.mkdir(parents=True, exist_ok=True) # mount FastAPI StaticFiles server app.mount("/static", StaticFiles(directory=static_dir), name="static") def log_video_segmentation( video_path: Path, prompt: str, model: GroundingDINO, predictor: Sam, device: str = "cpu", ): assert video_path.exists() cap = cv2.VideoCapture(str(video_path)) idx = 0 while cap.isOpened(): ret, bgr = cap.read() if not ret or idx > 20: break rr.set_time_sequence("frame", idx) rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) rgb = resize_img(rgb, 512) rr.log_image("image", rgb) detections, phrases, id_from_phrase = grounding_dino_detect( model, device, rgb, prompt ) predictor.set_image(rgb) run_segmentation(predictor, rgb, detections, phrases, id_from_phrase) idx += 1 def log_images_segmentation( images: list[Union[str, Path]], prompt: str, model: GroundingDINO, predictor: Sam, device: str = "cpu", ): for n, image_uri in enumerate(images): rr.set_time_sequence("image", n) image = load_image(image_uri) rr.log_image("image", image) detections, phrases, id_from_phrase = grounding_dino_detect( model, device, image, prompt ) predictor.set_image(image) run_segmentation(predictor, image, detections, phrases, id_from_phrase) # Gradio stuff def predict(prompt: str, image_path: str, type: Literal["image", "video"]): file_name = f"{datetime.utcnow().strftime('%s')}.html" file_path = static_dir / file_name rec = rr.memory_recording() device = "cuda" if torch.cuda.is_available() else "cpu" # load model grounded_checkpoint = get_downloaded_model_path("grounding") model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=device) sam = create_sam("vit_b", device) predictor = SamPredictor(sam) if type == "video": log_video_segmentation( Path(image_path), prompt, model, predictor, device=device, ) else: log_images_segmentation( [image_path], prompt, model, predictor, device=device, ) with open(file_path, "w") as f: f.write(rec.as_html()) iframe = f"""""" return iframe def pred_img_demo(): pred_fn = partial(predict, type="image") with gr.Row(): with gr.Column(): text_input = gr.Textbox(value="tires, wheels", label="Prompt") new_btn = gr.Button("Segment") with gr.Column(): image = gr.Image(label="Image", type="filepath") with gr.Row(): html = gr.HTML(label="HTML preview", show_label=True) with gr.Row(): gr.Examples( [ ["fan, mirror, sofa", "living_room.jpeg"], ], fn=pred_fn, inputs=[text_input, image], outputs=[html], cache_examples=True, ) new_btn.click(fn=pred_fn, inputs=[text_input, image], outputs=[html]) def pred_vid_demo(): pred_fn = partial(predict, type="video") with gr.Row(): with gr.Column(): text_input = gr.Textbox(value="tires, wheels", label="Prompt") new_btn = gr.Button("Segment") with gr.Column(): video = gr.Video(label="Video", type="filepath") with gr.Row(): html = gr.HTML(label="HTML preview", show_label=True) with gr.Row(): gr.Examples( [ ["dog, woman", "dog_and_woman.mp4"], ], fn=pred_fn, inputs=[text_input, video], outputs=[html], cache_examples=True, ) new_btn.click(fn=pred_fn, inputs=[text_input, video], outputs=[html]) with gr.Blocks() as block: with gr.Tab(label="Image"): pred_img_demo() with gr.Tab(label="Video"): pred_vid_demo() # mount Gradio app to FastAPI app app = gr.mount_gradio_app(app, block, path="/") # serve the app if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)