Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
import uvicorn | |
import gradio as gr | |
from datetime import datetime | |
from typing import Union, Literal | |
from functools import partial | |
import cv2 | |
import torch | |
from main import grounding_dino_detect | |
from models import ( | |
run_segmentation, | |
resize_img, | |
load_image, | |
get_downloaded_model_path, | |
load_grounding_model, | |
create_sam, | |
CONFIG_PATH, | |
) | |
from segment_anything import SamPredictor | |
from segment_anything.modeling import Sam | |
from groundingdino.models import GroundingDINO | |
import rerun as rr | |
rr.init("GroundingSAM") | |
# create a FastAPI app | |
app = FastAPI() | |
# create a static directory to store the static files | |
static_dir = Path("./static") | |
static_dir.mkdir(parents=True, exist_ok=True) | |
# mount FastAPI StaticFiles server | |
app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
def log_video_segmentation( | |
video_path: Path, | |
prompt: str, | |
model: GroundingDINO, | |
predictor: Sam, | |
device: str = "cpu", | |
): | |
assert video_path.exists() | |
cap = cv2.VideoCapture(str(video_path)) | |
idx = 0 | |
while cap.isOpened(): | |
ret, bgr = cap.read() | |
if not ret or idx > 20: | |
break | |
rr.set_time_sequence("frame", idx) | |
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
rgb = resize_img(rgb, 512) | |
rr.log_image("image", rgb) | |
detections, phrases, id_from_phrase = grounding_dino_detect( | |
model, device, rgb, prompt | |
) | |
predictor.set_image(rgb) | |
run_segmentation(predictor, rgb, detections, phrases, id_from_phrase) | |
idx += 1 | |
def log_images_segmentation( | |
images: list[Union[str, Path]], | |
prompt: str, | |
model: GroundingDINO, | |
predictor: Sam, | |
device: str = "cpu", | |
): | |
for n, image_uri in enumerate(images): | |
rr.set_time_sequence("image", n) | |
image = load_image(image_uri) | |
rr.log_image("image", image) | |
detections, phrases, id_from_phrase = grounding_dino_detect( | |
model, device, image, prompt | |
) | |
predictor.set_image(image) | |
run_segmentation(predictor, image, detections, phrases, id_from_phrase) | |
# Gradio stuff | |
def predict(prompt: str, image_path: str, type: Literal["image", "video"]): | |
file_name = f"{datetime.utcnow().strftime('%s')}.html" | |
file_path = static_dir / file_name | |
rec = rr.memory_recording() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load model | |
grounded_checkpoint = get_downloaded_model_path("grounding") | |
model = load_grounding_model(CONFIG_PATH, grounded_checkpoint, device=device) | |
sam = create_sam("vit_b", device) | |
predictor = SamPredictor(sam) | |
if type == "video": | |
log_video_segmentation( | |
Path(image_path), | |
prompt, | |
model, | |
predictor, | |
device=device, | |
) | |
else: | |
log_images_segmentation( | |
[image_path], | |
prompt, | |
model, | |
predictor, | |
device=device, | |
) | |
with open(file_path, "w") as f: | |
f.write(rec.as_html()) | |
iframe = f"""<iframe src="/static/{file_name}" width="950" height="712"></iframe>""" | |
return iframe | |
def pred_img_demo(): | |
pred_fn = partial(predict, type="image") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(value="tires, wheels", label="Prompt") | |
new_btn = gr.Button("Segment") | |
with gr.Column(): | |
image = gr.Image(label="Image", type="filepath") | |
with gr.Row(): | |
html = gr.HTML(label="HTML preview", show_label=True) | |
with gr.Row(): | |
gr.Examples( | |
[ | |
["fan, mirror, sofa", "living_room.jpeg"], | |
], | |
fn=pred_fn, | |
inputs=[text_input, image], | |
outputs=[html], | |
cache_examples=True, | |
) | |
new_btn.click(fn=pred_fn, inputs=[text_input, image], outputs=[html]) | |
def pred_vid_demo(): | |
pred_fn = partial(predict, type="video") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(value="tires, wheels", label="Prompt") | |
new_btn = gr.Button("Segment") | |
with gr.Column(): | |
video = gr.Video(label="Video", type="filepath") | |
with gr.Row(): | |
html = gr.HTML(label="HTML preview", show_label=True) | |
with gr.Row(): | |
gr.Examples( | |
[ | |
["dog, woman", "dog_and_woman.mp4"], | |
], | |
fn=pred_fn, | |
inputs=[text_input, video], | |
outputs=[html], | |
cache_examples=True, | |
) | |
new_btn.click(fn=pred_fn, inputs=[text_input, video], outputs=[html]) | |
with gr.Blocks() as block: | |
with gr.Tab(label="Image"): | |
pred_img_demo() | |
with gr.Tab(label="Video"): | |
pred_vid_demo() | |
# mount Gradio app to FastAPI app | |
app = gr.mount_gradio_app(app, block, path="/") | |
# serve the app | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |