File size: 5,538 Bytes
8b56af0
 
 
 
 
 
 
 
 
 
 
 
 
 
271fef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from einops import rearrange
import gradio as gr  
import torch  
import torch.nn.functional as F
from PIL import Image, ImageOps
from transformers import AutoModel, CLIPImageProcessor

hf_repo = "nvidia/RADIO-L"

image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
model.eval().cuda()


title = """RADIO: Reduce All Domains Into One"""
description = """
# RADIO

AM-RADIO is a framework to distill Large Vision Foundation models into a single one.
RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones.
Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence.
Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.

# Instructions

Simply paste an image or pick one from the gallery of examples and then click the "Submit" button.
"""

inputs = [
    gr.Image(type="pil")
]

examples = [
    "IMG_0996.jpeg",
    "IMG_1061.jpeg",
    "IMG_1338.jpeg",
    "IMG_4319.jpeg",
    "IMG_5104.jpeg",
    "IMG_5139.jpeg",
    "IMG_6225.jpeg",
    "IMG_6814.jpeg",
    "IMG_7459.jpeg",
    "IMG_7577.jpeg",
    "IMG_7687.jpeg",
    "IMG_9862.jpeg",
]

outputs = [
    gr.Textbox(label="Feature Shape"),
    gr.Image(),
]

def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
    # features: (N, C)
    # m: a hyperparam controlling how many std dev outside for outliers
    assert len(features.shape) == 2, "features should be (N, C)"
    reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
    colors = features @ reduction_mat
    if remove_first_component:
        colors_min = colors.min(dim=0).values
        colors_max = colors.max(dim=0).values
        tmp_colors = (colors - colors_min) / (colors_max - colors_min)
        fg_mask = tmp_colors[..., 0] < 0.2
        reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
        colors = features @ reduction_mat
    else:
        fg_mask = torch.ones_like(colors[:, 0]).bool()
    d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
    mdev = torch.median(d, dim=0).values
    s = d / mdev
    try:
        rins = colors[fg_mask][s[:, 0] < m, 0]
        gins = colors[fg_mask][s[:, 1] < m, 1]
        bins = colors[fg_mask][s[:, 2] < m, 2]
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
    except:
        rins = colors
        gins = colors
        bins = colors
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])

    return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)


def get_pca_map(
    feature_map: torch.Tensor,
    img_size,
    interpolation="bicubic",
    return_pca_stats=False,
    pca_stats=None,
):
    """
    feature_map: (1, h, w, C) is the feature map of a single image.
    """
    if feature_map.shape[0] != 1:
        # make it (1, h, w, C)
        feature_map = feature_map[None]
    if pca_stats is None:
        reduct_mat, color_min, color_max = get_robust_pca(
            feature_map.reshape(-1, feature_map.shape[-1])
        )
    else:
        reduct_mat, color_min, color_max = pca_stats
    pca_color = feature_map @ reduct_mat
    pca_color = (pca_color - color_min) / (color_max - color_min)
    pca_color = pca_color.clamp(0, 1)
    pca_color = F.interpolate(
        pca_color.permute(0, 3, 1, 2),
        size=img_size,
        mode=interpolation,
    ).permute(0, 2, 3, 1)
    pca_color = pca_color.cpu().numpy().squeeze(0)
    if return_pca_stats:
        return pca_color, (reduct_mat, color_min, color_max)
    return pca_color


def pad_image_to_multiple_of_16(image):
    # Calculate the new dimensions to make them multiples of 16
    width, height = image.size
    new_width = (width + 15) // 16 * 16
    new_height = (height + 15) // 16 * 16

    # Calculate the padding needed on each side
    pad_width = new_width - width
    pad_height = new_height - height

    left = pad_width // 2
    right = pad_width - left
    top = pad_height // 2
    bottom = pad_height - top

    # Apply the padding
    padded_image = ImageOps.expand(image, (left, top, right, bottom), fill='black')

    return padded_image


@spaces.GPU 
def infer_radio(image):
    """Define the function to generate the output."""
    image=pad_image_to_multiple_of_16(image)
    width, height = image.size
    pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
    pixel_values = pixel_values.to(torch.bfloat16).cuda()
    
    _, features = model(pixel_values)
    
    
    num_rows = height // model.patch_size
    num_cols = width // model.patch_size
    
    features = features.detach()
    features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
    
    pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
      
    return f"{features.shape}", pca_viz  


# Create the Gradio interface
demo = gr.Interface(
    fn=infer_radio,
    inputs=inputs,
    examples=examples,
    outputs=outputs,
    title=title,
    description=description
)
  
if __name__ == "__main__":  
    demo.launch()