Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoModel | |
| import os | |
| import torchvision.transforms.functional as F | |
| from src.plot import plot_qualitative | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| from pathlib import Path | |
| import spaces # π REQUIRED for ZeroGPU | |
| # --- Setup --- | |
| os.environ["GRADIO_TEMP_DIR"] = "tmp" | |
| os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- Load Models (on CPU first; ZeroGPU will move to CUDA dynamically) --- | |
| print("π Loading models... (this may take a minute on first launch)") | |
| model_B = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTB", trust_remote_code=True).to("cpu").eval() | |
| model_L = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTL", trust_remote_code=True).to("cpu").eval() | |
| model_B_v3 = AutoModel.from_pretrained("lorebianchi98/Talk2DINOv3-ViTB", trust_remote_code=True).to("cpu").eval() | |
| model_L_v3 = AutoModel.from_pretrained("lorebianchi98/Talk2DINOv3-ViTL", trust_remote_code=True).to("cpu").eval() | |
| MODELS = { | |
| "Talk2DINO-ViT-B": model_B, | |
| "Talk2DINO-ViT-L": model_L, | |
| "Talk2DINOv3-ViT-B": model_B_v3, | |
| "Talk2DINOv3-ViT-L": model_L_v3, | |
| } | |
| # --- Example Setup --- | |
| EXAMPLE_IMAGES_DIR = Path("examples").resolve() | |
| example_images = sorted([str(p) for p in EXAMPLE_IMAGES_DIR.glob("*.png")]) | |
| DEFAULT_CLASSES = { | |
| "0_pikachu.png": "pikachu,traffic_sign,forest,road,cap", | |
| "1_jurassic.png": "dinosaur,smoke,vegetation,person", | |
| "2_falcon.png": "millenium_falcon,space" | |
| } | |
| DEFAULT_BG_THRESH = 0.55 | |
| DEFAULT_BG_CLEAN = False | |
| # --- Inference Function --- | |
| def talk2dino_infer(input_image, class_text, selected_model="Talk2DINO-ViT-B", | |
| apply_pamr=True, with_background=False, bg_thresh=0.55, apply_bg_clean=False): | |
| if input_image is None: | |
| raise gr.Error("No image detected. Please select or upload an image first.") | |
| if selected_model not in MODELS: | |
| raise gr.Error("Please select a valid model before running inference.") | |
| model = MODELS[selected_model].to("cuda") | |
| text = [t.strip() for t in class_text.replace("_", " ").split(",") if t.strip()] | |
| if len(text) == 0: | |
| raise gr.Error("Please provide at least one class name before generating segmentation.") | |
| img = F.to_tensor(input_image).unsqueeze(0).float().to("cuda") * 255.0 | |
| # Generate color palette | |
| palette = [ | |
| [255, 0, 0], | |
| [255, 255, 0], | |
| [0, 255, 0], | |
| [0, 255, 255], | |
| [0, 0, 255], | |
| [128, 128, 128], | |
| ] | |
| if len(text) > len(palette): | |
| for _ in range(len(text) - len(palette)): | |
| palette.append([np.random.randint(0, 255) for _ in range(3)]) | |
| if with_background: | |
| palette.insert(0, [0, 0, 0]) | |
| model.with_bg_clean = apply_bg_clean | |
| with torch.no_grad(): | |
| text_emb = model.build_dataset_class_tokens("sub_imagenet_template", text) | |
| text_emb = model.build_text_embedding(text_emb) | |
| mask, _ = model.generate_masks(img, img_metas=None, text_emb=text_emb, | |
| classnames=text, apply_pamr=apply_pamr) | |
| if with_background: | |
| background = torch.ones_like(mask[:, :1]) * bg_thresh | |
| mask = torch.cat([background, mask], dim=1) | |
| mask = mask.argmax(dim=1) | |
| if with_background: | |
| text = ["background"] + text | |
| img_out = plot_qualitative( | |
| img.cpu()[0].permute(1, 2, 0).int().numpy(), | |
| mask.cpu()[0].numpy(), | |
| palette, | |
| texts=text, | |
| ) | |
| torch.cuda.empty_cache() | |
| return img_out | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Talk2DINO / Talk2DINOv3 Demo") as demo: | |
| # Overview Section | |
| overview_img = Image.open("assets/overview.png").convert("RGB") | |
| overview_img = overview_img.resize((int(overview_img.width * 0.7), int(overview_img.height * 0.7))) | |
| buffered = BytesIO() | |
| overview_img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| gr.Markdown(f""" | |
| # π¦ Talk2DINO & Talk2DINOv3 Demo | |
|  | |
| <div style="font-size: x-large; white-space: nowrap; display: flex; align-items: center; gap: 10px;"> | |
| <a href="https://lorebianchi98.github.io/Talk2DINO/" target="_blank">Project page</a> | |
| <span>|</span> | |
| <a href="http://arxiv.org/abs/2411.19331" target="_blank"> | |
| <img src="https://img.shields.io/badge/arXiv-2411.19331-b31b1b.svg" style="height:28px; vertical-align:middle;"> | |
| </a> | |
| <span>|</span> | |
| <a href="https://huggingface.co/papers/2411.19331" target="_blank"> | |
| <img src="https://img.shields.io/badge/HuggingFace-Paper-yellow.svg" style="height:28px; vertical-align:middle;"> | |
| </a> | |
| </div> | |
| --- | |
| Perform **open-vocabulary semantic segmentation** using Talk2DINO or Talk2DINOv3. | |
| Supports both **ViT-B** and **ViT-L** model sizes. | |
| **Steps:** | |
| 1. Upload or select an example image. | |
| 2. Enter class names (e.g., `pikachu, forest, road`). | |
| 3. Choose model variant (Talk2DINO or v3). | |
| 4. Adjust options and click **Generate Segmentation**. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image", value=None) | |
| if example_images: | |
| example_gallery = gr.Gallery( | |
| value=example_images, | |
| label="Or select from example images", | |
| show_label=True, | |
| columns=3, | |
| object_fit="contain", | |
| height="auto", | |
| ) | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| label="Select Model", | |
| choices=list(MODELS.keys()), | |
| value="Talk2DINO-ViT-B", | |
| ) | |
| class_text = gr.Textbox( | |
| label="Comma-separated Classes", | |
| value="", | |
| placeholder="e.g. pikachu, road, tree", | |
| ) | |
| apply_pamr = gr.Checkbox(label="Apply PAMR", value=True) | |
| with_background = gr.Checkbox(label="Include Background", value=False) | |
| bg_thresh = gr.Slider( | |
| label="Background Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=DEFAULT_BG_THRESH, | |
| step=0.01, | |
| interactive=False, | |
| ) | |
| apply_bg_clean = gr.Checkbox( | |
| label="Apply Background Cleaning", | |
| value=False, | |
| interactive=False, | |
| ) | |
| generate_button = gr.Button("π Generate Segmentation", interactive=False) | |
| output_image = gr.Image(type="numpy", label="Segmentation Overlay") | |
| # --- Background Option Toggle --- | |
| def toggle_bg_options(with_bg): | |
| if with_bg: | |
| return gr.update(interactive=True, value=DEFAULT_BG_THRESH), gr.update(interactive=True, value=DEFAULT_BG_CLEAN) | |
| else: | |
| return gr.update(interactive=False, value=DEFAULT_BG_THRESH), gr.update(interactive=False, value=DEFAULT_BG_CLEAN) | |
| with_background.change( | |
| fn=toggle_bg_options, | |
| inputs=[with_background], | |
| outputs=[bg_thresh, apply_bg_clean], | |
| ) | |
| # --- Enable Button Only When Classes Exist --- | |
| def enable_generate_button(text): | |
| return gr.update(interactive=bool(text.strip())) | |
| class_text.change(fn=enable_generate_button, inputs=[class_text], outputs=[generate_button]) | |
| # --- Example Image Loader --- | |
| def load_example_image(evt: gr.SelectData): | |
| selected = evt.value["image"] | |
| if isinstance(selected, str): | |
| img = Image.open(selected).convert("RGB") | |
| filename = Path(selected).name | |
| elif isinstance(selected, dict): | |
| img = Image.open(selected["path"]).convert("RGB") | |
| filename = Path(selected["path"]).name | |
| else: | |
| img = Image.fromarray(selected) | |
| filename = None | |
| class_val = DEFAULT_CLASSES.get(filename, "") | |
| return img, class_val, gr.update(interactive=bool(class_val.strip())) | |
| if example_images: | |
| example_gallery.select( | |
| fn=load_example_image, | |
| inputs=[], | |
| outputs=[input_image, class_text, generate_button], | |
| ) | |
| # --- User Upload Reset --- | |
| def on_upload_image(img): | |
| if img is None: | |
| return None, "", gr.update(interactive=False) | |
| return img, "", gr.update(interactive=False) | |
| input_image.upload( | |
| fn=on_upload_image, | |
| inputs=[input_image], | |
| outputs=[input_image, class_text, generate_button], | |
| ) | |
| # --- Generate Segmentation --- | |
| generate_button.click( | |
| talk2dino_infer, | |
| inputs=[input_image, class_text, model_selector, apply_pamr, with_background, bg_thresh, apply_bg_clean], | |
| outputs=output_image, | |
| ) | |
| demo.launch() | |