4rtemi5 commited on
Commit
e45c79f
·
1 Parent(s): 76826f4

fix streamlit issues and update localization

Browse files
Files changed (3) hide show
  1. localization.py +121 -51
  2. modeling_hybrid_clip.py +3 -1
  3. requirements.txt +1 -1
localization.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from text2image import get_model, get_tokenizer, get_image_transform
3
  from utils import text_encoder
4
- from torchvision import transforms
5
  from PIL import Image
6
  from jax import numpy as jnp
7
  import pandas as pd
@@ -13,30 +13,34 @@ import jax
13
  import gc
14
 
15
 
16
- preprocess = transforms.Compose(
17
- [
18
- transforms.ToTensor(),
19
- transforms.Normalize(
20
- (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
21
- ),
22
- ]
23
- )
24
 
25
 
26
- def pad_to_square(image, size=224):
27
- ratio = float(size) / max(image.size)
28
- new_size = tuple([int(x * ratio) for x in image.size])
 
29
  image = image.resize(new_size, Image.ANTIALIAS)
30
- new_image = Image.new("RGB", size=(size, size), color=(128, 128, 128))
31
- new_image.paste(image, ((size - new_size[0]) // 2, (size - new_size[1]) // 2))
32
- return new_image
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def image_encoder(image, model):
36
  image = np.transpose(image, (0, 2, 3, 1))
37
  features = model.get_image_features(image)
38
- features /= jnp.linalg.norm(features, keepdims=True)
39
- return features
 
40
 
41
 
42
  def gen_image_batch(image_url, image_size=224, pixel_size=10):
@@ -44,64 +48,130 @@ def gen_image_batch(image_url, image_size=224, pixel_size=10):
44
 
45
  image_batch = []
46
  masks = []
 
 
 
47
  image_raw = requests.get(image_url, stream=True).raw
48
  image = Image.open(image_raw).convert("RGB")
49
- image = pad_to_square(image, size=image_size)
50
- gray = np.ones_like(image) * 128
51
- mask = np.ones_like(image)
52
 
53
  image_batch.append(image)
54
  masks.append(mask)
 
 
 
55
 
56
- for i in range(0, n_pixels):
57
- for j in range(i + 1, n_pixels):
58
  m = mask.copy()
59
- m[: min(i * pixel_size, image_size) + 1, :] = 0
60
- m[min(j * pixel_size, image_size) + 1 :, :] = 0
61
  neg_m = 1 - m
62
- image_batch.append(image * m + gray * neg_m)
63
  masks.append(m)
 
 
64
 
65
- for i in range(0, n_pixels + 1):
66
- for j in range(i + 1, n_pixels + 1):
67
  m = mask.copy()
68
- m[:, : min(i * pixel_size + 1, image_size)] = 0
69
- m[:, min(j * pixel_size + 1, image_size) :] = 0
70
  neg_m = 1 - m
71
- image_batch.append(image * m + gray * neg_m)
72
  masks.append(m)
 
 
73
 
74
- return image_batch, masks
75
 
76
 
77
  def get_heatmap(image_url, text, pixel_size=10, iterations=3):
78
- tokenizer = get_tokenizer()
79
  model = get_model()
80
  image_size = model.config.vision_config.image_size
81
- text_embedding = text_encoder(text, model, tokenizer)
82
- images, masks = gen_image_batch(
83
- image_url, image_size=image_size, pixel_size=pixel_size
84
- )
85
 
 
86
  input_image = images[0].copy()
87
- images = np.stack([preprocess(image) for image in images], axis=0)
88
- image_embeddings = jnp.asarray(image_encoder(images, model))
89
-
90
- sims = []
91
- scores = []
92
- mask_val = jnp.zeros_like(masks[0])
93
-
94
- for e, m in zip(image_embeddings, masks):
95
- sim = jnp.matmul(e, text_embedding.T)
96
- sims.append(sim)
97
- if len(sims) > 1:
98
- scores.append(sim * m)
99
- mask_val += 1 - m
100
 
101
- score = jnp.mean(jnp.clip(jnp.array(scores) - sims[0], 0, jnp.inf), axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  for i in range(iterations):
103
  score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
 
104
  score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
 
 
 
105
  return np.asarray(score), input_image
106
 
107
 
@@ -144,7 +214,7 @@ def app():
144
  with col2:
145
  pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2)
146
 
147
- iterations = st.selectbox("Refinement Steps", options=range(3, 30, 3), index=0)
148
 
149
  compute = st.button("LOCATE")
150
 
 
1
  import streamlit as st
2
  from text2image import get_model, get_tokenizer, get_image_transform
3
  from utils import text_encoder
4
+ from transformers import AutoProcessor
5
  from PIL import Image
6
  from jax import numpy as jnp
7
  import pandas as pd
 
13
  import gc
14
 
15
 
16
+ preprocess = AutoProcessor.from_pretrained("clip-italian/clip-italian")
 
 
 
 
 
 
 
17
 
18
 
19
+ def resize_longer(image, longer_size=224):
20
+ old_size = image.size
21
+ ratio = float(longer_size) / max(old_size)
22
+ new_size = tuple([int(x * ratio) for x in old_size])
23
  image = image.resize(new_size, Image.ANTIALIAS)
24
+ return image
25
+
26
+
27
+ def pad_to_square(image):
28
+ (a,b)=image.shape[:2]
29
+ if a<b:
30
+ ah = (b - a) // 2
31
+ padding=((ah,b - a -ah), (0,0), (0,0))
32
+ else:
33
+ bh = (a - b) // 2
34
+ padding=((0,0), (bh,a-b-bh), (0,0))
35
+ return np.pad(image, padding,mode='constant',constant_values=127)
36
 
37
 
38
  def image_encoder(image, model):
39
  image = np.transpose(image, (0, 2, 3, 1))
40
  features = model.get_image_features(image)
41
+ feature_norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
42
+ features = features / feature_norms
43
+ return features, feature_norms
44
 
45
 
46
  def gen_image_batch(image_url, image_size=224, pixel_size=10):
 
48
 
49
  image_batch = []
50
  masks = []
51
+ is_vertical = []
52
+ is_horizontal = []
53
+
54
  image_raw = requests.get(image_url, stream=True).raw
55
  image = Image.open(image_raw).convert("RGB")
56
+ image = np.array(resize_longer(image, longer_size=image_size))
57
+ gray = np.ones_like(image) * 127
58
+ mask = np.ones_like(image[:,:,:1])
59
 
60
  image_batch.append(image)
61
  masks.append(mask)
62
+ is_vertical.append(True)
63
+ is_horizontal.append(True)
64
+
65
 
66
+ for i in range(0, image.shape[0] // pixel_size + 1):
67
+ for j in range(i+1, image.shape[0] // pixel_size + 2):
68
  m = mask.copy()
69
+ m[:min(i*pixel_size, image_size), :] = 0
70
+ m[min(j*pixel_size, image_size):, :] = 0
71
  neg_m = 1 - m
72
+ image_batch.append(image.copy() * m + gray * neg_m)
73
  masks.append(m)
74
+ is_vertical.append(False)
75
+ is_horizontal.append(True)
76
 
77
+ for i in range(0, image.shape[1] // pixel_size + 1):
78
+ for j in range(i+1, image.shape[1] // pixel_size + 2):
79
  m = mask.copy()
80
+ m[:, :min(i*pixel_size, image_size)] = 0
81
+ m[:, min(j*pixel_size, image_size):] = 0
82
  neg_m = 1 - m
83
+ image_batch.append(image.copy() * m + gray * neg_m)
84
  masks.append(m)
85
+ is_vertical.append(True)
86
+ is_horizontal.append(False)
87
 
88
+ return image_batch, masks, is_vertical, is_horizontal
89
 
90
 
91
  def get_heatmap(image_url, text, pixel_size=10, iterations=3):
92
+ # tokenizer = get_tokenizer()
93
  model = get_model()
94
  image_size = model.config.vision_config.image_size
 
 
 
 
95
 
96
+ images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size)
97
  input_image = images[0].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ inputs = preprocess(text=[text], images=images, return_tensors="np")
100
+
101
+ image_embeddings, embedding_norms = image_encoder(inputs['pixel_values'], model)
102
+ text_embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[0]
103
+ text_embedding = text_embedding / jnp.linalg.norm(text_embedding, axis=-1, keepdims=True)
104
+
105
+ vertical_scores = jnp.zeros((masks[0].shape[1], 512))
106
+ vertical_masks = jnp.zeros((masks[0].shape[1], 1))
107
+ horizontal_scores = jnp.zeros((masks[0].shape[0], 512))
108
+ horizontal_masks = jnp.zeros((masks[0].shape[0], 1))
109
+
110
+ for e, n, m, v, h in zip(image_embeddings, embedding_norms, masks, vertical, horizontal):
111
+ # sim = (jnp.matmul(e, text_embedding.T)) # + 1) / 2
112
+
113
+ # sim = jax.nn.relu(sim)
114
+
115
+ # if full_sim is None:
116
+ # full_sim = sim
117
+ # sim = jax.nn.relu(sim - full_sim)
118
+ emb = jnp.expand_dims(e, axis=0) * n
119
+
120
+ if v:
121
+ vm = jnp.any(m, axis=0)
122
+ vertical_scores = vertical_scores + (emb * vm) #/ jnp.mean(vm)
123
+ vertical_masks = vertical_masks + vm #/ jnp.mean(vm)
124
+ if h:
125
+ hm = jnp.any(m, axis=1)
126
+ horizontal_scores = horizontal_scores + (emb * hm) #/ jnp.mean(hm)
127
+ horizontal_masks = horizontal_masks + hm #/ jnp.mean(hm)
128
+
129
+
130
+ embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1)
131
+ embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
132
+ full_embs = jnp.minimum(embs_1, embs_2)
133
+ mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
134
+
135
+ print(full_embs.shape)
136
+
137
+ #full_embs = full_embs / jnp.linalg.norm(full_embs, axis=-1, keepdims=True)
138
+ full_embs = (full_embs / mask_sum)
139
+
140
+ orig_shape = full_embs.shape
141
+ sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)), text_embedding.T)
142
+ sims = jnp.reshape(sims, (*orig_shape[:2], 1))
143
+ #sims = jax.nn.relu(sims)
144
+
145
+
146
+
147
+
148
+
149
+
150
+ # mean_vertical_scores = vertical_scores / vertical_masks
151
+ # mean_horizontal_scores = horizontal_scores / horizontal_masks
152
+
153
+ # print(mean_vertical_score)
154
+ # print(mean_horizontal_score)
155
+
156
+ # score = jnp.matmul(mean_vertical_scores, mean_horizontal_scores.T)
157
+
158
+ #mask = jnp.matmul(vertical_masks, horizontal_scores.T)
159
+ #score = score / mask
160
+
161
+ score = sims # jnp.expand_dims(score.T, axis=-1)
162
+ #score = jax.nn.relu(score) / jnp.max(jnp.abs(score))
163
+
164
+ #score = jax.nn.relu(score - sims[0])
165
+
166
+ # score = jnp.square(score)
167
+
168
  for i in range(iterations):
169
  score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
170
+
171
  score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
172
+
173
+ print(jnp.min(score), jnp.max(score))
174
+
175
  return np.asarray(score), input_image
176
 
177
 
 
214
  with col2:
215
  pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2)
216
 
217
+ iterations = st.selectbox("Refinement Steps", options=range(1, 6, 1), index=0)
218
 
219
  compute = st.button("LOCATE")
220
 
modeling_hybrid_clip.py CHANGED
@@ -136,8 +136,10 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
136
  ):
137
  if input_shape is None:
138
  input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
 
 
139
 
140
- module = self.module_class(config=config, dtype=dtype, **kwargs)
141
  super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
142
 
143
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
 
136
  ):
137
  if input_shape is None:
138
  input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
139
+
140
+ print(kwargs)
141
 
142
+ module = self.module_class(config=config, dtype=dtype) # , **kwargs)
143
  super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
144
 
145
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
requirements.txt CHANGED
@@ -8,4 +8,4 @@ stqdm
8
  pandas
9
  requests
10
  psutil
11
- streamlit==1.2.0
 
8
  pandas
9
  requests
10
  psutil
11
+ streamlit