Talk2DINO / app.py
lorebianchi98's picture
Added DINOv3 support
702509c
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel
import os
import torchvision.transforms.functional as F
from src.plot import plot_qualitative
from PIL import Image
from io import BytesIO
import base64
from pathlib import Path
import spaces # πŸ‘ˆ REQUIRED for ZeroGPU
# --- Setup ---
os.environ["GRADIO_TEMP_DIR"] = "tmp"
os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Load Models (on CPU first; ZeroGPU will move to CUDA dynamically) ---
print("πŸ” Loading models... (this may take a minute on first launch)")
model_B = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTB", trust_remote_code=True).to("cpu").eval()
model_L = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTL", trust_remote_code=True).to("cpu").eval()
model_B_v3 = AutoModel.from_pretrained("lorebianchi98/Talk2DINOv3-ViTB", trust_remote_code=True).to("cpu").eval()
model_L_v3 = AutoModel.from_pretrained("lorebianchi98/Talk2DINOv3-ViTL", trust_remote_code=True).to("cpu").eval()
MODELS = {
"Talk2DINO-ViT-B": model_B,
"Talk2DINO-ViT-L": model_L,
"Talk2DINOv3-ViT-B": model_B_v3,
"Talk2DINOv3-ViT-L": model_L_v3,
}
# --- Example Setup ---
EXAMPLE_IMAGES_DIR = Path("examples").resolve()
example_images = sorted([str(p) for p in EXAMPLE_IMAGES_DIR.glob("*.png")])
DEFAULT_CLASSES = {
"0_pikachu.png": "pikachu,traffic_sign,forest,road,cap",
"1_jurassic.png": "dinosaur,smoke,vegetation,person",
"2_falcon.png": "millenium_falcon,space"
}
DEFAULT_BG_THRESH = 0.55
DEFAULT_BG_CLEAN = False
# --- Inference Function ---
@spaces.GPU(duration=120)
def talk2dino_infer(input_image, class_text, selected_model="Talk2DINO-ViT-B",
apply_pamr=True, with_background=False, bg_thresh=0.55, apply_bg_clean=False):
if input_image is None:
raise gr.Error("No image detected. Please select or upload an image first.")
if selected_model not in MODELS:
raise gr.Error("Please select a valid model before running inference.")
model = MODELS[selected_model].to("cuda")
text = [t.strip() for t in class_text.replace("_", " ").split(",") if t.strip()]
if len(text) == 0:
raise gr.Error("Please provide at least one class name before generating segmentation.")
img = F.to_tensor(input_image).unsqueeze(0).float().to("cuda") * 255.0
# Generate color palette
palette = [
[255, 0, 0],
[255, 255, 0],
[0, 255, 0],
[0, 255, 255],
[0, 0, 255],
[128, 128, 128],
]
if len(text) > len(palette):
for _ in range(len(text) - len(palette)):
palette.append([np.random.randint(0, 255) for _ in range(3)])
if with_background:
palette.insert(0, [0, 0, 0])
model.with_bg_clean = apply_bg_clean
with torch.no_grad():
text_emb = model.build_dataset_class_tokens("sub_imagenet_template", text)
text_emb = model.build_text_embedding(text_emb)
mask, _ = model.generate_masks(img, img_metas=None, text_emb=text_emb,
classnames=text, apply_pamr=apply_pamr)
if with_background:
background = torch.ones_like(mask[:, :1]) * bg_thresh
mask = torch.cat([background, mask], dim=1)
mask = mask.argmax(dim=1)
if with_background:
text = ["background"] + text
img_out = plot_qualitative(
img.cpu()[0].permute(1, 2, 0).int().numpy(),
mask.cpu()[0].numpy(),
palette,
texts=text,
)
torch.cuda.empty_cache()
return img_out
# --- Gradio Interface ---
with gr.Blocks(title="Talk2DINO / Talk2DINOv3 Demo") as demo:
# Overview Section
overview_img = Image.open("assets/overview.png").convert("RGB")
overview_img = overview_img.resize((int(overview_img.width * 0.7), int(overview_img.height * 0.7)))
buffered = BytesIO()
overview_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
gr.Markdown(f"""
# πŸ¦– Talk2DINO & Talk2DINOv3 Demo
![Overview](data:image/png;base64,{img_str})
<div style="font-size: x-large; white-space: nowrap; display: flex; align-items: center; gap: 10px;">
<a href="https://lorebianchi98.github.io/Talk2DINO/" target="_blank">Project page</a>
<span>|</span>
<a href="http://arxiv.org/abs/2411.19331" target="_blank">
<img src="https://img.shields.io/badge/arXiv-2411.19331-b31b1b.svg" style="height:28px; vertical-align:middle;">
</a>
<span>|</span>
<a href="https://huggingface.co/papers/2411.19331" target="_blank">
<img src="https://img.shields.io/badge/HuggingFace-Paper-yellow.svg" style="height:28px; vertical-align:middle;">
</a>
</div>
---
Perform **open-vocabulary semantic segmentation** using Talk2DINO or Talk2DINOv3.
Supports both **ViT-B** and **ViT-L** model sizes.
**Steps:**
1. Upload or select an example image.
2. Enter class names (e.g., `pikachu, forest, road`).
3. Choose model variant (Talk2DINO or v3).
4. Adjust options and click **Generate Segmentation**.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image", value=None)
if example_images:
example_gallery = gr.Gallery(
value=example_images,
label="Or select from example images",
show_label=True,
columns=3,
object_fit="contain",
height="auto",
)
with gr.Column():
model_selector = gr.Dropdown(
label="Select Model",
choices=list(MODELS.keys()),
value="Talk2DINO-ViT-B",
)
class_text = gr.Textbox(
label="Comma-separated Classes",
value="",
placeholder="e.g. pikachu, road, tree",
)
apply_pamr = gr.Checkbox(label="Apply PAMR", value=True)
with_background = gr.Checkbox(label="Include Background", value=False)
bg_thresh = gr.Slider(
label="Background Threshold",
minimum=0.0,
maximum=1.0,
value=DEFAULT_BG_THRESH,
step=0.01,
interactive=False,
)
apply_bg_clean = gr.Checkbox(
label="Apply Background Cleaning",
value=False,
interactive=False,
)
generate_button = gr.Button("πŸš€ Generate Segmentation", interactive=False)
output_image = gr.Image(type="numpy", label="Segmentation Overlay")
# --- Background Option Toggle ---
def toggle_bg_options(with_bg):
if with_bg:
return gr.update(interactive=True, value=DEFAULT_BG_THRESH), gr.update(interactive=True, value=DEFAULT_BG_CLEAN)
else:
return gr.update(interactive=False, value=DEFAULT_BG_THRESH), gr.update(interactive=False, value=DEFAULT_BG_CLEAN)
with_background.change(
fn=toggle_bg_options,
inputs=[with_background],
outputs=[bg_thresh, apply_bg_clean],
)
# --- Enable Button Only When Classes Exist ---
def enable_generate_button(text):
return gr.update(interactive=bool(text.strip()))
class_text.change(fn=enable_generate_button, inputs=[class_text], outputs=[generate_button])
# --- Example Image Loader ---
def load_example_image(evt: gr.SelectData):
selected = evt.value["image"]
if isinstance(selected, str):
img = Image.open(selected).convert("RGB")
filename = Path(selected).name
elif isinstance(selected, dict):
img = Image.open(selected["path"]).convert("RGB")
filename = Path(selected["path"]).name
else:
img = Image.fromarray(selected)
filename = None
class_val = DEFAULT_CLASSES.get(filename, "")
return img, class_val, gr.update(interactive=bool(class_val.strip()))
if example_images:
example_gallery.select(
fn=load_example_image,
inputs=[],
outputs=[input_image, class_text, generate_button],
)
# --- User Upload Reset ---
def on_upload_image(img):
if img is None:
return None, "", gr.update(interactive=False)
return img, "", gr.update(interactive=False)
input_image.upload(
fn=on_upload_image,
inputs=[input_image],
outputs=[input_image, class_text, generate_button],
)
# --- Generate Segmentation ---
generate_button.click(
talk2dino_infer,
inputs=[input_image, class_text, model_selector, apply_pamr, with_background, bg_thresh, apply_bg_clean],
outputs=output_image,
)
demo.launch()