import gradio as gr import torch import numpy as np from transformers import AutoModel import os import torchvision.transforms.functional as F from src.plot import plot_qualitative from PIL import Image from io import BytesIO import base64 from pathlib import Path import spaces # 👈 REQUIRED for ZeroGPU # --- Setup --- os.environ["GRADIO_TEMP_DIR"] = "tmp" os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" # --- Load Models (on CPU first; ZeroGPU will move to CUDA dynamically) --- print("🔍 Loading models... (this may take a minute on first launch)") model_B = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTB", trust_remote_code=True).to("cpu").eval() model_L = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTL", trust_remote_code=True).to("cpu").eval() model_B_v3 = AutoModel.from_pretrained("lorebianchi98/Talk2DINOv3-ViTB", trust_remote_code=True).to("cpu").eval() model_L_v3 = AutoModel.from_pretrained("lorebianchi98/Talk2DINOv3-ViTL", trust_remote_code=True).to("cpu").eval() MODELS = { "Talk2DINO-ViT-B": model_B, "Talk2DINO-ViT-L": model_L, "Talk2DINOv3-ViT-B": model_B_v3, "Talk2DINOv3-ViT-L": model_L_v3, } # --- Example Setup --- EXAMPLE_IMAGES_DIR = Path("examples").resolve() example_images = sorted([str(p) for p in EXAMPLE_IMAGES_DIR.glob("*.png")]) DEFAULT_CLASSES = { "0_pikachu.png": "pikachu,traffic_sign,forest,road,cap", "1_jurassic.png": "dinosaur,smoke,vegetation,person", "2_falcon.png": "millenium_falcon,space" } DEFAULT_BG_THRESH = 0.55 DEFAULT_BG_CLEAN = False # --- Inference Function --- @spaces.GPU(duration=120) def talk2dino_infer(input_image, class_text, selected_model="Talk2DINO-ViT-B", apply_pamr=True, with_background=False, bg_thresh=0.55, apply_bg_clean=False): if input_image is None: raise gr.Error("No image detected. Please select or upload an image first.") if selected_model not in MODELS: raise gr.Error("Please select a valid model before running inference.") model = MODELS[selected_model].to("cuda") text = [t.strip() for t in class_text.replace("_", " ").split(",") if t.strip()] if len(text) == 0: raise gr.Error("Please provide at least one class name before generating segmentation.") img = F.to_tensor(input_image).unsqueeze(0).float().to("cuda") * 255.0 # Generate color palette palette = [ [255, 0, 0], [255, 255, 0], [0, 255, 0], [0, 255, 255], [0, 0, 255], [128, 128, 128], ] if len(text) > len(palette): for _ in range(len(text) - len(palette)): palette.append([np.random.randint(0, 255) for _ in range(3)]) if with_background: palette.insert(0, [0, 0, 0]) model.with_bg_clean = apply_bg_clean with torch.no_grad(): text_emb = model.build_dataset_class_tokens("sub_imagenet_template", text) text_emb = model.build_text_embedding(text_emb) mask, _ = model.generate_masks(img, img_metas=None, text_emb=text_emb, classnames=text, apply_pamr=apply_pamr) if with_background: background = torch.ones_like(mask[:, :1]) * bg_thresh mask = torch.cat([background, mask], dim=1) mask = mask.argmax(dim=1) if with_background: text = ["background"] + text img_out = plot_qualitative( img.cpu()[0].permute(1, 2, 0).int().numpy(), mask.cpu()[0].numpy(), palette, texts=text, ) torch.cuda.empty_cache() return img_out # --- Gradio Interface --- with gr.Blocks(title="Talk2DINO / Talk2DINOv3 Demo") as demo: # Overview Section overview_img = Image.open("assets/overview.png").convert("RGB") overview_img = overview_img.resize((int(overview_img.width * 0.7), int(overview_img.height * 0.7))) buffered = BytesIO() overview_img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() gr.Markdown(f""" # 🦖 Talk2DINO & Talk2DINOv3 Demo ![Overview](data:image/png;base64,{img_str})
Project page | |
--- Perform **open-vocabulary semantic segmentation** using Talk2DINO or Talk2DINOv3. Supports both **ViT-B** and **ViT-L** model sizes. **Steps:** 1. Upload or select an example image. 2. Enter class names (e.g., `pikachu, forest, road`). 3. Choose model variant (Talk2DINO or v3). 4. Adjust options and click **Generate Segmentation**. """) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image", value=None) if example_images: example_gallery = gr.Gallery( value=example_images, label="Or select from example images", show_label=True, columns=3, object_fit="contain", height="auto", ) with gr.Column(): model_selector = gr.Dropdown( label="Select Model", choices=list(MODELS.keys()), value="Talk2DINO-ViT-B", ) class_text = gr.Textbox( label="Comma-separated Classes", value="", placeholder="e.g. pikachu, road, tree", ) apply_pamr = gr.Checkbox(label="Apply PAMR", value=True) with_background = gr.Checkbox(label="Include Background", value=False) bg_thresh = gr.Slider( label="Background Threshold", minimum=0.0, maximum=1.0, value=DEFAULT_BG_THRESH, step=0.01, interactive=False, ) apply_bg_clean = gr.Checkbox( label="Apply Background Cleaning", value=False, interactive=False, ) generate_button = gr.Button("🚀 Generate Segmentation", interactive=False) output_image = gr.Image(type="numpy", label="Segmentation Overlay") # --- Background Option Toggle --- def toggle_bg_options(with_bg): if with_bg: return gr.update(interactive=True, value=DEFAULT_BG_THRESH), gr.update(interactive=True, value=DEFAULT_BG_CLEAN) else: return gr.update(interactive=False, value=DEFAULT_BG_THRESH), gr.update(interactive=False, value=DEFAULT_BG_CLEAN) with_background.change( fn=toggle_bg_options, inputs=[with_background], outputs=[bg_thresh, apply_bg_clean], ) # --- Enable Button Only When Classes Exist --- def enable_generate_button(text): return gr.update(interactive=bool(text.strip())) class_text.change(fn=enable_generate_button, inputs=[class_text], outputs=[generate_button]) # --- Example Image Loader --- def load_example_image(evt: gr.SelectData): selected = evt.value["image"] if isinstance(selected, str): img = Image.open(selected).convert("RGB") filename = Path(selected).name elif isinstance(selected, dict): img = Image.open(selected["path"]).convert("RGB") filename = Path(selected["path"]).name else: img = Image.fromarray(selected) filename = None class_val = DEFAULT_CLASSES.get(filename, "") return img, class_val, gr.update(interactive=bool(class_val.strip())) if example_images: example_gallery.select( fn=load_example_image, inputs=[], outputs=[input_image, class_text, generate_button], ) # --- User Upload Reset --- def on_upload_image(img): if img is None: return None, "", gr.update(interactive=False) return img, "", gr.update(interactive=False) input_image.upload( fn=on_upload_image, inputs=[input_image], outputs=[input_image, class_text, generate_button], ) # --- Generate Segmentation --- generate_button.click( talk2dino_infer, inputs=[input_image, class_text, model_selector, apply_pamr, with_background, bg_thresh, apply_bg_clean], outputs=output_image, ) demo.launch()