import tome import timm import gradio as gr from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode model_name = "vit_large_patch16_384" print("Started Downloading:", model_name) model = timm.create_model(model_name, pretrained=True) print("Finished Downloading:", model_name) tome.patch.timm(model, trace_source=True) input_size = model.default_cfg["input_size"][1] # Make sure the transform is correct for your model! transform_list = [ transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(input_size) ] # The visualization and model need different transforms transform_vis = transforms.Compose(transform_list) transform_norm = transforms.Compose(transform_list + [ transforms.ToTensor(), transforms.Normalize(model.default_cfg["mean"], model.default_cfg["std"]), ]) def process_image(img, r=25, layers=1): img = Image.fromarray(img.astype('uint8'), 'RGB') img_vis = transform_vis(img) img_norm = transform_norm(img) # from the paper: # r can take the following forms: # - int: A constant number of tokens per layer. # - Tuple[int, float]: A pair of r, inflection. # Inflection describes there the the reduction / layer should trend # upward (+1), downward (-1), or stay constant (0). A value of (r, 0) # is as providing a constant r. (r, -1) is what we describe in the paper # as "decreasing schedule". Any value between -1 and +1 is accepted. # - List[int]: A specific number of tokens per layer. For extreme granularity. if layers != 1: r = [r] * layers print(r) model.r = r _ = model(img_norm[None, ...]) source = model._tome_info["source"] # print(f"{source.shape[1]} tokens at the end") return tome.make_visualization(img_vis, source, patch_size=16, class_token=True) iface = gr.Interface( fn=process_image, inputs=[ "image", gr.inputs.Slider(0, 50, step=1, label="r value (the amount of reduction. See paper for details.)"), gr.inputs.Slider(1, 50, step=1, label="layers (1 means r is applied to all layers)"), ], outputs="image", examples=[ ["images/husky.png", 25, 1], ["images/husky.png", 25, 8], ["images/husky.png", 25, 16], ["images/husky.png", 25, 22], ] ) iface.launch()