Spaces:
Runtime error
Runtime error
WalidBouss
commited on
Commit
·
be1ec96
1
Parent(s):
55fd6e9
Initial commit :tada:
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- app.py +84 -0
- assets/cats_remote_control.jpeg +3 -0
- assets/elon_jeff_mark.jpeg +3 -0
- gem/__init__.py +1 -0
- gem/gem.py +188 -0
- gem/gem_utils.py +194 -0
- gem/gem_wrapper.py +123 -0
- requirements.txt +15 -0
- setup.py +57 -0
- test_examples.py +60 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import cv2 as cv2
|
4 |
+
import torch
|
5 |
+
import requests
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
import gem
|
10 |
+
|
11 |
+
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
# OpenCLIP
|
14 |
+
model_name = 'ViT-B-16-quickgelu'
|
15 |
+
pretrained = 'metaclip_400m'
|
16 |
+
preprocess = gem.get_gem_img_transform()
|
17 |
+
# global gem_model
|
18 |
+
gem_model = gem.create_gem_model(model_name=model_name, pretrained=pretrained, device=device)
|
19 |
+
image_source = "image"
|
20 |
+
_MODELS = {
|
21 |
+
"OpenAI": ('ViT-B-16', 'openai'),
|
22 |
+
"MetaCLIP": ('ViT-B-16-quickgelu', 'metaclip_400m'),
|
23 |
+
"OpenCLIP": ('ViT-B-16', 'laion400m_e32')
|
24 |
+
}
|
25 |
+
|
26 |
+
def change_weights(pretrained_weights):
|
27 |
+
""" Handle changing model's weights triggered by a Dropdown module change."""
|
28 |
+
curr_model = pretrained_weights
|
29 |
+
_new_model = _MODELS[pretrained_weights]
|
30 |
+
print(_new_model)
|
31 |
+
global gem_model
|
32 |
+
gem_model = gem.create_gem_model(model_name=_new_model[0], pretrained=_new_model[1], device=device)
|
33 |
+
|
34 |
+
def change_to_url(url):
|
35 |
+
img_pil = Image.open(requests.get(url, stream=True).raw).convert('RGB')
|
36 |
+
return img_pil
|
37 |
+
|
38 |
+
def viz_func(url, image, text, model_weights):
|
39 |
+
image_torch = preprocess(image).unsqueeze(0).to(device)
|
40 |
+
with torch.no_grad():
|
41 |
+
logits = gem_model(image_torch, [text])
|
42 |
+
logits = logits[0].detach().cpu().numpy()
|
43 |
+
|
44 |
+
img_cv = cv2.cvtColor(np.array(image.resize((448, 448))), cv2.COLOR_RGB2BGR)
|
45 |
+
logit_cs_viz = (logits * 255).astype('uint8')
|
46 |
+
heat_maps_cs = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logit_cs_viz]
|
47 |
+
|
48 |
+
vizs = [0.4 * img_cv + 0.6 * heat_map for heat_map in heat_maps_cs]
|
49 |
+
vizs = [cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB) for viz in vizs]
|
50 |
+
return vizs[0]
|
51 |
+
|
52 |
+
inputs = [
|
53 |
+
gr.Textbox(label="url to the image", ),
|
54 |
+
gr.Image(type="pil"),
|
55 |
+
gr.Textbox(label="Text Prompt"),
|
56 |
+
gr.Dropdown(["OpenAI", "MetaCLIP", "OpenCLIP"], label="Pretrained Weights", value="MetaCLIP",
|
57 |
+
info='It can take a few second for the model to be updated.'),
|
58 |
+
]
|
59 |
+
|
60 |
+
with gr.Blocks() as demo:
|
61 |
+
inputs[-1].change(fn=change_weights, inputs=[inputs[-1]])
|
62 |
+
inputs[0].change(fn=change_to_url, outputs=inputs[1], inputs=inputs[0])
|
63 |
+
|
64 |
+
interact = gr.Interface(
|
65 |
+
title="GEM: Grounding Everything Module (link to paper/code)",
|
66 |
+
description="Grounding Everything: Emerging Localization Properties in Vision-Language Transformers",
|
67 |
+
fn=viz_func,
|
68 |
+
inputs=inputs,
|
69 |
+
outputs=["image"],
|
70 |
+
)
|
71 |
+
|
72 |
+
gr.Examples(
|
73 |
+
[
|
74 |
+
["assets/cats_remote_control.jpeg", "cat"],
|
75 |
+
["assets/cats_remote_control.jpeg", "remote control"],
|
76 |
+
["assets/elon_jeff_mark.jpeg", "elon musk"],
|
77 |
+
["assets/elon_jeff_mark.jpeg", "mark zuckerberg"],
|
78 |
+
["assets/elon_jeff_mark.jpeg", "jeff bezos"],
|
79 |
+
],
|
80 |
+
[inputs[1], inputs[2]]
|
81 |
+
)
|
82 |
+
|
83 |
+
# demo.launch(server_port=5152)
|
84 |
+
demo.launch()
|
assets/cats_remote_control.jpeg
ADDED
Git LFS Details
|
assets/elon_jeff_mark.jpeg
ADDED
Git LFS Details
|
gem/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .gem import *
|
gem/gem.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Union, List, Optional, Tuple, Dict
|
3 |
+
import open_clip
|
4 |
+
from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torchvision import transforms
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
import cv2 as cv2
|
12 |
+
|
13 |
+
from .gem_wrapper import GEMWrapper
|
14 |
+
|
15 |
+
|
16 |
+
_MODELS = {
|
17 |
+
# B/32
|
18 |
+
"ViT-B/32": [
|
19 |
+
"openai",
|
20 |
+
"laion400m_e31",
|
21 |
+
"laion400m_e32",
|
22 |
+
"laion2b_e16",
|
23 |
+
"laion2b_s34b_b79k",
|
24 |
+
],
|
25 |
+
|
26 |
+
"ViT-B/32-quickgelu": [
|
27 |
+
"metaclip_400m",
|
28 |
+
"metaclip_fullcc"
|
29 |
+
],
|
30 |
+
# B/16
|
31 |
+
"ViT-B/16": [
|
32 |
+
"openai",
|
33 |
+
"laion400m_e31",
|
34 |
+
"laion400m_e32",
|
35 |
+
"laion2b_s34b_b88k",
|
36 |
+
],
|
37 |
+
"ViT-B/16-quickgelu": [
|
38 |
+
"metaclip_400m",
|
39 |
+
"metaclip_fullcc",
|
40 |
+
],
|
41 |
+
"ViT-B/16-plus-240": [
|
42 |
+
"laion400m_e31",
|
43 |
+
"laion400m_e32"
|
44 |
+
],
|
45 |
+
# L/14
|
46 |
+
"ViT-L/14": [
|
47 |
+
"openai",
|
48 |
+
"laion400m_e31",
|
49 |
+
"laion400m_e32",
|
50 |
+
"laion2b_s32b_b82k",
|
51 |
+
],
|
52 |
+
"ViT-L/14-quickgelu": [
|
53 |
+
"metaclip_400m",
|
54 |
+
"metaclip_fullcc"
|
55 |
+
],
|
56 |
+
"ViT-L/14-336": [
|
57 |
+
"openai",
|
58 |
+
]
|
59 |
+
}
|
60 |
+
|
61 |
+
def available_models() -> List[str]:
|
62 |
+
"""Returns the names of available GEM-VL models"""
|
63 |
+
# _str = "".join([": ".join([key, value]) + "\n" for key, values in _MODELS2.items() for value in values])
|
64 |
+
_str = "".join([": ".join([key + " "*(20 - len(key)), value]) + "\n" for key, values in _MODELS.items() for value in values])
|
65 |
+
return _str
|
66 |
+
|
67 |
+
def get_tokenizer(
|
68 |
+
model_name: str = '',
|
69 |
+
context_length: Optional[int] = None,
|
70 |
+
**kwargs,
|
71 |
+
):
|
72 |
+
""" Wrapper around openclip get_tokenizer function """
|
73 |
+
return open_clip.get_tokenizer(model_name=model_name, context_length=context_length, **kwargs)
|
74 |
+
|
75 |
+
|
76 |
+
def get_gem_img_transform(
|
77 |
+
img_size: Union[int, Tuple[int, int]] = (448, 448),
|
78 |
+
mean: Optional[Tuple[float, ...]] = None,
|
79 |
+
std: Optional[Tuple[float, ...]] = None,
|
80 |
+
):
|
81 |
+
mean = mean or OPENAI_DATASET_MEAN
|
82 |
+
std = std or OPENAI_DATASET_STD
|
83 |
+
transform = transforms.Compose([
|
84 |
+
transforms.Resize(size=img_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
85 |
+
transforms.ToTensor(),
|
86 |
+
transforms.Normalize(mean, std),
|
87 |
+
])
|
88 |
+
return transform
|
89 |
+
|
90 |
+
|
91 |
+
def create_gem_model(
|
92 |
+
model_name: str,
|
93 |
+
pretrained: Optional[str] = None,
|
94 |
+
gem_depth: int = 7,
|
95 |
+
ss_attn_iter: int = 1,
|
96 |
+
ss_attn_temp: Optional[float] = None,
|
97 |
+
precision: str = 'fp32',
|
98 |
+
device: Union[str, torch.device] = 'cpu',
|
99 |
+
jit: bool = False,
|
100 |
+
force_quick_gelu: bool = False,
|
101 |
+
force_custom_text: bool = False,
|
102 |
+
force_patch_dropout: Optional[float] = None,
|
103 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
104 |
+
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
105 |
+
pretrained_image: bool = False,
|
106 |
+
pretrained_hf: bool = True,
|
107 |
+
cache_dir: Optional[str] = None,
|
108 |
+
output_dict: Optional[bool] = None,
|
109 |
+
require_pretrained: bool = False,
|
110 |
+
**model_kwargs,
|
111 |
+
):
|
112 |
+
model_name = model_name.replace("/", "-")
|
113 |
+
logging.info(f'Loading pretrained {model_name} from pretrained weights {pretrained}...')
|
114 |
+
open_clip_model = open_clip.create_model(model_name, pretrained, precision, device, jit, force_quick_gelu, force_custom_text,
|
115 |
+
force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
|
116 |
+
pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)
|
117 |
+
tokenizer = open_clip.get_tokenizer(model_name=model_name)
|
118 |
+
|
119 |
+
gem_model = GEMWrapper(model=open_clip_model, tokenizer=tokenizer, depth=gem_depth,
|
120 |
+
ss_attn_iter=ss_attn_iter, ss_attn_temp=ss_attn_temp)
|
121 |
+
logging.info(f'Loaded GEM-{model_name} from pretrained weights {pretrained}!')
|
122 |
+
return gem_model
|
123 |
+
|
124 |
+
def create_model_and_transforms(
|
125 |
+
model_name: str,
|
126 |
+
pretrained: Optional[str] = None,
|
127 |
+
gem_depth: int = 7,
|
128 |
+
precision: str = 'fp32',
|
129 |
+
device: Union[str, torch.device] = 'cpu',
|
130 |
+
jit: bool = False,
|
131 |
+
force_quick_gelu: bool = False,
|
132 |
+
force_custom_text: bool = False,
|
133 |
+
force_patch_dropout: Optional[float] = None,
|
134 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
135 |
+
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
136 |
+
pretrained_image: bool = False,
|
137 |
+
pretrained_hf: bool = True,
|
138 |
+
cache_dir: Optional[str] = None,
|
139 |
+
output_dict: Optional[bool] = None,
|
140 |
+
require_pretrained: bool = False,
|
141 |
+
**model_kwargs,
|
142 |
+
):
|
143 |
+
gem_model = create_gem_model(model_name, pretrained, gem_depth, precision, device, jit, force_quick_gelu, force_custom_text,
|
144 |
+
force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
|
145 |
+
pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)
|
146 |
+
|
147 |
+
transform = get_gem_img_transform(**model_kwargs)
|
148 |
+
return gem_model, transform
|
149 |
+
|
150 |
+
def visualize(image, text, logits, alpha=0.6, save_path=None):
|
151 |
+
W, H = logits.shape[-2:]
|
152 |
+
if isinstance(image, Image.Image):
|
153 |
+
image = image.resize((W, H))
|
154 |
+
elif isinstance(image, torch.Tensor):
|
155 |
+
if image.ndim > 3:
|
156 |
+
image = image.squeeze(0)
|
157 |
+
image_unormed = (image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]) \
|
158 |
+
+ torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None] # undo the normalization
|
159 |
+
image = Image.fromarray((image_unormed.permute(1, 2, 0).numpy() * 255).astype('uint8')) # convert to PIL
|
160 |
+
else:
|
161 |
+
raise f'image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}'
|
162 |
+
|
163 |
+
# plot image
|
164 |
+
plt.imshow(image)
|
165 |
+
plt.axis('off')
|
166 |
+
plt.tight_layout()
|
167 |
+
plt.show()
|
168 |
+
|
169 |
+
if logits.ndim > 3:
|
170 |
+
logits = logits.squeeze(0)
|
171 |
+
logits = logits.detach().cpu().numpy()
|
172 |
+
|
173 |
+
|
174 |
+
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
175 |
+
logits = (logits * 255).astype('uint8')
|
176 |
+
heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logits]
|
177 |
+
|
178 |
+
vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps]
|
179 |
+
for viz, cls_name in zip(vizs, text):
|
180 |
+
|
181 |
+
viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB)
|
182 |
+
plt.imshow(viz)
|
183 |
+
plt.title(cls_name)
|
184 |
+
plt.axis('off')
|
185 |
+
plt.tight_layout()
|
186 |
+
plt.show()
|
187 |
+
if save_path is not None:
|
188 |
+
plt.savefig(f'heatmap_{cls_name}.png')
|
gem/gem_utils.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from open_clip.transformer import _expand_token, to_2tuple
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def resample_abs_pos_embed(
|
13 |
+
posemb,
|
14 |
+
new_size: List[int],
|
15 |
+
old_size: Optional[List[int]] = None,
|
16 |
+
num_prefix_tokens: int = 1,
|
17 |
+
interpolation: str = 'bicubic',
|
18 |
+
antialias: bool = True
|
19 |
+
):
|
20 |
+
# sort out sizes, assume square if old size not provided
|
21 |
+
new_size = to_2tuple(new_size)
|
22 |
+
new_ntok = new_size[0] * new_size[1]
|
23 |
+
if not old_size:
|
24 |
+
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
|
25 |
+
old_size = to_2tuple(old_size)
|
26 |
+
if new_size == old_size: # might not both be same container type
|
27 |
+
return posemb
|
28 |
+
|
29 |
+
if num_prefix_tokens:
|
30 |
+
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
|
31 |
+
else:
|
32 |
+
posemb_prefix, posemb = None, posemb
|
33 |
+
|
34 |
+
# do the interpolation
|
35 |
+
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
36 |
+
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
37 |
+
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
|
38 |
+
|
39 |
+
|
40 |
+
# add back extra (class, etc) prefix tokens
|
41 |
+
if posemb_prefix is not None:
|
42 |
+
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
43 |
+
return posemb
|
44 |
+
|
45 |
+
class SelfSelfAttention(nn.Module):
|
46 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ss_attn_iter=1,
|
47 |
+
ss_attn_temp=None):
|
48 |
+
super().__init__()
|
49 |
+
self.num_heads = num_heads
|
50 |
+
head_dim = dim // num_heads
|
51 |
+
self.scale = qk_scale or head_dim ** -0.5
|
52 |
+
self.ss_attn_iter = ss_attn_iter
|
53 |
+
self.ss_attn_temp = ss_attn_temp
|
54 |
+
|
55 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
56 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
57 |
+
self.proj = nn.Linear(dim, dim)
|
58 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
59 |
+
|
60 |
+
def forward(self, x, attn_bias=None, prev_attn=None):
|
61 |
+
x = x.transpose(0, 1)
|
62 |
+
B, N, C = x.shape
|
63 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
64 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
65 |
+
self.v_values = v
|
66 |
+
# original self-attention for the original path
|
67 |
+
attn_ori_return = (q @ k.transpose(-2, -1)) * self.scale
|
68 |
+
attn_ori = attn_ori_return.softmax(dim=-1)
|
69 |
+
attn_ori = self.attn_drop(attn_ori)
|
70 |
+
|
71 |
+
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
|
72 |
+
x_ori = self.proj_drop(self.proj(x_ori))
|
73 |
+
|
74 |
+
# GEM
|
75 |
+
xs1 = v
|
76 |
+
xs2 = k
|
77 |
+
xs3 = q
|
78 |
+
|
79 |
+
if self.ss_attn_temp is None:
|
80 |
+
pre_norm = torch.norm(x, dim=-1).mean(dim=-1, keepdim=True).unsqueeze(1).unsqueeze(-1)
|
81 |
+
inv_temp = pre_norm * self.scale
|
82 |
+
else:
|
83 |
+
inv_temp = self.ss_attn_temp
|
84 |
+
|
85 |
+
for it in range(self.ss_attn_iter):
|
86 |
+
xs1 = F.normalize(xs1, dim=-1)
|
87 |
+
xs2 = F.normalize(xs2, dim=-1)
|
88 |
+
xs3 = F.normalize(xs3, dim=-1)
|
89 |
+
|
90 |
+
attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp
|
91 |
+
attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp
|
92 |
+
attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp
|
93 |
+
|
94 |
+
attn1 = (attn_return1).softmax(dim=-1)
|
95 |
+
attn2 = (attn_return2).softmax(dim=-1)
|
96 |
+
attn3 = (attn_return3).softmax(dim=-1)
|
97 |
+
|
98 |
+
xs1 = attn1 @ xs1
|
99 |
+
xs2 = attn2 @ xs2
|
100 |
+
xs3 = attn3 @ xs3
|
101 |
+
|
102 |
+
# Assigment to V
|
103 |
+
xs1 = F.normalize(xs1, dim=-1)
|
104 |
+
xs2 = F.normalize(xs2, dim=-1)
|
105 |
+
xs3 = F.normalize(xs3, dim=-1)
|
106 |
+
|
107 |
+
attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp
|
108 |
+
attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp
|
109 |
+
attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp
|
110 |
+
|
111 |
+
attn1 = (attn_return1).softmax(dim=-1)
|
112 |
+
attn2 = (attn_return2).softmax(dim=-1)
|
113 |
+
attn3 = (attn_return3).softmax(dim=-1)
|
114 |
+
|
115 |
+
xs1 = attn1 @ v
|
116 |
+
xs2 = attn2 @ v
|
117 |
+
xs3 = attn3 @ v
|
118 |
+
xs = (xs1 + xs2 + xs3) / 3
|
119 |
+
|
120 |
+
x = xs.transpose(1, 2).reshape(B, N, C)
|
121 |
+
x = self.proj_drop(self.proj(x))
|
122 |
+
|
123 |
+
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
|
124 |
+
|
125 |
+
|
126 |
+
class GEMResidualBlock(nn.Module):
|
127 |
+
def __init__(self, res_block):
|
128 |
+
super(GEMResidualBlock, self).__init__()
|
129 |
+
self.res_block = res_block
|
130 |
+
|
131 |
+
def forward(self,
|
132 |
+
q_x: torch.Tensor,
|
133 |
+
k_x: Optional[torch.Tensor] = None,
|
134 |
+
v_x: Optional[torch.Tensor] = None,
|
135 |
+
attn_mask: Optional[torch.Tensor] = None,
|
136 |
+
):
|
137 |
+
if isinstance(q_x, list):
|
138 |
+
x_gem, q_x = q_x
|
139 |
+
else:
|
140 |
+
x_gem = q_x
|
141 |
+
|
142 |
+
x_gem_res, x_ori_res = self.res_block.attn(x=self.res_block.ln_1(q_x))
|
143 |
+
x_gem_res, x_ori_res = self.res_block.ls_1(x_gem_res), self.res_block.ls_1(x_ori_res)
|
144 |
+
# Original
|
145 |
+
x_ori = q_x + x_ori_res
|
146 |
+
x_ori = x_ori + self.res_block.ls_2(self.res_block.mlp(self.res_block.ln_2(x_ori)))
|
147 |
+
# GEM
|
148 |
+
x_gem = x_gem + x_gem_res
|
149 |
+
return [x_gem, x_ori]
|
150 |
+
|
151 |
+
class GEMViT(nn.Module):
|
152 |
+
def __init__(self, vit):
|
153 |
+
self.vit = vit
|
154 |
+
|
155 |
+
def modified_vit_forward(self, x: torch.Tensor):
|
156 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
157 |
+
grid_h, grid_w = x.shape[2:]
|
158 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
159 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
160 |
+
|
161 |
+
# class embeddings and positional embeddings
|
162 |
+
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
|
163 |
+
# shape = [*, grid ** 2 + 1, width]
|
164 |
+
|
165 |
+
if x.shape[1] != self.positional_embedding.shape[1]:
|
166 |
+
pos_emb = resample_abs_pos_embed(self.positional_embedding.unsqueeze(0),
|
167 |
+
new_size=[grid_h, grid_w],
|
168 |
+
# old_size=list(self.grid_size),
|
169 |
+
num_prefix_tokens=1,
|
170 |
+
interpolation='bicubic',
|
171 |
+
antialias=True)
|
172 |
+
|
173 |
+
else:
|
174 |
+
pos_emb = self.positional_embedding
|
175 |
+
|
176 |
+
x = x + pos_emb.to(x.dtype)
|
177 |
+
# x = x + self.positional_embedding.to(x.dtype)
|
178 |
+
|
179 |
+
x = self.patch_dropout(x)
|
180 |
+
x = self.ln_pre(x)
|
181 |
+
|
182 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
183 |
+
x_gem, x = self.transformer(x)
|
184 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
185 |
+
x_gem = x_gem.permute(1, 0, 2) # LND -> NLD
|
186 |
+
|
187 |
+
# Apply proj
|
188 |
+
x = self.ln_post(x)
|
189 |
+
x_gem = self.ln_post(x_gem)
|
190 |
+
if self.proj is not None:
|
191 |
+
x = x @ self.proj
|
192 |
+
x_gem = x_gem @ self.proj
|
193 |
+
|
194 |
+
return [x_gem, x]
|
gem/gem_wrapper.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from open_clip.transformer import VisionTransformer
|
7 |
+
|
8 |
+
from .gem_utils import SelfSelfAttention, GEMResidualBlock, modified_vit_forward
|
9 |
+
|
10 |
+
|
11 |
+
class GEMWrapper(nn.Module):
|
12 |
+
def __init__(self, model, tokenizer, depth=7, ss_attn_iter=1, ss_attn_temp=None):
|
13 |
+
super(GEMWrapper, self).__init__()
|
14 |
+
self.model = model
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.depth = depth
|
17 |
+
self.ss_attn_iter = ss_attn_iter
|
18 |
+
self.ss_attn_temp = ss_attn_temp
|
19 |
+
self.patch_size = self.model.visual.patch_size[0]
|
20 |
+
self.apply_gem()
|
21 |
+
|
22 |
+
def apply_gem(self):
|
23 |
+
for i in range(1, self.depth):
|
24 |
+
# Extract info from the original ViT
|
25 |
+
num_heads = self.model.visual.transformer.resblocks[-i].attn.num_heads
|
26 |
+
dim = int(self.model.visual.transformer.resblocks[-i].attn.head_dim * num_heads)
|
27 |
+
qkv_bias = True
|
28 |
+
# Init the self-self attention layer
|
29 |
+
ss_attn = SelfSelfAttention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
30 |
+
ss_attn_iter=self.ss_attn_iter, ss_attn_temp=self.ss_attn_temp)
|
31 |
+
# Copy necessary weights
|
32 |
+
ss_attn.qkv.weight.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_weight.clone()
|
33 |
+
ss_attn.qkv.bias.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_bias.clone()
|
34 |
+
ss_attn.proj.weight.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.weight.clone()
|
35 |
+
ss_attn.proj.bias.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.bias.clone()
|
36 |
+
# Swap the original Attention with our SelfSelfAttention
|
37 |
+
self.model.visual.transformer.resblocks[-i].attn = ss_attn
|
38 |
+
# Wrap Residual block to handle SelfSelfAttention outputs
|
39 |
+
self.model.visual.transformer.resblocks[-i] = GEMResidualBlock(self.model.visual.transformer.resblocks[-i])
|
40 |
+
# Modify ViT's forward function
|
41 |
+
self.model.visual.forward = modified_vit_forward.__get__(self.model.visual, VisionTransformer)
|
42 |
+
return
|
43 |
+
|
44 |
+
def encode_text(self, text: list):
|
45 |
+
prompts = [f'a photo of a {cls}.' for cls in text]
|
46 |
+
tokenized_prompts = self.tokenizer(prompts).to(self.model.visual.proj.device)
|
47 |
+
text_embedding = self.model.encode_text(tokenized_prompts)
|
48 |
+
text_embedding = F.normalize(text_embedding, dim=-1)
|
49 |
+
return text_embedding.unsqueeze(0)
|
50 |
+
|
51 |
+
def min_max(self, logits):
|
52 |
+
B, num_prompt = logits.shape[:2]
|
53 |
+
logits_min = logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
54 |
+
logits_max = logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
55 |
+
logits = (logits - logits_min) / (logits_max - logits_min)
|
56 |
+
return logits
|
57 |
+
|
58 |
+
def forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
|
59 |
+
"""
|
60 |
+
:param image: torch.Tensor [1, 3, H, W]
|
61 |
+
:param text: list[]
|
62 |
+
:param normalize: bool - if True performs min-max normalization
|
63 |
+
:param return_ori: bool - if True uses the features from the original visual encoder
|
64 |
+
"""
|
65 |
+
# Image
|
66 |
+
W, H = image.shape[-2:]
|
67 |
+
feat_gem, feat_ori = self.model.visual(image)
|
68 |
+
image_feat = feat_ori if return_ori else feat_gem
|
69 |
+
image_feat = F.normalize(image_feat, dim=-1) # [1, N, dim]
|
70 |
+
|
71 |
+
# Text
|
72 |
+
text_embeddings = self.encode_text(text) # [1, num_prompt, dim]
|
73 |
+
|
74 |
+
# Image-Text matching
|
75 |
+
img_txt_matching = image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [1, N, num_prompt]
|
76 |
+
img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
|
77 |
+
w=W//self.patch_size, h=H//self.patch_size) # [1, num_prompt, w, h]
|
78 |
+
|
79 |
+
# Interpolate
|
80 |
+
img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [1, num_prompt, W, H]
|
81 |
+
|
82 |
+
# Heat Maps
|
83 |
+
if normalize:
|
84 |
+
img_txt_matching = self.min_max(img_txt_matching)
|
85 |
+
return img_txt_matching
|
86 |
+
|
87 |
+
def batched_forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
|
88 |
+
"""
|
89 |
+
:param image: torch.Tensor [B, 3, H, W]
|
90 |
+
:param text: list[list[]]
|
91 |
+
:param normalize: bool - if True performs min-max normalization
|
92 |
+
:param return_ori: bool - if True uses the features from the original visual encoder
|
93 |
+
"""
|
94 |
+
L = len(text)
|
95 |
+
cumm_idx = np.cumsum([len(t) for t in text]).tolist()
|
96 |
+
B, _, W, H = image.shape
|
97 |
+
assert B == L, f'Number of prompts L: {L} should be the same as number of images B: {B}.'
|
98 |
+
|
99 |
+
# Image
|
100 |
+
feat_gem, feat_ori = self.model.visual(image)
|
101 |
+
image_feat = feat_ori if return_ori else feat_gem
|
102 |
+
image_feat = F.normalize(image_feat, dim=-1) # [B, N, dim]
|
103 |
+
|
104 |
+
# Text
|
105 |
+
flatten_text = [t for sub_text in text for t in sub_text]
|
106 |
+
text_embeddings = self.encode_text(flatten_text) # [B, num_prompt, dim]
|
107 |
+
|
108 |
+
# Image-Text matching
|
109 |
+
img_txt_matching = 100 * image_feat[:, 1:] @ text_embeddings.transpose(-1, -2) # [B, N, num_prompt]
|
110 |
+
img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
|
111 |
+
w=W // self.patch_size, h=H // self.patch_size) # [B, num_prompt, w, h]
|
112 |
+
|
113 |
+
# Interpolate
|
114 |
+
img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear') # [B,num_prompt, W, H]
|
115 |
+
|
116 |
+
# Heat Maps
|
117 |
+
if normalize:
|
118 |
+
img_txt_matching = self.min_max(img_txt_matching) # [B,num_prompt, W, H]
|
119 |
+
|
120 |
+
# unflatten
|
121 |
+
img_txt_matching = torch.tensor_split(img_txt_matching, cumm_idx[:-1], dim=1)
|
122 |
+
img_txt_matching = [itm[i] for i, itm in enumerate(img_txt_matching)]
|
123 |
+
return img_txt_matching
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.9.0
|
2 |
+
torchvision
|
3 |
+
regex
|
4 |
+
ftfy
|
5 |
+
tqdm
|
6 |
+
huggingface_hub
|
7 |
+
sentencepiece
|
8 |
+
protobuf
|
9 |
+
timm
|
10 |
+
einops
|
11 |
+
open_clip_torch
|
12 |
+
opencv-python
|
13 |
+
matplotlib
|
14 |
+
numpy
|
15 |
+
requests
|
setup.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Setup
|
2 |
+
Adapted from https://github.com/mlfoundations/open_clip
|
3 |
+
"""
|
4 |
+
from setuptools import setup, find_packages
|
5 |
+
from codecs import open
|
6 |
+
from os import path
|
7 |
+
|
8 |
+
here = path.abspath(path.dirname(__file__))
|
9 |
+
|
10 |
+
# Get the long description from the README file
|
11 |
+
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
12 |
+
long_description = f.read()
|
13 |
+
|
14 |
+
def _read_reqs(relpath):
|
15 |
+
fullpath = path.join(path.dirname(__file__), relpath)
|
16 |
+
with open(fullpath) as f:
|
17 |
+
return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))]
|
18 |
+
|
19 |
+
REQUIREMENTS = _read_reqs("requirements.txt")
|
20 |
+
|
21 |
+
setup(
|
22 |
+
name='gem_torch',
|
23 |
+
version="1.0",
|
24 |
+
description='GEM',
|
25 |
+
long_description=long_description,
|
26 |
+
long_description_content_type='text/markdown',
|
27 |
+
url='https://github.com/WalBouss/GEM',
|
28 |
+
author='Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne',
|
29 |
+
author_email='',
|
30 |
+
classifiers=[
|
31 |
+
# How mature is this project? Common values are
|
32 |
+
# 3 - Alpha
|
33 |
+
# 4 - Beta
|
34 |
+
# 5 - Production/Stable
|
35 |
+
'Development Status :: 3 - Alpha',
|
36 |
+
'Intended Audience :: Education',
|
37 |
+
'Intended Audience :: Science/Research',
|
38 |
+
'License :: OSI Approved :: Apache Software License',
|
39 |
+
'Programming Language :: Python :: 3.7',
|
40 |
+
'Programming Language :: Python :: 3.8',
|
41 |
+
'Programming Language :: Python :: 3.9',
|
42 |
+
'Programming Language :: Python :: 3.10',
|
43 |
+
'Topic :: Scientific/Engineering',
|
44 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
45 |
+
'Topic :: Software Development',
|
46 |
+
'Topic :: Software Development :: Libraries',
|
47 |
+
'Topic :: Software Development :: Libraries :: Python Modules',
|
48 |
+
],
|
49 |
+
|
50 |
+
# Note that this is a string of words separated by whitespace, not a list.
|
51 |
+
keywords='CLIP pretrained',
|
52 |
+
py_modules=["gem"],
|
53 |
+
packages=find_packages(exclude=["assets*"]),
|
54 |
+
include_package_data=True,
|
55 |
+
install_requires=REQUIREMENTS,
|
56 |
+
python_requires='>=3.7',
|
57 |
+
)
|
test_examples.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from gem import create_gem_model, get_gem_img_transform, visualize, available_models
|
3 |
+
import torch
|
4 |
+
import requests
|
5 |
+
|
6 |
+
|
7 |
+
print(available_models())
|
8 |
+
|
9 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
+
model_name = 'ViT-B-16-quickgelu'
|
11 |
+
pretrained = 'metaclip_400m'
|
12 |
+
gem_model = create_gem_model(model_name=model_name, pretrained=pretrained, device=device)
|
13 |
+
gem_model.eval()
|
14 |
+
|
15 |
+
###########################
|
16 |
+
# Single Image
|
17 |
+
###########################
|
18 |
+
|
19 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg" # cat & remote control
|
20 |
+
text = ['remote control', 'cat']
|
21 |
+
# image_path = 'path/to/image' #, <-- uncomment to use path
|
22 |
+
|
23 |
+
image_pil = Image.open(requests.get(url, stream=True).raw)
|
24 |
+
# image_pil = Image.open(image_path) # <-- uncomment to use path
|
25 |
+
|
26 |
+
gem_img_transform = get_gem_img_transform()
|
27 |
+
image = gem_img_transform(image_pil).unsqueeze(0).to(device)
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
logits = gem_model(image, text)
|
31 |
+
visualize(image, text, logits)
|
32 |
+
print(logits.shape) # torch.Size([1, 2, 448, 448])
|
33 |
+
# visualize(image_pil, text, logits) # <-- works with torch.Tensor and PIL.Image
|
34 |
+
|
35 |
+
###########################
|
36 |
+
# Batch of Images
|
37 |
+
###########################
|
38 |
+
urls = [
|
39 |
+
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
40 |
+
"https://cdn.vietnambiz.vn/171464876016439296/2021/7/11/headshots16170695297430-1626006880779826347793.jpg",
|
41 |
+
"https://preview.redd.it/do-you-think-joker-should-be-unpredictable-enough-to-put-up-v0-6a2ax4ngtlaa1.jpg?auto=webp&s=f8762e6a1b40642bcae5900bac184fc597131503",
|
42 |
+
]
|
43 |
+
texts = [
|
44 |
+
['remote control', 'cat'],
|
45 |
+
['elon musk', 'mark zuckerberg', 'jeff bezos', 'bill gates'],
|
46 |
+
['batman', 'joker', 'shoe', 'belt', 'purple suit'],
|
47 |
+
] # note that the number of prompt per image can be different
|
48 |
+
|
49 |
+
# download images + convert to PIL.Image
|
50 |
+
images_pil = [Image.open(requests.get(url, stream=True).raw) for url in urls]
|
51 |
+
images = torch.stack([gem_img_transform(img) for img in images_pil]).to(device)
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
# return list with logits of size [1, num_prompt, W, H]
|
55 |
+
logits_list = gem_model.batched_forward(images, texts)
|
56 |
+
print(logits_list[0].shape) # torch.Size([2, 448, 448])
|
57 |
+
print(logits_list[1].shape) # torch.Size([4, 448, 448])
|
58 |
+
print(logits_list[2].shape) # torch.Size([5, 448, 448])
|
59 |
+
for i, _logits in enumerate(logits_list):
|
60 |
+
visualize(images[i], texts[i], _logits) # (optional visualization)
|