osbm's picture
initial commit
3a50a96
import tome
import timm
import gradio as gr
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
model_name = "vit_large_patch16_384"
print("Started Downloading:", model_name)
model = timm.create_model(model_name, pretrained=True)
print("Finished Downloading:", model_name)
tome.patch.timm(model, trace_source=True)
input_size = model.default_cfg["input_size"][1]
# Make sure the transform is correct for your model!
transform_list = [
transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(input_size)
]
# The visualization and model need different transforms
transform_vis = transforms.Compose(transform_list)
transform_norm = transforms.Compose(transform_list + [
transforms.ToTensor(),
transforms.Normalize(model.default_cfg["mean"], model.default_cfg["std"]),
])
def process_image(img, r=25, layers=1):
img = Image.fromarray(img.astype('uint8'), 'RGB')
img_vis = transform_vis(img)
img_norm = transform_norm(img)
# from the paper:
# r can take the following forms:
# - int: A constant number of tokens per layer.
# - Tuple[int, float]: A pair of r, inflection.
# Inflection describes there the the reduction / layer should trend
# upward (+1), downward (-1), or stay constant (0). A value of (r, 0)
# is as providing a constant r. (r, -1) is what we describe in the paper
# as "decreasing schedule". Any value between -1 and +1 is accepted.
# - List[int]: A specific number of tokens per layer. For extreme granularity.
if layers != 1:
r = [r] * layers
print(r)
model.r = r
_ = model(img_norm[None, ...])
source = model._tome_info["source"]
# print(f"{source.shape[1]} tokens at the end")
return tome.make_visualization(img_vis, source, patch_size=16, class_token=True)
iface = gr.Interface(
fn=process_image,
inputs=[
"image",
gr.inputs.Slider(0, 50, step=1, label="r value (the amount of reduction. See paper for details.)"),
gr.inputs.Slider(1, 50, step=1, label="layers (1 means r is applied to all layers)"),
],
outputs="image",
examples=[
["images/husky.png", 25, 1],
["images/husky.png", 25, 8],
["images/husky.png", 25, 16],
["images/husky.png", 25, 22],
]
)
iface.launch()