Spaces:
Running
Running
fix streamlit issues and update localization
Browse files- localization.py +121 -51
- modeling_hybrid_clip.py +3 -1
- requirements.txt +1 -1
localization.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from text2image import get_model, get_tokenizer, get_image_transform
|
3 |
from utils import text_encoder
|
4 |
-
from
|
5 |
from PIL import Image
|
6 |
from jax import numpy as jnp
|
7 |
import pandas as pd
|
@@ -13,30 +13,34 @@ import jax
|
|
13 |
import gc
|
14 |
|
15 |
|
16 |
-
preprocess =
|
17 |
-
[
|
18 |
-
transforms.ToTensor(),
|
19 |
-
transforms.Normalize(
|
20 |
-
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
|
21 |
-
),
|
22 |
-
]
|
23 |
-
)
|
24 |
|
25 |
|
26 |
-
def
|
27 |
-
|
28 |
-
|
|
|
29 |
image = image.resize(new_size, Image.ANTIALIAS)
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
def image_encoder(image, model):
|
36 |
image = np.transpose(image, (0, 2, 3, 1))
|
37 |
features = model.get_image_features(image)
|
38 |
-
|
39 |
-
|
|
|
40 |
|
41 |
|
42 |
def gen_image_batch(image_url, image_size=224, pixel_size=10):
|
@@ -44,64 +48,130 @@ def gen_image_batch(image_url, image_size=224, pixel_size=10):
|
|
44 |
|
45 |
image_batch = []
|
46 |
masks = []
|
|
|
|
|
|
|
47 |
image_raw = requests.get(image_url, stream=True).raw
|
48 |
image = Image.open(image_raw).convert("RGB")
|
49 |
-
image =
|
50 |
-
gray = np.ones_like(image) *
|
51 |
-
mask = np.ones_like(image)
|
52 |
|
53 |
image_batch.append(image)
|
54 |
masks.append(mask)
|
|
|
|
|
|
|
55 |
|
56 |
-
for i in range(0,
|
57 |
-
for j in range(i
|
58 |
m = mask.copy()
|
59 |
-
m[:
|
60 |
-
m[min(j
|
61 |
neg_m = 1 - m
|
62 |
-
image_batch.append(image * m + gray * neg_m)
|
63 |
masks.append(m)
|
|
|
|
|
64 |
|
65 |
-
for i in range(0,
|
66 |
-
for j in range(i
|
67 |
m = mask.copy()
|
68 |
-
m[:, :
|
69 |
-
m[:, min(j
|
70 |
neg_m = 1 - m
|
71 |
-
image_batch.append(image * m + gray * neg_m)
|
72 |
masks.append(m)
|
|
|
|
|
73 |
|
74 |
-
return image_batch, masks
|
75 |
|
76 |
|
77 |
def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
78 |
-
tokenizer = get_tokenizer()
|
79 |
model = get_model()
|
80 |
image_size = model.config.vision_config.image_size
|
81 |
-
text_embedding = text_encoder(text, model, tokenizer)
|
82 |
-
images, masks = gen_image_batch(
|
83 |
-
image_url, image_size=image_size, pixel_size=pixel_size
|
84 |
-
)
|
85 |
|
|
|
86 |
input_image = images[0].copy()
|
87 |
-
images = np.stack([preprocess(image) for image in images], axis=0)
|
88 |
-
image_embeddings = jnp.asarray(image_encoder(images, model))
|
89 |
-
|
90 |
-
sims = []
|
91 |
-
scores = []
|
92 |
-
mask_val = jnp.zeros_like(masks[0])
|
93 |
-
|
94 |
-
for e, m in zip(image_embeddings, masks):
|
95 |
-
sim = jnp.matmul(e, text_embedding.T)
|
96 |
-
sims.append(sim)
|
97 |
-
if len(sims) > 1:
|
98 |
-
scores.append(sim * m)
|
99 |
-
mask_val += 1 - m
|
100 |
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
for i in range(iterations):
|
103 |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
|
|
|
104 |
score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
|
|
|
|
|
|
|
105 |
return np.asarray(score), input_image
|
106 |
|
107 |
|
@@ -144,7 +214,7 @@ def app():
|
|
144 |
with col2:
|
145 |
pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2)
|
146 |
|
147 |
-
iterations = st.selectbox("Refinement Steps", options=range(
|
148 |
|
149 |
compute = st.button("LOCATE")
|
150 |
|
|
|
1 |
import streamlit as st
|
2 |
from text2image import get_model, get_tokenizer, get_image_transform
|
3 |
from utils import text_encoder
|
4 |
+
from transformers import AutoProcessor
|
5 |
from PIL import Image
|
6 |
from jax import numpy as jnp
|
7 |
import pandas as pd
|
|
|
13 |
import gc
|
14 |
|
15 |
|
16 |
+
preprocess = AutoProcessor.from_pretrained("clip-italian/clip-italian")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
+
def resize_longer(image, longer_size=224):
|
20 |
+
old_size = image.size
|
21 |
+
ratio = float(longer_size) / max(old_size)
|
22 |
+
new_size = tuple([int(x * ratio) for x in old_size])
|
23 |
image = image.resize(new_size, Image.ANTIALIAS)
|
24 |
+
return image
|
25 |
+
|
26 |
+
|
27 |
+
def pad_to_square(image):
|
28 |
+
(a,b)=image.shape[:2]
|
29 |
+
if a<b:
|
30 |
+
ah = (b - a) // 2
|
31 |
+
padding=((ah,b - a -ah), (0,0), (0,0))
|
32 |
+
else:
|
33 |
+
bh = (a - b) // 2
|
34 |
+
padding=((0,0), (bh,a-b-bh), (0,0))
|
35 |
+
return np.pad(image, padding,mode='constant',constant_values=127)
|
36 |
|
37 |
|
38 |
def image_encoder(image, model):
|
39 |
image = np.transpose(image, (0, 2, 3, 1))
|
40 |
features = model.get_image_features(image)
|
41 |
+
feature_norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
|
42 |
+
features = features / feature_norms
|
43 |
+
return features, feature_norms
|
44 |
|
45 |
|
46 |
def gen_image_batch(image_url, image_size=224, pixel_size=10):
|
|
|
48 |
|
49 |
image_batch = []
|
50 |
masks = []
|
51 |
+
is_vertical = []
|
52 |
+
is_horizontal = []
|
53 |
+
|
54 |
image_raw = requests.get(image_url, stream=True).raw
|
55 |
image = Image.open(image_raw).convert("RGB")
|
56 |
+
image = np.array(resize_longer(image, longer_size=image_size))
|
57 |
+
gray = np.ones_like(image) * 127
|
58 |
+
mask = np.ones_like(image[:,:,:1])
|
59 |
|
60 |
image_batch.append(image)
|
61 |
masks.append(mask)
|
62 |
+
is_vertical.append(True)
|
63 |
+
is_horizontal.append(True)
|
64 |
+
|
65 |
|
66 |
+
for i in range(0, image.shape[0] // pixel_size + 1):
|
67 |
+
for j in range(i+1, image.shape[0] // pixel_size + 2):
|
68 |
m = mask.copy()
|
69 |
+
m[:min(i*pixel_size, image_size), :] = 0
|
70 |
+
m[min(j*pixel_size, image_size):, :] = 0
|
71 |
neg_m = 1 - m
|
72 |
+
image_batch.append(image.copy() * m + gray * neg_m)
|
73 |
masks.append(m)
|
74 |
+
is_vertical.append(False)
|
75 |
+
is_horizontal.append(True)
|
76 |
|
77 |
+
for i in range(0, image.shape[1] // pixel_size + 1):
|
78 |
+
for j in range(i+1, image.shape[1] // pixel_size + 2):
|
79 |
m = mask.copy()
|
80 |
+
m[:, :min(i*pixel_size, image_size)] = 0
|
81 |
+
m[:, min(j*pixel_size, image_size):] = 0
|
82 |
neg_m = 1 - m
|
83 |
+
image_batch.append(image.copy() * m + gray * neg_m)
|
84 |
masks.append(m)
|
85 |
+
is_vertical.append(True)
|
86 |
+
is_horizontal.append(False)
|
87 |
|
88 |
+
return image_batch, masks, is_vertical, is_horizontal
|
89 |
|
90 |
|
91 |
def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
92 |
+
# tokenizer = get_tokenizer()
|
93 |
model = get_model()
|
94 |
image_size = model.config.vision_config.image_size
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size)
|
97 |
input_image = images[0].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
inputs = preprocess(text=[text], images=images, return_tensors="np")
|
100 |
+
|
101 |
+
image_embeddings, embedding_norms = image_encoder(inputs['pixel_values'], model)
|
102 |
+
text_embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[0]
|
103 |
+
text_embedding = text_embedding / jnp.linalg.norm(text_embedding, axis=-1, keepdims=True)
|
104 |
+
|
105 |
+
vertical_scores = jnp.zeros((masks[0].shape[1], 512))
|
106 |
+
vertical_masks = jnp.zeros((masks[0].shape[1], 1))
|
107 |
+
horizontal_scores = jnp.zeros((masks[0].shape[0], 512))
|
108 |
+
horizontal_masks = jnp.zeros((masks[0].shape[0], 1))
|
109 |
+
|
110 |
+
for e, n, m, v, h in zip(image_embeddings, embedding_norms, masks, vertical, horizontal):
|
111 |
+
# sim = (jnp.matmul(e, text_embedding.T)) # + 1) / 2
|
112 |
+
|
113 |
+
# sim = jax.nn.relu(sim)
|
114 |
+
|
115 |
+
# if full_sim is None:
|
116 |
+
# full_sim = sim
|
117 |
+
# sim = jax.nn.relu(sim - full_sim)
|
118 |
+
emb = jnp.expand_dims(e, axis=0) * n
|
119 |
+
|
120 |
+
if v:
|
121 |
+
vm = jnp.any(m, axis=0)
|
122 |
+
vertical_scores = vertical_scores + (emb * vm) #/ jnp.mean(vm)
|
123 |
+
vertical_masks = vertical_masks + vm #/ jnp.mean(vm)
|
124 |
+
if h:
|
125 |
+
hm = jnp.any(m, axis=1)
|
126 |
+
horizontal_scores = horizontal_scores + (emb * hm) #/ jnp.mean(hm)
|
127 |
+
horizontal_masks = horizontal_masks + hm #/ jnp.mean(hm)
|
128 |
+
|
129 |
+
|
130 |
+
embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1)
|
131 |
+
embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
|
132 |
+
full_embs = jnp.minimum(embs_1, embs_2)
|
133 |
+
mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
|
134 |
+
|
135 |
+
print(full_embs.shape)
|
136 |
+
|
137 |
+
#full_embs = full_embs / jnp.linalg.norm(full_embs, axis=-1, keepdims=True)
|
138 |
+
full_embs = (full_embs / mask_sum)
|
139 |
+
|
140 |
+
orig_shape = full_embs.shape
|
141 |
+
sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)), text_embedding.T)
|
142 |
+
sims = jnp.reshape(sims, (*orig_shape[:2], 1))
|
143 |
+
#sims = jax.nn.relu(sims)
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
# mean_vertical_scores = vertical_scores / vertical_masks
|
151 |
+
# mean_horizontal_scores = horizontal_scores / horizontal_masks
|
152 |
+
|
153 |
+
# print(mean_vertical_score)
|
154 |
+
# print(mean_horizontal_score)
|
155 |
+
|
156 |
+
# score = jnp.matmul(mean_vertical_scores, mean_horizontal_scores.T)
|
157 |
+
|
158 |
+
#mask = jnp.matmul(vertical_masks, horizontal_scores.T)
|
159 |
+
#score = score / mask
|
160 |
+
|
161 |
+
score = sims # jnp.expand_dims(score.T, axis=-1)
|
162 |
+
#score = jax.nn.relu(score) / jnp.max(jnp.abs(score))
|
163 |
+
|
164 |
+
#score = jax.nn.relu(score - sims[0])
|
165 |
+
|
166 |
+
# score = jnp.square(score)
|
167 |
+
|
168 |
for i in range(iterations):
|
169 |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
|
170 |
+
|
171 |
score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
|
172 |
+
|
173 |
+
print(jnp.min(score), jnp.max(score))
|
174 |
+
|
175 |
return np.asarray(score), input_image
|
176 |
|
177 |
|
|
|
214 |
with col2:
|
215 |
pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2)
|
216 |
|
217 |
+
iterations = st.selectbox("Refinement Steps", options=range(1, 6, 1), index=0)
|
218 |
|
219 |
compute = st.button("LOCATE")
|
220 |
|
modeling_hybrid_clip.py
CHANGED
@@ -136,8 +136,10 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
136 |
):
|
137 |
if input_shape is None:
|
138 |
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
|
|
|
|
139 |
|
140 |
-
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
141 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
142 |
|
143 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
|
|
136 |
):
|
137 |
if input_shape is None:
|
138 |
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
139 |
+
|
140 |
+
print(kwargs)
|
141 |
|
142 |
+
module = self.module_class(config=config, dtype=dtype) # , **kwargs)
|
143 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
144 |
|
145 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
requirements.txt
CHANGED
@@ -8,4 +8,4 @@ stqdm
|
|
8 |
pandas
|
9 |
requests
|
10 |
psutil
|
11 |
-
streamlit
|
|
|
8 |
pandas
|
9 |
requests
|
10 |
psutil
|
11 |
+
streamlit
|