from pathlib import Path from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn import gradio as gr from datetime import datetime from typing import Union, List import cv2 import torch from main import grounding_dino_detect from models import ( run_segmentation, resize_img, load_image, get_downloaded_model_path, load_grounding_model, create_sam, CONFIG_PATH, ) from segment_anything import SamPredictor from segment_anything.modeling import Sam from groundingdino.models import GroundingDINO import rerun as rr rr.init("GroundingSAM") # create a FastAPI app app = FastAPI() # create a static directory to store the static files static_dir = Path("./static") static_dir.mkdir(parents=True, exist_ok=True) # mount FastAPI StaticFiles server app.mount("/static", StaticFiles(directory=static_dir), name="static") def log_video_segmentation( video_path: Path, prompt: str, model: GroundingDINO, predictor: Sam, device: str = "cpu", ): assert video_path.exists() cap = cv2.VideoCapture(str(video_path)) idx = 0 while cap.isOpened(): ret, bgr = cap.read() if not ret or idx > 20: break rr.set_time_sequence("frame", idx) rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) rgb = resize_img(rgb, 512) rr.log_image("image", rgb) detections, phrases, id_from_phrase = grounding_dino_detect( model, device, rgb, prompt ) predictor.set_image(rgb) run_segmentation(predictor, rgb, detections, phrases, id_from_phrase) idx += 1 def log_images_segmentation( images: list[Union[str, Path]], prompt: str, model: GroundingDINO, predictor: Sam, device: str = "cpu", ): for n, image_uri in enumerate(images): rr.set_time_sequence("image", n) image = load_image(image_uri) rr.log_image("image", image) detections, phrases, id_from_phrase = grounding_dino_detect( model, device, image, prompt ) predictor.set_image(image) run_segmentation(predictor, image, detections, phrases, id_from_phrase) # Gradio stuff def predict(prompt: str, image_path: str): file_name = f"{datetime.utcnow().strftime('%s')}.html" file_path = static_dir / file_name rec = rr.memory_recording() device = "cuda" if torch.cuda.is_available() else "cpu" # load model grounded_checkpoint = get_downloaded_model_path("grounding") model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=device) sam = create_sam("vit_b", device) predictor = SamPredictor(sam) # log_video_segmentation( # Path("dog_and_woman.mp4"), # "dog, woman", # model, # predictor, # device=device, # ) log_images_segmentation( [image_path], prompt, model, predictor, device=device, ) with open(file_path, "w") as f: f.write(rec.as_html()) iframe = f"""""" return iframe with gr.Blocks() as block: with gr.Row(): with gr.Column(): text_input = gr.Textbox(value="tires, wheels", label="Prompt") new_btn = gr.Button("Segment") with gr.Column(): image = gr.Image(label="Image", type="filepath") with gr.Row(): html = gr.HTML(label="HTML preview", show_label=True) with gr.Row(): gr.Examples( [ ["fan, mirror, sofa", "living_room.jpeg"], ], fn=predict, inputs=[text_input, image], outputs=[html], cache_examples=True, ) new_btn.click(fn=predict, inputs=[text_input, image], outputs=[html]) # mount Gradio app to FastAPI app app = gr.mount_gradio_app(app, block, path="/") # serve the app if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)