Spaces:
Sleeping
Sleeping
import tome | |
import timm | |
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
model_name = "vit_large_patch16_384" | |
print("Started Downloading:", model_name) | |
model = timm.create_model(model_name, pretrained=True) | |
print("Finished Downloading:", model_name) | |
tome.patch.timm(model, trace_source=True) | |
input_size = model.default_cfg["input_size"][1] | |
# Make sure the transform is correct for your model! | |
transform_list = [ | |
transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC), | |
transforms.CenterCrop(input_size) | |
] | |
# The visualization and model need different transforms | |
transform_vis = transforms.Compose(transform_list) | |
transform_norm = transforms.Compose(transform_list + [ | |
transforms.ToTensor(), | |
transforms.Normalize(model.default_cfg["mean"], model.default_cfg["std"]), | |
]) | |
def process_image(img, r=25, layers=1): | |
img = Image.fromarray(img.astype('uint8'), 'RGB') | |
img_vis = transform_vis(img) | |
img_norm = transform_norm(img) | |
# from the paper: | |
# r can take the following forms: | |
# - int: A constant number of tokens per layer. | |
# - Tuple[int, float]: A pair of r, inflection. | |
# Inflection describes there the the reduction / layer should trend | |
# upward (+1), downward (-1), or stay constant (0). A value of (r, 0) | |
# is as providing a constant r. (r, -1) is what we describe in the paper | |
# as "decreasing schedule". Any value between -1 and +1 is accepted. | |
# - List[int]: A specific number of tokens per layer. For extreme granularity. | |
if layers != 1: | |
r = [r] * layers | |
print(r) | |
model.r = r | |
_ = model(img_norm[None, ...]) | |
source = model._tome_info["source"] | |
# print(f"{source.shape[1]} tokens at the end") | |
return tome.make_visualization(img_vis, source, patch_size=16, class_token=True) | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
"image", | |
gr.inputs.Slider(0, 50, step=1, label="r value (the amount of reduction. See paper for details.)"), | |
gr.inputs.Slider(1, 50, step=1, label="layers (1 means r is applied to all layers)"), | |
], | |
outputs="image", | |
examples=[ | |
["images/husky.png", 25, 1], | |
["images/husky.png", 25, 8], | |
["images/husky.png", 25, 16], | |
["images/husky.png", 25, 22], | |
] | |
) | |
iface.launch() | |