osbm commited on
Commit
3a50a96
1 Parent(s): dce5d16

initial commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ *.pyc
README.md CHANGED
@@ -9,4 +9,17 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # Token Merging: Your ViT but Faster
13
+
14
+ github: https://github.com/facebookresearch/tome
15
+ paper: https://arxiv.org/abs/2210.09461
16
+
17
+ # Citation
18
+ ```bibtex
19
+ @inproceedings{bolya2022tome,
20
+ title={Token Merging: Your {ViT} but Faster},
21
+ author={Bolya, Daniel and Fu, Cheng-Yang and Dai, Xiaoliang and Zhang, Peizhao and Feichtenhofer, Christoph and Hoffman, Judy},
22
+ booktitle={International Conference on Learning Representations},
23
+ year={2023}
24
+ }
25
+ ```
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tome
2
+ import timm
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from torchvision.transforms.functional import InterpolationMode
7
+
8
+
9
+ model_name = "vit_large_patch16_384"
10
+
11
+ print("Started Downloading:", model_name)
12
+ model = timm.create_model(model_name, pretrained=True)
13
+ print("Finished Downloading:", model_name)
14
+
15
+ tome.patch.timm(model, trace_source=True)
16
+
17
+ input_size = model.default_cfg["input_size"][1]
18
+
19
+ # Make sure the transform is correct for your model!
20
+ transform_list = [
21
+ transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
22
+ transforms.CenterCrop(input_size)
23
+ ]
24
+
25
+ # The visualization and model need different transforms
26
+ transform_vis = transforms.Compose(transform_list)
27
+ transform_norm = transforms.Compose(transform_list + [
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(model.default_cfg["mean"], model.default_cfg["std"]),
30
+ ])
31
+
32
+
33
+ def process_image(img, r=25, layers=1):
34
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
35
+ img_vis = transform_vis(img)
36
+ img_norm = transform_norm(img)
37
+
38
+ # from the paper:
39
+ # r can take the following forms:
40
+ # - int: A constant number of tokens per layer.
41
+ # - Tuple[int, float]: A pair of r, inflection.
42
+ # Inflection describes there the the reduction / layer should trend
43
+ # upward (+1), downward (-1), or stay constant (0). A value of (r, 0)
44
+ # is as providing a constant r. (r, -1) is what we describe in the paper
45
+ # as "decreasing schedule". Any value between -1 and +1 is accepted.
46
+ # - List[int]: A specific number of tokens per layer. For extreme granularity.
47
+
48
+ if layers != 1:
49
+ r = [r] * layers
50
+
51
+ print(r)
52
+ model.r = r
53
+ _ = model(img_norm[None, ...])
54
+ source = model._tome_info["source"]
55
+
56
+ # print(f"{source.shape[1]} tokens at the end")
57
+ return tome.make_visualization(img_vis, source, patch_size=16, class_token=True)
58
+
59
+
60
+ iface = gr.Interface(
61
+ fn=process_image,
62
+ inputs=[
63
+ "image",
64
+ gr.inputs.Slider(0, 50, step=1, label="r value (the amount of reduction. See paper for details.)"),
65
+ gr.inputs.Slider(1, 50, step=1, label="layers (1 means r is applied to all layers)"),
66
+ ],
67
+ outputs="image",
68
+ examples=[
69
+ ["images/husky.png", 25, 1],
70
+ ["images/husky.png", 25, 8],
71
+ ["images/husky.png", 25, 16],
72
+ ["images/husky.png", 25, 22],
73
+ ]
74
+ )
75
+ iface.launch()
images/concept_figure.png ADDED

Git LFS Details

  • SHA256: 535d645011be6021705eba4b8a2b48a43a1c5fad0afddb2ea76bafd31cbcd2b6
  • Pointer size: 131 Bytes
  • Size of remote file: 339 kB
images/husky.png ADDED

Git LFS Details

  • SHA256: 79699ff62ea6595a273ea54fb136a4a68edd580fad1a5225d54e15e67b613f4c
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
images/image_vis.png ADDED

Git LFS Details

  • SHA256: 0bd8d344b2fac00867ff9e14cef362477cfa8f64aa8c4fc7354313d40dc69b6c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ timm==0.4.12
3
+ torchvision
4
+ torch
5
+ pillow
6
+ tqdm
7
+ git+https://github.com/facebookresearch/tome