Spaces:
Running
on
Zero
Running
on
Zero
| import io | |
| import os | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import PIL.Image | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| # Set environment variable for pip | |
| env = os.environ.copy() | |
| try: | |
| import natten | |
| except ImportError: | |
| print("NATTEN not found. Installing NATTEN...") | |
| print("Torch Version:", torch.__version__) | |
| print("CUDA Version:", torch.version.cuda) | |
| # Install NATTEN | |
| subprocess.run( | |
| "pip3 install natten==0.17.4+torch240cu121 -f https://shi-labs.com/natten/wheels/", shell=True, env=env, check=True | |
| ) | |
| # Add project root to path | |
| sys.path.append(str(Path(__file__).parent)) | |
| from src.backbone.vit_wrapper import PretrainedViTWrapper | |
| from utils.training import round_to_nearest_multiple | |
| from utils.visualization import plot_feats | |
| # Load NAF model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = torch.hub.load("valeoai/NAF", "naf", pretrained=True, device=device) | |
| model.eval() | |
| # Normalization for upsampling | |
| ups_norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| # Sample images | |
| SAMPLE_IMAGES = [ | |
| "asset/Cartoon.png", | |
| "asset/Natural.png", | |
| "asset/Satellite.png", | |
| "asset/Medical.png", | |
| "asset/Ecosystems.png", | |
| "asset/Driving.jpg", | |
| "asset/Manufacturing.png", | |
| ] | |
| def resize_with_aspect_ratio(img, max_size, patch_size): | |
| """Resize image maintaining aspect ratio with max dimension and patch size constraints""" | |
| w, h = img.size | |
| # Calculate scaling factor to fit within max_size | |
| scale = min(max_size / w, max_size / h) | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| # Round to nearest patch size multiple | |
| new_w = round_to_nearest_multiple(new_w, patch_size) | |
| new_h = round_to_nearest_multiple(new_h, patch_size) | |
| # Ensure minimum size | |
| new_w = max(new_w, patch_size) | |
| new_h = max(new_h, patch_size) | |
| return new_w, new_h | |
| def process_image(image, model_selection, custom_model, output_resolution): | |
| """Process image with selected model and resolution""" | |
| try: | |
| # Determine which model to use | |
| if custom_model.strip(): | |
| model_name = custom_model.strip() | |
| else: | |
| model_name = MODEL_MAPPING.get(model_selection, model_selection) | |
| # Load the backbone using vit_wrapper | |
| backbone = PretrainedViTWrapper(model_name, norm=True).to(device) | |
| backbone.eval() | |
| # Get model config for normalization and input size | |
| mean = backbone.config["mean"] | |
| std = backbone.config["std"] | |
| patch_size = backbone.patch_size | |
| back_norm = T.Normalize(mean=mean, std=std) | |
| # Prepare image at model's expected resolution | |
| img = PIL.Image.fromarray(image).convert("RGB") | |
| new_w, new_h = resize_with_aspect_ratio(img, max_size=512, patch_size=patch_size) | |
| transform = T.Compose( | |
| [ | |
| T.Resize((new_h, new_w)), | |
| T.ToTensor(), | |
| ] | |
| ) | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| # Normalize for backbone | |
| img_back = back_norm(img_tensor) | |
| lr_feats = backbone(img_back) | |
| # vit_wrapper already returns features in [B, C, H, W] format | |
| if not isinstance(lr_feats, torch.Tensor): | |
| raise ValueError(f"Unexpected feature type: {type(lr_feats)}") | |
| if len(lr_feats.shape) != 4: | |
| raise ValueError(f"Unexpected feature shape: {lr_feats.shape}. Expected [B, C, H, W].") | |
| # Normalize for upsampling | |
| img_ups = ups_norm(img_tensor) | |
| # Calculate output resolution maintaining aspect ratio | |
| _, _, h, w = lr_feats.shape | |
| aspect_ratio = w / h | |
| if aspect_ratio > 1: # Width > Height | |
| out_h = round_to_nearest_multiple(int(output_resolution / aspect_ratio), patch_size) | |
| out_w = output_resolution | |
| else: # Height >= Width | |
| out_h = output_resolution | |
| out_w = round_to_nearest_multiple(int(output_resolution * aspect_ratio), patch_size) | |
| upsampled_feats = model(img_ups, lr_feats, (out_h, out_w)) | |
| # Create visualization using plot_feats | |
| plot_feats( | |
| img_tensor[0], | |
| lr_feats[0], | |
| [upsampled_feats[0]], | |
| legend=["Image", f"Low-Res: {h}x{w}", f"High-Res: {out_h}x{out_w}"], | |
| font_size=14, | |
| ) | |
| # Convert matplotlib figure to PIL Image | |
| fig = plt.gcf() # Get current figure | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=100, bbox_inches="tight") | |
| buf.seek(0) | |
| result_img = PIL.Image.open(buf) | |
| plt.close(fig) | |
| return result_img | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| # Popular vision models with friendly names | |
| MODEL_MAPPING = { | |
| "DINOv3-B": "vit_base_patch16_dinov3.lvd1689m", | |
| "RADIOv2.5-B": "radio_v2.5-b", | |
| "DINOv2-B": "vit_base_patch14_dinov2.lvd142m", | |
| "DINOv2-R-B": "vit_base_patch14_reg4_dinov2", | |
| "DINO-B": "vit_base_patch16_224.dino", | |
| "SigLIP2-B": "vit_base_patch16_siglip_512.v2_webli", | |
| "PE-Core-B": "vit_pe_core_base_patch16_224.fb", | |
| "CLIP-B": "vit_base_patch16_clip_224.openai", | |
| } | |
| FRIENDLY_MODEL_NAMES = list(MODEL_MAPPING.keys()) | |
| # Create Gradio interface | |
| with gr.Blocks(title="NAF: Zero-Shot Feature Upsampling") as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin-bottom: 2rem;"> | |
| <h1 class="title-text" style="font-size: 3rem; margin-bottom: 0.5rem;"> | |
| π― NAF: Zero-Shot Feature Upsampling | |
| </h1> | |
| <p style="font-size: 1.2rem; color: #666; margin-bottom: 0.5rem;"> | |
| via Neighborhood Attention Filtering | |
| </p> | |
| <div style="margin-bottom: 1rem;"> | |
| <a href="https://github.com/valeoai/NAF" target="_blank" | |
| style="margin: 0 0.5rem; text-decoration: none; color: #667eea; font-weight: bold;"> | |
| π¦ Code | |
| </a> | |
| <a href="https://arxiv.org/abs/2511.18452" target="_blank" | |
| style="margin: 0 0.5rem; text-decoration: none; color: #667eea; font-weight: bold;"> | |
| π Paper | |
| </a> | |
| </div> | |
| <div class="info-box" style="max-width: 900px; margin: 0 auto;"> | |
| <p style="font-size: 1.1rem; margin-bottom: 0.8rem;"> | |
| π <strong>Upsample features from any Vision Foundation Model to any resolution using a single upsampler!</strong> | |
| </p> | |
| <p style="font-size: 0.95rem; margin: 0;"> | |
| Upload an image, select a model, choose your target resolution, and see NAF in action. | |
| </p> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Input Configuration") | |
| image_input = gr.Image(label="Upload Your Image", type="numpy") | |
| # Sample images | |
| if any(Path(p).exists() for p in SAMPLE_IMAGES): | |
| gr.Examples( | |
| examples=[[p] for p in SAMPLE_IMAGES if Path(p).exists()], | |
| inputs=image_input, | |
| label="πΌοΈ Try Sample Images", | |
| examples_per_page=4, | |
| ) | |
| gr.Markdown("### βοΈ Model Settings") | |
| model_dropdown = gr.Dropdown( | |
| choices=FRIENDLY_MODEL_NAMES, | |
| value=FRIENDLY_MODEL_NAMES[0], | |
| label="π€ Vision Foundation Model", | |
| ) | |
| custom_model_input = gr.Textbox( | |
| label="βοΈ Or Use Custom Model (timm reference name)", | |
| placeholder="e.g., vit_large_patch14_dinov2.lvd142m", | |
| value="", | |
| ) | |
| resolution_slider = gr.Slider( | |
| minimum=64, | |
| maximum=512, | |
| step=64, | |
| value=448, | |
| label="π Output Resolution (max dimension)", | |
| ) | |
| process_btn = gr.Button("β¨ Upsample Features", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π¨ Visualization Results") | |
| output_image = gr.Image(label="Feature Comparison", type="pil") | |
| gr.Markdown( | |
| """ | |
| <div style="background: #f0f7ff; padding: 1rem; border-radius: 8px; border-left: 4px solid #667eea;"> | |
| <strong>π Visualization Guide:</strong> | |
| <ul style="margin: 0.5rem 0;"> | |
| <li><strong>Left:</strong> Original input image</li> | |
| <li><strong>Center:</strong> Low-resolution features (PCA visualization)</li> | |
| <li><strong>Right:</strong> High-resolution features upsampled by NAF</li> | |
| </ul> | |
| <p style="margin-top: 0.5rem; font-size: 0.9rem; color: #555;"> | |
| <em>Note: Output features maintain the aspect ratio of the input image.</em> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| process_btn.click( | |
| fn=process_image, | |
| inputs=[image_input, model_dropdown, custom_model_input, resolution_slider], | |
| outputs=output_image, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align: center; padding: 2rem 0;"> | |
| <h3 style="color: #667eea;">π‘ About NAF</h3> | |
| <p style="max-width: 800px; margin: 1rem auto; font-size: 1.05rem; color: #555;"> | |
| NAF enables <strong>zero-shot feature upsampling</strong> from any Vision Foundation Model | |
| to any resolution. It learns to filter and combine features using neighborhood attention, | |
| without requiring model-specific training. | |
| </p> | |
| <div style="margin-top: 1.5rem;"> | |
| <a href="https://github.com/valeoai/NAF" target="_blank" | |
| style="margin: 0 1rem; text-decoration: none; color: #667eea; font-weight: bold;"> | |
| π¦ GitHub Repository | |
| </a> | |
| <a href="https://arxiv.org/abs/2511.18452" target="_blank" | |
| style="margin: 0 1rem; text-decoration: none; color: #667eea; font-weight: bold;"> | |
| π Research Paper | |
| </a> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |