WalidBouss commited on
Commit
be1ec96
·
1 Parent(s): 55fd6e9

Initial commit :tada:

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import cv2 as cv2
4
+ import torch
5
+ import requests
6
+
7
+ import gradio as gr
8
+
9
+ import gem
10
+
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ # OpenCLIP
14
+ model_name = 'ViT-B-16-quickgelu'
15
+ pretrained = 'metaclip_400m'
16
+ preprocess = gem.get_gem_img_transform()
17
+ # global gem_model
18
+ gem_model = gem.create_gem_model(model_name=model_name, pretrained=pretrained, device=device)
19
+ image_source = "image"
20
+ _MODELS = {
21
+ "OpenAI": ('ViT-B-16', 'openai'),
22
+ "MetaCLIP": ('ViT-B-16-quickgelu', 'metaclip_400m'),
23
+ "OpenCLIP": ('ViT-B-16', 'laion400m_e32')
24
+ }
25
+
26
+ def change_weights(pretrained_weights):
27
+ """ Handle changing model's weights triggered by a Dropdown module change."""
28
+ curr_model = pretrained_weights
29
+ _new_model = _MODELS[pretrained_weights]
30
+ print(_new_model)
31
+ global gem_model
32
+ gem_model = gem.create_gem_model(model_name=_new_model[0], pretrained=_new_model[1], device=device)
33
+
34
+ def change_to_url(url):
35
+ img_pil = Image.open(requests.get(url, stream=True).raw).convert('RGB')
36
+ return img_pil
37
+
38
+ def viz_func(url, image, text, model_weights):
39
+ image_torch = preprocess(image).unsqueeze(0).to(device)
40
+ with torch.no_grad():
41
+ logits = gem_model(image_torch, [text])
42
+ logits = logits[0].detach().cpu().numpy()
43
+
44
+ img_cv = cv2.cvtColor(np.array(image.resize((448, 448))), cv2.COLOR_RGB2BGR)
45
+ logit_cs_viz = (logits * 255).astype('uint8')
46
+ heat_maps_cs = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logit_cs_viz]
47
+
48
+ vizs = [0.4 * img_cv + 0.6 * heat_map for heat_map in heat_maps_cs]
49
+ vizs = [cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB) for viz in vizs]
50
+ return vizs[0]
51
+
52
+ inputs = [
53
+ gr.Textbox(label="url to the image", ),
54
+ gr.Image(type="pil"),
55
+ gr.Textbox(label="Text Prompt"),
56
+ gr.Dropdown(["OpenAI", "MetaCLIP", "OpenCLIP"], label="Pretrained Weights", value="MetaCLIP",
57
+ info='It can take a few second for the model to be updated.'),
58
+ ]
59
+
60
+ with gr.Blocks() as demo:
61
+ inputs[-1].change(fn=change_weights, inputs=[inputs[-1]])
62
+ inputs[0].change(fn=change_to_url, outputs=inputs[1], inputs=inputs[0])
63
+
64
+ interact = gr.Interface(
65
+ title="GEM: Grounding Everything Module (link to paper/code)",
66
+ description="Grounding Everything: Emerging Localization Properties in Vision-Language Transformers",
67
+ fn=viz_func,
68
+ inputs=inputs,
69
+ outputs=["image"],
70
+ )
71
+
72
+ gr.Examples(
73
+ [
74
+ ["assets/cats_remote_control.jpeg", "cat"],
75
+ ["assets/cats_remote_control.jpeg", "remote control"],
76
+ ["assets/elon_jeff_mark.jpeg", "elon musk"],
77
+ ["assets/elon_jeff_mark.jpeg", "mark zuckerberg"],
78
+ ["assets/elon_jeff_mark.jpeg", "jeff bezos"],
79
+ ],
80
+ [inputs[1], inputs[2]]
81
+ )
82
+
83
+ # demo.launch(server_port=5152)
84
+ demo.launch()
assets/cats_remote_control.jpeg ADDED

Git LFS Details

  • SHA256: dea9e7ef97386345f7cff32f9055da4982da5471c48d575146c796ab4563b04e
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
assets/elon_jeff_mark.jpeg ADDED

Git LFS Details

  • SHA256: 680a5638a2af9658bc7e9506f54fb1d984ae4282337f6ec65504f61a54e3317f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
