4rtemi5 commited on
Commit
3fc62db
1 Parent(s): e45c79f

fix encoder loading

Browse files
Files changed (2) hide show
  1. localization.py +17 -38
  2. utils.py +8 -6
localization.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from text2image import get_model, get_tokenizer, get_image_transform
3
  from utils import text_encoder
4
- from transformers import AutoProcessor
5
  from PIL import Image
6
  from jax import numpy as jnp
7
  import pandas as pd
@@ -13,7 +13,16 @@ import jax
13
  import gc
14
 
15
 
16
- preprocess = AutoProcessor.from_pretrained("clip-italian/clip-italian")
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def resize_longer(image, longer_size=224):
@@ -89,18 +98,16 @@ def gen_image_batch(image_url, image_size=224, pixel_size=10):
89
 
90
 
91
  def get_heatmap(image_url, text, pixel_size=10, iterations=3):
92
- # tokenizer = get_tokenizer()
93
  model = get_model()
94
  image_size = model.config.vision_config.image_size
95
 
96
  images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size)
97
  input_image = images[0].copy()
98
 
99
- inputs = preprocess(text=[text], images=images, return_tensors="np")
100
-
101
- image_embeddings, embedding_norms = image_encoder(inputs['pixel_values'], model)
102
- text_embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[0]
103
- text_embedding = text_embedding / jnp.linalg.norm(text_embedding, axis=-1, keepdims=True)
104
 
105
  vertical_scores = jnp.zeros((masks[0].shape[1], 512))
106
  vertical_masks = jnp.zeros((masks[0].shape[1], 1))
@@ -131,39 +138,11 @@ def get_heatmap(image_url, text, pixel_size=10, iterations=3):
131
  embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
132
  full_embs = jnp.minimum(embs_1, embs_2)
133
  mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
134
-
135
- print(full_embs.shape)
136
-
137
- #full_embs = full_embs / jnp.linalg.norm(full_embs, axis=-1, keepdims=True)
138
  full_embs = (full_embs / mask_sum)
139
 
140
  orig_shape = full_embs.shape
141
- sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)), text_embedding.T)
142
- sims = jnp.reshape(sims, (*orig_shape[:2], 1))
143
- #sims = jax.nn.relu(sims)
144
-
145
-
146
-
147
-
148
-
149
-
150
- # mean_vertical_scores = vertical_scores / vertical_masks
151
- # mean_horizontal_scores = horizontal_scores / horizontal_masks
152
-
153
- # print(mean_vertical_score)
154
- # print(mean_horizontal_score)
155
-
156
- # score = jnp.matmul(mean_vertical_scores, mean_horizontal_scores.T)
157
-
158
- #mask = jnp.matmul(vertical_masks, horizontal_scores.T)
159
- #score = score / mask
160
-
161
- score = sims # jnp.expand_dims(score.T, axis=-1)
162
- #score = jax.nn.relu(score) / jnp.max(jnp.abs(score))
163
-
164
- #score = jax.nn.relu(score - sims[0])
165
-
166
- # score = jnp.square(score)
167
 
168
  for i in range(iterations):
169
  score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
 
1
  import streamlit as st
2
  from text2image import get_model, get_tokenizer, get_image_transform
3
  from utils import text_encoder
4
+ from torchvision import transforms
5
  from PIL import Image
6
  from jax import numpy as jnp
7
  import pandas as pd
 
13
  import gc
14
 
15
 
16
+ preprocess = transforms.Compose(
17
+ [
18
+ transforms.ToTensor(),
19
+ transforms.Resize(224),
20
+ transforms.Normalize(
21
+ (0.48145466, 0.4578275, 0.40821073),
22
+ (0.26862954, 0.26130258, 0.27577711)
23
+ ),
24
+ ]
25
+ )
26
 
27
 
28
  def resize_longer(image, longer_size=224):
 
98
 
99
 
100
  def get_heatmap(image_url, text, pixel_size=10, iterations=3):
101
+ tokenizer = get_tokenizer()
102
  model = get_model()
103
  image_size = model.config.vision_config.image_size
104
 
105
  images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size)
106
  input_image = images[0].copy()
107
 
108
+ images = np.stack([preprocess(pad_to_square(image)) for image in images], axis=0)
109
+ image_embeddings, embedding_norms = image_encoder(images, model)
110
+ text_embeddings, _ = text_encoder(text, model, tokenizer)
 
 
111
 
112
  vertical_scores = jnp.zeros((masks[0].shape[1], 512))
113
  vertical_masks = jnp.zeros((masks[0].shape[1], 1))
 
138
  embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
139
  full_embs = jnp.minimum(embs_1, embs_2)
140
  mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
 
 
 
 
141
  full_embs = (full_embs / mask_sum)
142
 
143
  orig_shape = full_embs.shape
144
+ sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)), text_embeddings.T)
145
+ score = jnp.reshape(sims, (*orig_shape[:2], 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  for i in range(iterations):
148
  score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
utils.py CHANGED
@@ -34,18 +34,20 @@ def text_encoder(text, model, tokenizer):
34
  padding="max_length",
35
  return_tensors="np",
36
  )
37
- embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
38
- 0
39
- ]
40
- embedding /= jnp.linalg.norm(embedding)
41
- return jnp.expand_dims(embedding, axis=0)
 
42
 
43
 
44
  def image_encoder(image, model):
45
  image = image.permute(1, 2, 0).numpy()
46
  image = jnp.expand_dims(image, axis=0) #  add batch size
47
  features = model.get_image_features(image,)
48
- features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
 
49
  return features
50
 
51
 
 
34
  padding="max_length",
35
  return_tensors="np",
36
  )
37
+ embedding = model.get_text_features(
38
+ inputs["input_ids"],
39
+ inputs["attention_mask"])[0]
40
+ norms = jnp.linalg.norm(embedding, axis=-1, keepdims=True)
41
+ embedding = embedding / norms
42
+ return jnp.expand_dims(embedding, axis=0), norms
43
 
44
 
45
  def image_encoder(image, model):
46
  image = image.permute(1, 2, 0).numpy()
47
  image = jnp.expand_dims(image, axis=0) #  add batch size
48
  features = model.get_image_features(image,)
49
+ norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
50
+ features = features / norms
51
  return features
52
 
53