osbm's picture
initial commit
raw history blame
No virus
2.43 kB
import tome
import timm
import gradio as gr
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
model_name = "vit_large_patch16_384"
print("Started Downloading:", model_name)
model = timm.create_model(model_name, pretrained=True)
print("Finished Downloading:", model_name)
tome.patch.timm(model, trace_source=True)
input_size = model.default_cfg["input_size"][1]
# Make sure the transform is correct for your model!
transform_list = [
transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
# The visualization and model need different transforms
transform_vis = transforms.Compose(transform_list)
transform_norm = transforms.Compose(transform_list + [
transforms.Normalize(model.default_cfg["mean"], model.default_cfg["std"]),
def process_image(img, r=25, layers=1):
img = Image.fromarray(img.astype('uint8'), 'RGB')
img_vis = transform_vis(img)
img_norm = transform_norm(img)
# from the paper:
# r can take the following forms:
# - int: A constant number of tokens per layer.
# - Tuple[int, float]: A pair of r, inflection.
# Inflection describes there the the reduction / layer should trend
# upward (+1), downward (-1), or stay constant (0). A value of (r, 0)
# is as providing a constant r. (r, -1) is what we describe in the paper
# as "decreasing schedule". Any value between -1 and +1 is accepted.
# - List[int]: A specific number of tokens per layer. For extreme granularity.
if layers != 1:
r = [r] * layers
model.r = r
_ = model(img_norm[None, ...])
source = model._tome_info["source"]
# print(f"{source.shape[1]} tokens at the end")
return tome.make_visualization(img_vis, source, patch_size=16, class_token=True)
iface = gr.Interface(
gr.inputs.Slider(0, 50, step=1, label="r value (the amount of reduction. See paper for details.)"),
gr.inputs.Slider(1, 50, step=1, label="layers (1 means r is applied to all layers)"),
["images/husky.png", 25, 1],
["images/husky.png", 25, 8],
["images/husky.png", 25, 16],
["images/husky.png", 25, 22],