gem/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .gem import *
gem/gem.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Union, List, Optional, Tuple, Dict
3
+ import open_clip
4
+ from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
5
+
6
+ import torch
7
+ from torchvision import transforms
8
+ import matplotlib.pyplot as plt
9
+ from PIL import Image
10
+ import numpy as np
11
+ import cv2 as cv2
12
+
13
+ from .gem_wrapper import GEMWrapper
14
+
15
+
16
+ _MODELS = {
17
+ # B/32
18
+ "ViT-B/32": [
19
+ "openai",
20
+ "laion400m_e31",
21
+ "laion400m_e32",
22
+ "laion2b_e16",
23
+ "laion2b_s34b_b79k",
24
+ ],
25
+
26
+ "ViT-B/32-quickgelu": [
27
+ "metaclip_400m",
28
+ "metaclip_fullcc"
29
+ ],
30
+ # B/16
31
+ "ViT-B/16": [
32
+ "openai",
33
+ "laion400m_e31",
34
+ "laion400m_e32",
35
+ "laion2b_s34b_b88k",
36
+ ],
37
+ "ViT-B/16-quickgelu": [
38
+ "metaclip_400m",
39
+ "metaclip_fullcc",
40
+ ],
41
+ "ViT-B/16-plus-240": [
42
+ "laion400m_e31",
43
+ "laion400m_e32"
44
+ ],
45
+ # L/14
46
+ "ViT-L/14": [
47
+ "openai",
48
+ "laion400m_e31",
49
+ "laion400m_e32",
50
+ "laion2b_s32b_b82k",
51
+ ],
52
+ "ViT-L/14-quickgelu": [
53
+ "metaclip_400m",
54
+ "metaclip_fullcc"
55
+ ],
56
+ "ViT-L/14-336": [
57
+ "openai",
58
+ ]
59
+ }
60
+
61
+ def available_models() -> List[str]:
62
+ """Returns the names of available GEM-VL models"""
63
+ # _str = "".join([": ".join([key, value]) + "\n" for key, values in _MODELS2.items() for value in values])
64
+ _str = "".join([": ".join([key + " "*(20 - len(key)), value]) + "\n" for key, values in _MODELS.items() for value in values])
65
+ return _str
66
+
67
+ def get_tokenizer(
68
+ model_name: str = '',
69
+ context_length: Optional[int] = None,
70
+ **kwargs,
71
+ ):
72
+ """ Wrapper around openclip get_tokenizer function """
73
+ return open_clip.get_tokenizer(model_name=model_name, context_length=context_length, **kwargs)
74
+
75
+
76
+ def get_gem_img_transform(
77
+ img_size: Union[int, Tuple[int, int]] = (448, 448),
78
+ mean: Optional[Tuple[float, ...]] = None,
79
+ std: Optional[Tuple[float, ...]] = None,
80
+ ):
81
+ mean = mean or OPENAI_DATASET_MEAN
82
+ std = std or OPENAI_DATASET_STD
83
+ transform = transforms.Compose([
84
+ transforms.Resize(size=img_size, interpolation=transforms.InterpolationMode.BICUBIC),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize(mean, std),
87
+ ])
88
+ return transform
89
+
90
+
91
+ def create_gem_model(
92
+ model_name: str,
93
+ pretrained: Optional[str] = None,
94
+ gem_depth: int = 7,
95
+ ss_attn_iter: int = 1,
96
+ ss_attn_temp: Optional[float] = None,
97
+ precision: str = 'fp32',
98
+ device: Union[str, torch.device] = 'cpu',
99
+ jit: bool = False,
100
+ force_quick_gelu: bool = False,
101
+ force_custom_text: bool = False,
102
+ force_patch_dropout: Optional[float] = None,
103
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
104
+ force_preprocess_cfg: Optional[Dict[str, Any]] = None,
105
+ pretrained_image: bool = False,
106
+ pretrained_hf: bool = True,
107
+ cache_dir: Optional[str] = None,
108
+ output_dict: Optional[bool] = None,
109
+ require_pretrained: bool = False,
110
+ **model_kwargs,
111
+ ):
112
+ model_name = model_name.replace("/", "-")
113
+ logging.info(f'Loading pretrained {model_name} from pretrained weights {pretrained}...')
114
+ open_clip_model = open_clip.create_model(model_name, pretrained, precision, device, jit, force_quick_gelu, force_custom_text,
115
+ force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
116
+ pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)
117
+ tokenizer = open_clip.get_tokenizer(model_name=model_name)
118
+
119
+ gem_model = GEMWrapper(model=open_clip_model, tokenizer=tokenizer, depth=gem_depth,
120
+ ss_attn_iter=ss_attn_iter, ss_attn_temp=ss_attn_temp)
121
+ logging.info(f'Loaded GEM-{model_name} from pretrained weights {pretrained}!')
122
+ return gem_model
123
+
124
+ def create_model_and_transforms(
125
+ model_name: str,
126
+ pretrained: Optional[str] = None,
127
+ gem_depth: int = 7,
128
+ precision: str = 'fp32',
129
+ device: Union[str, torch.device] = 'cpu',
130
+ jit: bool = False,
131
+ force_quick_gelu: bool = False,
132
+ force_custom_text: bool = False,
133
+ force_patch_dropout: Optional[float] = None,
134
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
135
+ force_preprocess_cfg: Optional[Dict[str, Any]] = None,
136
+ pretrained_image: bool = False,
137
+ pretrained_hf: bool = True,
138
+ cache_dir: Optional[str] = None,
139
+ output_dict: Optional[bool] = None,
140
+ require_pretrained: bool = False,
141
+ **model_kwargs,
142
+ ):
143
+ gem_model = create_gem_model(model_name, pretrained, gem_depth, precision, device, jit, force_quick_gelu, force_custom_text,
144
+ force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
145
+ pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)
146
+
147
+ transform = get_gem_img_transform(**model_kwargs)
148
+ return gem_model, transform
149
+
150
+ def visualize(image, text, logits, alpha=0.6, save_path=None):
151
+ W, H = logits.shape[-2:]
152
+ if isinstance(image, Image.Image):
153
+ image = image.resize((W, H))
154
+ elif isinstance(image, torch.Tensor):
155
+ if image.ndim > 3:
156
+ image = image.squeeze(0)
157
+ image_unormed = (image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]) \
158
+ + torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None] # undo the normalization
159
+ image = Image.fromarray((image_unormed.permute(1, 2, 0).numpy() * 255).astype('uint8')) # convert to PIL
160
+ else:
161
+ raise f'image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}'
162
+
163
+ # plot image
164
+ plt.imshow(image)
165
+ plt.axis('off')
166
+ plt.tight_layout()
167
+ plt.show()
168
+
169
+ if logits.ndim > 3:
170
+ logits = logits.squeeze(0)
171
+ logits = logits.detach().cpu().numpy()
172
+
173
+
174
+ img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
175
+ logits = (logits * 255).astype('uint8')
176
+ heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logits]
177
+
178
+ vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps]
179
+ for viz, cls_name in zip(vizs, text):
180
+
181
+ viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB)
182
+ plt.imshow(viz)
183
+ plt.title(cls_name)
184
+ plt.axis('off')
185
+ plt.tight_layout()
186
+ plt.show()
187
+ if save_path is not None:
188
+ plt.savefig(f'heatmap_{cls_name}.png')
gem/gem_utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from open_clip.transformer import _expand_token, to_2tuple
9
+
10
+
11
+
12
+ def resample_abs_pos_embed(
13
+ posemb,
14
+ new_size: List[int],
15
+ old_size: Optional[List[int]] = None,
16
+ num_prefix_tokens: int = 1,
17
+ interpolation: str = 'bicubic',
18
+ antialias: bool = True
19
+ ):
20
+ # sort out sizes, assume square if old size not provided
21
+ new_size = to_2tuple(new_size)
22
+ new_ntok = new_size[0] * new_size[1]
23
+ if not old_size:
24
+ old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
25
+ old_size = to_2tuple(old_size)
26
+ if new_size == old_size: # might not both be same container type
27
+ return posemb
28
+
29
+ if num_prefix_tokens:
30
+ posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
31
+ else:
32
+ posemb_prefix, posemb = None, posemb
33
+
34
+ # do the interpolation
35
+ posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
36
+ posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
37
+ posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
38
+
39
+
40
+ # add back extra (class, etc) prefix tokens
41
+ if posemb_prefix is not None:
42
+ posemb = torch.cat([posemb_prefix, posemb], dim=1)
43
+ return posemb
44
+
45
+ class SelfSelfAttention(nn.Module):
46
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ss_attn_iter=1,
47
+ ss_attn_temp=None):
48
+ super().__init__()
49
+ self.num_heads = num_heads
50
+ head_dim = dim // num_heads
51
+ self.scale = qk_scale or head_dim ** -0.5
52
+ self.ss_attn_iter = ss_attn_iter
53
+ self.ss_attn_temp = ss_attn_temp
54
+
55
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
56
+ self.attn_drop = nn.Dropout(attn_drop)
57
+ self.proj = nn.Linear(dim, dim)
58
+ self.proj_drop = nn.Dropout(proj_drop)
59
+
60
+ def forward(self, x, attn_bias=None, prev_attn=None):
61
+ x = x.transpose(0, 1)
62
+ B, N, C = x.shape
63
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
64
+ q, k, v = qkv[0], qkv[1], qkv[2]
65
+ self.v_values = v
66
+ # original self-attention for the original path
67
+ attn_ori_return = (q @ k.transpose(-2, -1)) * self.scale
68
+ attn_ori = attn_ori_return.softmax(dim=-1)
69
+ attn_ori = self.attn_drop(attn_ori)
70
+
71
+ x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
72
+ x_ori = self.proj_drop(self.proj(x_ori))
73
+
74
+ # GEM
75
+ xs1 = v
76
+ xs2 = k
77
+ xs3 = q
78
+
79
+ if self.ss_attn_temp is None:
80
+ pre_norm = torch.norm(x, dim=-1).mean(dim=-1, keepdim=True).unsqueeze(1).unsqueeze(-1)
81
+ inv_temp = pre_norm * self.scale
82
+ else:
83
+ inv_temp = self.ss_attn_temp
84
+
85
+ for it in range(self.ss_attn_iter):
86
+ xs1 = F.normalize(xs1, dim=-1)
87
+ xs2 = F.normalize(xs2, dim=-1)
88
+ xs3 = F.normalize(xs3, dim=-1)
89
+
90
+ attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp
91
+ attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp
92
+ attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp
93
+
94
+ attn1 = (attn_return1).softmax(dim=-1)
95
+ attn2 = (attn_return2).softmax(dim=-1)
96
+ attn3 = (attn_return3).softmax(dim=-1)
97
+
98
+ xs1 = attn1 @ xs1
99
+ xs2 = attn2 @ xs2
100
+ xs3 = attn3 @ xs3
101
+
102
+ # Assigment to V
103
+ xs1 = F.normalize(xs1, dim=-1)
104
+ xs2 = F.normalize(xs2, dim=-1)
105
+ xs3 = F.normalize(xs3, dim=-1)
106
+
107
+ attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp
108
+ attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp
109
+ attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp
110
+
111
+ attn1 = (attn_return1).softmax(dim=-1)
112
+ attn2 = (attn_return2).softmax(dim=-1)
113
+ attn3 = (attn_return3).softmax(dim=-1)
114
+
115
+ xs1 = attn1 @ v
116
+ xs2 = attn2 @ v
117
+ xs3 = attn3 @ v
118
+ xs = (xs1 + xs2 + xs3) / 3
119
+
120
+ x = xs.transpose(1, 2).reshape(B, N, C)
121
+ x = self.proj_drop(self.proj(x))
122
+
123
+ return [x.transpose(0, 1), x_ori.transpose(0, 1)]
124
+
125
+
126
+ class GEMResidualBlock(nn.Module):
127
+ def __init__(self, res_block):
128
+ super(GEMResidualBlock, self).__init__()
129
+ self.res_block = res_block
130
+
131
+ def forward(self,
132
+ q_x: torch.Tensor,
133
+ k_x: Optional[torch.Tensor] = None,
134
+ v_x: Optional[torch.Tensor] = None,
135
+ attn_mask: Optional[torch.Tensor] = None,
136
+ ):
137
+ if isinstance(q_x, list):
138
+ x_gem, q_x = q_x
139
+ else:
140
+ x_gem = q_x
141
+
142
+ x_gem_res, x_ori_res = self.res_block.attn(x=self.res_block.ln_1(q_x))
143
+ x_gem_res, x_ori_res = self.res_block.ls_1(x_gem_res), self.res_block.ls_1(x_ori_res)
144
+ # Original
145
+ x_ori = q_x + x_ori_res
146
+ x_ori = x_ori + self.res_block.ls_2(self.res_block.mlp(self.res_block.ln_2(x_ori)))
147
+ # GEM
148
+ x_gem = x_gem + x_gem_res
149
+ return [x_gem, x_ori]
150
+
151
+ class GEMViT(nn.Module):
152
+ def __init__(self, vit):
153
+ self.vit = vit
154
+
155
+ def modified_vit_forward(self, x: torch.Tensor):
156
+ x = self.conv1(x) # shape = [*, width, grid, grid]
157
+ grid_h, grid_w = x.shape[2:]
158
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
159
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
160
+
161
+ # class embeddings and positional embeddings
162
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
163
+ # shape = [*, grid ** 2 + 1, width]
164
+
165
+ if x.shape[1] != self.positional_embedding.shape[1]:
166
+ pos_emb = resample_abs_pos_embed(self.positional_embedding.unsqueeze(0),
167
+ new_size=[grid_h, grid_w],
168
+ # old_size=list(self.grid_size),
169
+ num_prefix_tokens=1,
170
+ interpolation='bicubic',
171
+ antialias=True)
172
+
173
+ else:
174
+ pos_emb = self.positional_embedding
175
+
176
+ x = x + pos_emb.to(x.dtype)
177
+ # x = x + self.positional_embedding.to(x.dtype)
178
+
179
+ x = self.patch_dropout(x)
180
+ x = self.ln_pre(x)
181
+
182
+ x = x.permute(1, 0, 2) # NLD -> LND
183
+ x_gem, x = self.transformer(x)
184
+ x = x.permute(1, 0, 2) # LND -> NLD
185
+ x_gem = x_gem.permute(1, 0, 2) # LND -> NLD
186
+
187
+ # Apply proj
188
+ x = self.ln_post(x)
189
+ x_gem = self.ln_post(x_gem)
190
+ if self.proj is not None:
191
+ x = x @ self.proj
192
+ x_gem = x_gem @ self.proj
193
+
194
+ return [x_gem, x]
gem/gem_wrapper.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from open_clip.transformer import VisionTransformer
7
+
8
+ from .gem_utils import SelfSelfAttention, GEMResidualBlock, modified_vit_forward
9
+
10
+
11
+ class GEMWrapper(nn.Module):
12
+ def __init__(self, model, tokenizer, depth=7, ss_attn_iter=1, ss_attn_temp=None):
13
+ super(GEMWrapper, self).__init__()
14
+ self.model = model
15
+ self.tokenizer = tokenizer
16
+ self.depth = depth
17
+ self.ss_attn_iter = ss_attn_iter
18
+ self.ss_attn_temp = ss_attn_temp
19
+ self.patch_size = self.model.visual.patch_size[0]
20
+ self.apply_gem()
21
+
22
+ def apply_gem(self):
23
+ for i in range(1, self.depth):
24
+ # Extract info from the original ViT
25
+ num_heads = self.model.visual.transformer.resblocks[-i].attn.num_heads
26
+ dim = int(self.model.visual.transformer.resblocks[-i].attn.head_dim * num_heads)
27
+ qkv_bias = True
28
+ # Init the self-self attention layer
29
+ ss_attn = SelfSelfAttention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
30
+ ss_attn_iter=self.ss_attn_iter, ss_attn_temp=self.ss_attn_temp)
31
+ # Copy necessary weights
32
+ ss_attn.qkv.weight.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_weight.clone()
33
+ ss_attn.qkv.bias.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_bias.clone()
34
+ ss_attn.proj.weight.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.weight.clone()
35
+ ss_attn.proj.bias.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.bias.clone()
36
+ # Swap the original Attention with our SelfSelfAttention
37
+ self.model.visual.transformer.resblocks[-i].attn = ss_attn
38
+ # Wrap Residual block to handle SelfSelfAttention outputs
39
+ self.model.visual.transformer.resblocks[-i] = GEMResidualBlock(self.model.visual.transformer.resblocks[-i])
40
+ # Modify ViT's forward function
41
+ self.model.visual.forward = modified_vit_forward.__get__(self.model.visual, VisionTransformer)
42
+ return
43
+
44
+ def encode_text(self, text: list):
45
+ prompts = [f'a photo of a {cls}.' for cls in text]
46
+ tokenized_prompts = self.tokenizer(prompts).to(self.model.visual.proj.device)
47
+ text_embedding = self.model.encode_text(tokenized_prompts)
48
+ text_embedding = F.normalize(text_embedding, dim=-1)
49
+ return text_embedding.unsqueeze(0)
50
+
51
+ def min_max(self, logits):
52
+ B, num_prompt = logits.shape[:2]
53
+ logits_min = logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
54
+ logits_max = logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
55
+ logits = (logits - logits_min) / (logits_max - logits_min)
56
+ return logits
57
+
58
+ def forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
59
+ """
60
+ :param image: torch.Tensor [1, 3, H, W]
61
+ :param text: list[]
62
+ :param normalize: bool - if True performs min-max normalization
63
+ :param return_ori: bool - if True uses the features from the original visual encoder
64
+ """
65
+ # Image
66
+ W, H = image.shape[-2:]
67
+ feat_gem, feat_ori = self.model.visual(image)
68
+ image_feat = feat_ori if return_ori else feat_gem
69
+ image_feat = F.normalize(image_feat, dim=-1) # [1, N, dim]
70
+
71
+ # Text
72
+ text_embeddings = self.encode_text(text) # [1, num_prompt, dim]
73
+
74
+ # Image-Text matching
75
+ img_txt_matching = image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [1, N, num_prompt]
76
+ img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
77
+ w=W//self.patch_size, h=H//self.patch_size) # [1, num_prompt, w, h]
78
+
79
+ # Interpolate
80
+ img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [1, num_prompt, W, H]
81
+
82
+ # Heat Maps
83
+ if normalize:
84
+ img_txt_matching = self.min_max(img_txt_matching)
85
+ return img_txt_matching
86
+
87
+ def batched_forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
88
+ """
89
+ :param image: torch.Tensor [B, 3, H, W]
90
+ :param text: list[list[]]
91
+ :param normalize: bool - if True performs min-max normalization
92
+ :param return_ori: bool - if True uses the features from the original visual encoder
93
+ """
94
+ L = len(text)
95
+ cumm_idx = np.cumsum([len(t) for t in text]).tolist()
96
+ B, _, W, H = image.shape
97
+ assert B == L, f'Number of prompts L: {L} should be the same as number of images B: {B}.'
98
+
99
+ # Image
100
+ feat_gem, feat_ori = self.model.visual(image)
101
+ image_feat = feat_ori if return_ori else feat_gem
102
+ image_feat = F.normalize(image_feat, dim=-1) # [B, N, dim]
103
+
104
+ # Text
105
+ flatten_text = [t for sub_text in text for t in sub_text]
106
+ text_embeddings = self.encode_text(flatten_text) # [B, num_prompt, dim]
107
+
108
+ # Image-Text matching
109
+ img_txt_matching = 100 * image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [B, N, num_prompt]
110
+ img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
111
+ w=W // self.patch_size, h=H // self.patch_size) # [B, num_prompt, w, h]
112
+
113
+ # Interpolate
114
+ img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [B,num_prompt, W, H]
115
+
116
+ # Heat Maps
117
+ if normalize:
118
+ img_txt_matching = self.min_max(img_txt_matching) # [B,num_prompt, W, H]
119
+
120
+ # unflatten
121
+ img_txt_matching = torch.tensor_split(img_txt_matching, cumm_idx[:-1], dim=1)
122
+ img_txt_matching = [itm[i] for i, itm in enumerate(img_txt_matching)]
123
+ return img_txt_matching
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision
3
+ regex
4
+ ftfy
5
+ tqdm
6
+ huggingface_hub
7
+ sentencepiece
8
+ protobuf
9
+ timm
10
+ einops
11
+ open_clip_torch
12
+ opencv-python
13
+ matplotlib
14
+ numpy
15
+ requests
setup.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Setup
2
+ Adapted from https://github.com/mlfoundations/open_clip
3
+ """
4
+ from setuptools import setup, find_packages
5
+ from codecs import open
6
+ from os import path
7
+
8
+ here = path.abspath(path.dirname(__file__))
9
+
10
+ # Get the long description from the README file
11
+ with open(path.join(here, 'README.md'), encoding='utf-8') as f:
12
+ long_description = f.read()
13
+
14
+ def _read_reqs(relpath):
15
+ fullpath = path.join(path.dirname(__file__), relpath)
16
+ with open(fullpath) as f:
17
+ return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))]
18
+
19
+ REQUIREMENTS = _read_reqs("requirements.txt")
20
+
21
+ setup(
22
+ name='gem_torch',
23
+ version="1.0",
24
+ description='GEM',
25
+ long_description=long_description,
26
+ long_description_content_type='text/markdown',
27
+ url='https://github.com/WalBouss/GEM',
28
+ author='Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne',
29
+ author_email='',
30
+ classifiers=[
31
+ # How mature is this project? Common values are
32
+ # 3 - Alpha
33
+ # 4 - Beta
34
+ # 5 - Production/Stable
35
+ 'Development Status :: 3 - Alpha',
36
+ 'Intended Audience :: Education',
37
+ 'Intended Audience :: Science/Research',
38
+ 'License :: OSI Approved :: Apache Software License',
39
+ 'Programming Language :: Python :: 3.7',
40
+ 'Programming Language :: Python :: 3.8',
41
+ 'Programming Language :: Python :: 3.9',
42
+ 'Programming Language :: Python :: 3.10',
43
+ 'Topic :: Scientific/Engineering',
44
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
45
+ 'Topic :: Software Development',
46
+ 'Topic :: Software Development :: Libraries',
47
+ 'Topic :: Software Development :: Libraries :: Python Modules',
48
+ ],
49
+
50
+ # Note that this is a string of words separated by whitespace, not a list.
51
+ keywords='CLIP pretrained',
52
+ py_modules=["gem"],
53
+ packages=find_packages(exclude=["assets*"]),
54
+ include_package_data=True,
55
+ install_requires=REQUIREMENTS,
56
+ python_requires='>=3.7',
57
+ )
test_examples.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from gem import create_gem_model, get_gem_img_transform, visualize, available_models
3
+ import torch
4
+ import requests
5
+
6
+
7
+ print(available_models())
8
+
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+ model_name = 'ViT-B-16-quickgelu'
11
+ pretrained = 'metaclip_400m'
12
+ gem_model = create_gem_model(model_name=model_name, pretrained=pretrained, device=device)
13
+ gem_model.eval()
14
+
15
+ ###########################
16
+ # Single Image
17
+ ###########################
18
+
19
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg" # cat & remote control
20
+ text = ['remote control', 'cat']
21
+ # image_path = 'path/to/image' #, <-- uncomment to use path
22
+
23
+ image_pil = Image.open(requests.get(url, stream=True).raw)
24
+ # image_pil = Image.open(image_path) # <-- uncomment to use path
25
+
26
+ gem_img_transform = get_gem_img_transform()
27
+ image = gem_img_transform(image_pil).unsqueeze(0).to(device)
28
+
29
+ with torch.no_grad():
30
+ logits = gem_model(image, text)
31
+ visualize(image, text, logits)
32
+ print(logits.shape) # torch.Size([1, 2, 448, 448])
33
+ # visualize(image_pil, text, logits) # <-- works with torch.Tensor and PIL.Image
34
+
35
+ ###########################
36
+ # Batch of Images
37
+ ###########################
38
+ urls = [
39
+ "http://images.cocodataset.org/val2017/000000039769.jpg",
40
+ "https://cdn.vietnambiz.vn/171464876016439296/2021/7/11/headshots16170695297430-1626006880779826347793.jpg",
41
+ "https://preview.redd.it/do-you-think-joker-should-be-unpredictable-enough-to-put-up-v0-6a2ax4ngtlaa1.jpg?auto=webp&s=f8762e6a1b40642bcae5900bac184fc597131503",
42
+ ]
43
+ texts = [
44
+ ['remote control', 'cat'],
45
+ ['elon musk', 'mark zuckerberg', 'jeff bezos', 'bill gates'],
46
+ ['batman', 'joker', 'shoe', 'belt', 'purple suit'],
47
+ ] # note that the number of prompt per image can be different
48
+
49
+ # download images + convert to PIL.Image
50
+ images_pil = [Image.open(requests.get(url, stream=True).raw) for url in urls]
51
+ images = torch.stack([gem_img_transform(img) for img in images_pil]).to(device)
52
+
53
+ with torch.no_grad():
54
+ # return list with logits of size [1, num_prompt, W, H]
55
+ logits_list = gem_model.batched_forward(images, texts)
56
+ print(logits_list[0].shape) # torch.Size([2, 448, 448])
57
+ print(logits_list[1].shape) # torch.Size([4, 448, 448])
58
+ print(logits_list[2].shape) # torch.Size([5, 448, 448])
59
+ for i, _logits in enumerate(logits_list):
60
+ visualize(images[i], texts[i], _logits) # (optional visualization)