File size: 2,428 Bytes
3a50a96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import tome
import timm
import gradio as gr
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode


model_name = "vit_large_patch16_384"

print("Started Downloading:", model_name)
model = timm.create_model(model_name, pretrained=True)
print("Finished Downloading:", model_name)

tome.patch.timm(model, trace_source=True)

input_size = model.default_cfg["input_size"][1]

# Make sure the transform is correct for your model!
transform_list = [
    transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(input_size)
]

# The visualization and model need different transforms
transform_vis  = transforms.Compose(transform_list)
transform_norm = transforms.Compose(transform_list + [
    transforms.ToTensor(),
    transforms.Normalize(model.default_cfg["mean"], model.default_cfg["std"]),
])


def process_image(img, r=25, layers=1):
    img = Image.fromarray(img.astype('uint8'), 'RGB')
    img_vis = transform_vis(img)
    img_norm = transform_norm(img)

    # from the paper:
    # r can take the following forms:
    #  - int: A constant number of tokens per layer.
    #  - Tuple[int, float]: A pair of r, inflection.
    #    Inflection describes there the the reduction / layer should trend
    #    upward (+1), downward (-1), or stay constant (0). A value of (r, 0)
    #    is as providing a constant r. (r, -1) is what we describe in the paper
    #    as "decreasing schedule". Any value between -1 and +1 is accepted.
    #  - List[int]: A specific number of tokens per layer. For extreme granularity.
    
    if layers != 1:
        r = [r] * layers

    print(r)
    model.r = r
    _ = model(img_norm[None, ...])
    source = model._tome_info["source"]

    # print(f"{source.shape[1]} tokens at the end")
    return tome.make_visualization(img_vis, source, patch_size=16, class_token=True)


iface = gr.Interface(
    fn=process_image,
    inputs=[
        "image",
        gr.inputs.Slider(0, 50, step=1, label="r value (the amount of reduction. See paper for details.)"),
        gr.inputs.Slider(1, 50, step=1, label="layers (1 means r is applied to all layers)"),
    ],
    outputs="image",
    examples=[
        ["images/husky.png", 25, 1],
        ["images/husky.png", 25, 8],
        ["images/husky.png", 25, 16],
        ["images/husky.png", 25, 22],
    ]
)
iface.launch()