4rtemi5 commited on
Commit
ea3b7ec
β€’
1 Parent(s): 0264d55

add localization and examples

Browse files
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import image2text
3
  import text2image
 
4
  import home
5
  import examples
6
  from PIL import Image
@@ -9,6 +10,7 @@ PAGES = {
9
  "Introduction": home,
10
  "Text to Image": text2image,
11
  "Image to Text": image2text,
 
12
  "Examples & Applications": examples,
13
  }
14
 
 
1
  import streamlit as st
2
  import image2text
3
  import text2image
4
+ import localization
5
  import home
6
  import examples
7
  from PIL import Image
 
10
  "Introduction": home,
11
  "Text to Image": text2image,
12
  "Image to Text": image2text,
13
+ "Localization": localization,
14
  "Examples & Applications": examples,
15
  }
16
 
examples.py CHANGED
@@ -81,6 +81,20 @@ def app():
81
  col2.markdown("*A rustic chair*")
82
  col2.image("static/img/examples/sedia_rustica.jpeg", use_column_width=True)
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  st.markdown("## Image Classification")
85
  st.markdown(
86
  "We report this cool example provided by the "
 
81
  col2.markdown("*A rustic chair*")
82
  col2.image("static/img/examples/sedia_rustica.jpeg", use_column_width=True)
83
 
84
+ st.markdown('## Localization')
85
+
86
+ st.subheader("Un gatto")
87
+ st.markdown("*A cat*")
88
+ st.image("static/img/examples/un_gatto.png", use_column_width=True)
89
+
90
+ st.subheader("Un gatto")
91
+ st.markdown("*A cat*")
92
+ st.image("static/img/examples/due_gatti.png", use_column_width=True)
93
+
94
+ st.subheader("Un bambino")
95
+ st.markdown("*A child*")
96
+ st.image("static/img/examples/child_on_slide.png", use_column_width=True)
97
+
98
  st.markdown("## Image Classification")
99
  st.markdown(
100
  "We report this cool example provided by the "
localization.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from text2image import get_model, get_tokenizer, get_image_transform
3
+ from utils import text_encoder
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from jax import numpy as jnp
7
+ import pandas as pd
8
+ import numpy as np
9
+ import requests
10
+ import jax
11
+ import gc
12
+
13
+
14
+ preprocess = transforms.Compose([
15
+ transforms.ToTensor(),
16
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
17
+ ])
18
+
19
+
20
+ def pad_to_square(image, size=224):
21
+ ratio = float(size) / max(image.size)
22
+ new_size = tuple([int(x * ratio) for x in image.size])
23
+ image = image.resize(new_size, Image.ANTIALIAS)
24
+ new_image = Image.new("RGB", size=(size, size), color=(128, 128, 128))
25
+ new_image.paste(image, ((size - new_size[0]) // 2, (size - new_size[1]) // 2))
26
+ return new_image
27
+
28
+
29
+ def image_encoder(image, model):
30
+ image = np.transpose(image, (0, 2, 3, 1))
31
+ features = model.get_image_features(image)
32
+ features /= jnp.linalg.norm(features, keepdims=True)
33
+ return features
34
+
35
+
36
+ def gen_image_batch(image_url, image_size=224, pixel_size=10):
37
+ n_pixels = image_size // pixel_size + 1
38
+
39
+ image_batch = []
40
+ masks = []
41
+ image_raw = requests.get(image_url, stream=True).raw
42
+ image = Image.open(image_raw).convert("RGB")
43
+ image = pad_to_square(image, size=image_size)
44
+ gray = np.ones_like(image) * 128
45
+ mask = np.ones_like(image)
46
+
47
+ image_batch.append(image)
48
+ masks.append(mask)
49
+
50
+ for i in range(0, n_pixels):
51
+ for j in range(i+1, n_pixels):
52
+ m = mask.copy()
53
+ m[:min(i*pixel_size, image_size) + 1, :] = 0
54
+ m[min(j*pixel_size, image_size) + 1:, :] = 0
55
+ neg_m = 1 - m
56
+ image_batch.append(image * m + gray * neg_m)
57
+ masks.append(m)
58
+
59
+ for i in range(0, n_pixels+1):
60
+ for j in range(i+1, n_pixels+1):
61
+ m = mask.copy()
62
+ m[:, :min(i*pixel_size + 1, image_size)] = 0
63
+ m[:, min(j*pixel_size + 1, image_size):] = 0
64
+ neg_m = 1 - m
65
+ image_batch.append(image * m + gray * neg_m)
66
+ masks.append(m)
67
+
68
+ return image_batch, masks
69
+
70
+
71
+ def get_heatmap(image_url, text, pixel_size=10, iterations=3):
72
+ tokenizer = get_tokenizer()
73
+ model = get_model()
74
+ image_size = model.config.vision_config.image_size
75
+ text_embedding = text_encoder(text, model, tokenizer)
76
+ images, masks = gen_image_batch(image_url, image_size=image_size, pixel_size=pixel_size)
77
+
78
+ input_image = images[0].copy()
79
+ images = np.stack([preprocess(image) for image in images], axis=0)
80
+ image_embeddings = jnp.asarray(image_encoder(images, model))
81
+
82
+ sims = []
83
+ scores = []
84
+ mask_val = jnp.zeros_like(masks[0])
85
+
86
+ for e, m in zip(image_embeddings, masks):
87
+ sim = jnp.matmul(e, text_embedding.T)
88
+ sims.append(sim)
89
+ if len(sims) > 1:
90
+ scores.append(sim * m)
91
+ mask_val += 1 - m
92
+
93
+ score = jnp.mean(jnp.clip(jnp.array(scores) - sims[0], 0, jnp.inf), axis=0)
94
+ for i in range(iterations):
95
+ score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
96
+ score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
97
+ return np.asarray(score), input_image
98
+
99
+
100
+ def app():
101
+ st.title("Zero-Shot Localization")
102
+ st.markdown(
103
+ """
104
+
105
+ ### πŸ‘‹ Ciao!
106
+
107
+ Here you can find an exaple for zero shot localization that will show you where in an image the model sees an object.
108
+
109
+ 🀌 Italian mode on! 🀌
110
+
111
+ For example, try typing "gatto" (cat) or "cane" (dog) in the space for label and click "locate"!
112
+
113
+ """
114
+ )
115
+
116
+ image_url = st.text_input(
117
+ "You can input the URL of an image here...",
118
+ value="https://www.tuttosuigatti.it/files/styles/full_width/public/images/featured/205/cani-e-gatti.jpg?itok=WAAiTGS6",
119
+ )
120
+
121
+ MAX_ITER = 1
122
+
123
+ col1, col2 = st.beta_columns([3, 1])
124
+
125
+ with col2:
126
+ pixel_size = st.selectbox(
127
+ "Pixel Size", options=range(10, 21, 5), index=0
128
+ )
129
+
130
+ iterations = st.selectbox(
131
+ "Refinement Steps", options=range(3, 30, 3), index=0
132
+ )
133
+
134
+ compute = st.button("LOCATE")
135
+
136
+ with col1:
137
+ caption = st.text_input(f"Insert label...")
138
+
139
+ if compute:
140
+
141
+ if not caption or not image_url:
142
+ st.error("Please choose one image and at least one label")
143
+ else:
144
+ with st.spinner("Computing..."):
145
+ heatmap, image = get_heatmap(image_url, caption, pixel_size, iterations)
146
+
147
+ with col1:
148
+ st.image(image, use_column_width=True)
149
+ st.image(heatmap, use_column_width=True)
150
+ st.image(np.asarray(image) / 255.0 * heatmap, use_column_width=True)
151
+ gc.collect()
152
+
153
+ elif image_url:
154
+ image_raw = requests.get(image_url, stream=True, ).raw
155
+ image = Image.open(image_raw).convert("RGB")
156
+ with col1:
157
+ st.image(image)
static/img/examples/child_on_slide.png ADDED
static/img/examples/due_gatti.png ADDED
static/img/examples/un_gatto.png ADDED