Spaces:
Running
Running
import streamlit as st | |
from text2image import get_model, get_tokenizer, get_image_transform | |
from utils import text_encoder | |
from transformers import AutoProcessor | |
from PIL import Image | |
from jax import numpy as jnp | |
import pandas as pd | |
import numpy as np | |
import requests | |
import psutil | |
import time | |
import jax | |
import gc | |
preprocess = AutoProcessor.from_pretrained("clip-italian/clip-italian") | |
def resize_longer(image, longer_size=224): | |
old_size = image.size | |
ratio = float(longer_size) / max(old_size) | |
new_size = tuple([int(x * ratio) for x in old_size]) | |
image = image.resize(new_size, Image.ANTIALIAS) | |
return image | |
def pad_to_square(image): | |
(a,b)=image.shape[:2] | |
if a<b: | |
ah = (b - a) // 2 | |
padding=((ah,b - a -ah), (0,0), (0,0)) | |
else: | |
bh = (a - b) // 2 | |
padding=((0,0), (bh,a-b-bh), (0,0)) | |
return np.pad(image, padding,mode='constant',constant_values=127) | |
def image_encoder(image, model): | |
image = np.transpose(image, (0, 2, 3, 1)) | |
features = model.get_image_features(image) | |
feature_norms = jnp.linalg.norm(features, axis=-1, keepdims=True) | |
features = features / feature_norms | |
return features, feature_norms | |
def gen_image_batch(image_url, image_size=224, pixel_size=10): | |
n_pixels = image_size // pixel_size + 1 | |
image_batch = [] | |
masks = [] | |
is_vertical = [] | |
is_horizontal = [] | |
image_raw = requests.get(image_url, stream=True).raw | |
image = Image.open(image_raw).convert("RGB") | |
image = np.array(resize_longer(image, longer_size=image_size)) | |
gray = np.ones_like(image) * 127 | |
mask = np.ones_like(image[:,:,:1]) | |
image_batch.append(image) | |
masks.append(mask) | |
is_vertical.append(True) | |
is_horizontal.append(True) | |
for i in range(0, image.shape[0] // pixel_size + 1): | |
for j in range(i+1, image.shape[0] // pixel_size + 2): | |
m = mask.copy() | |
m[:min(i*pixel_size, image_size), :] = 0 | |
m[min(j*pixel_size, image_size):, :] = 0 | |
neg_m = 1 - m | |
image_batch.append(image.copy() * m + gray * neg_m) | |
masks.append(m) | |
is_vertical.append(False) | |
is_horizontal.append(True) | |
for i in range(0, image.shape[1] // pixel_size + 1): | |
for j in range(i+1, image.shape[1] // pixel_size + 2): | |
m = mask.copy() | |
m[:, :min(i*pixel_size, image_size)] = 0 | |
m[:, min(j*pixel_size, image_size):] = 0 | |
neg_m = 1 - m | |
image_batch.append(image.copy() * m + gray * neg_m) | |
masks.append(m) | |
is_vertical.append(True) | |
is_horizontal.append(False) | |
return image_batch, masks, is_vertical, is_horizontal | |
def get_heatmap(image_url, text, pixel_size=10, iterations=3): | |
# tokenizer = get_tokenizer() | |
model = get_model() | |
image_size = model.config.vision_config.image_size | |
images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size) | |
input_image = images[0].copy() | |
inputs = preprocess(text=[text], images=images, return_tensors="np") | |
image_embeddings, embedding_norms = image_encoder(inputs['pixel_values'], model) | |
text_embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[0] | |
text_embedding = text_embedding / jnp.linalg.norm(text_embedding, axis=-1, keepdims=True) | |
vertical_scores = jnp.zeros((masks[0].shape[1], 512)) | |
vertical_masks = jnp.zeros((masks[0].shape[1], 1)) | |
horizontal_scores = jnp.zeros((masks[0].shape[0], 512)) | |
horizontal_masks = jnp.zeros((masks[0].shape[0], 1)) | |
for e, n, m, v, h in zip(image_embeddings, embedding_norms, masks, vertical, horizontal): | |
# sim = (jnp.matmul(e, text_embedding.T)) # + 1) / 2 | |
# sim = jax.nn.relu(sim) | |
# if full_sim is None: | |
# full_sim = sim | |
# sim = jax.nn.relu(sim - full_sim) | |
emb = jnp.expand_dims(e, axis=0) * n | |
if v: | |
vm = jnp.any(m, axis=0) | |
vertical_scores = vertical_scores + (emb * vm) #/ jnp.mean(vm) | |
vertical_masks = vertical_masks + vm #/ jnp.mean(vm) | |
if h: | |
hm = jnp.any(m, axis=1) | |
horizontal_scores = horizontal_scores + (emb * hm) #/ jnp.mean(hm) | |
horizontal_masks = horizontal_masks + hm #/ jnp.mean(hm) | |
embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1) | |
embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1) | |
full_embs = jnp.minimum(embs_1, embs_2) | |
mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1) | |
print(full_embs.shape) | |
#full_embs = full_embs / jnp.linalg.norm(full_embs, axis=-1, keepdims=True) | |
full_embs = (full_embs / mask_sum) | |
orig_shape = full_embs.shape | |
sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)), text_embedding.T) | |
sims = jnp.reshape(sims, (*orig_shape[:2], 1)) | |
#sims = jax.nn.relu(sims) | |
# mean_vertical_scores = vertical_scores / vertical_masks | |
# mean_horizontal_scores = horizontal_scores / horizontal_masks | |
# print(mean_vertical_score) | |
# print(mean_horizontal_score) | |
# score = jnp.matmul(mean_vertical_scores, mean_horizontal_scores.T) | |
#mask = jnp.matmul(vertical_masks, horizontal_scores.T) | |
#score = score / mask | |
score = sims # jnp.expand_dims(score.T, axis=-1) | |
#score = jax.nn.relu(score) / jnp.max(jnp.abs(score)) | |
#score = jax.nn.relu(score - sims[0]) | |
# score = jnp.square(score) | |
for i in range(iterations): | |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf) | |
score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score)) | |
print(jnp.min(score), jnp.max(score)) | |
return np.asarray(score), input_image | |
def app(): | |
st.title("Zero-Shot Localization") | |
st.markdown( | |
""" | |
### π Ciao! | |
Here you can find an example for zero-shot localization that will show you where in an image the model sees an object. | |
The object location is computed by masking different areas of the image and looking at | |
how the similarity to the image description changes. If you want to have a look at the implementation in detail, | |
you can find it in [this Colab](https://colab.research.google.com/drive/10neENr1DEAFq_GzsLqBDo0gZ50hOhkOr?usp=sharing). | |
On the two parameters: | |
+ the *pixel size* defines the resolution of the localization map. A pixel size of 15 means | |
that 15 pixels in the original image will form 1 pixel in the heatmap. | |
+ The *refinement iterations* are just a cheap operation to reduce background noise. Too few iterations will leave a lot of noise. | |
Too many will shrink the heatmap too much. | |
π€ Italian mode on! π€ | |
For example, try typing "gatto" (cat) or "cane" (dog) in the space for label and click "locate"! | |
""" | |
) | |
image_url = st.text_input( | |
"You can input the URL of an image here...", | |
value="https://www.tuttosuigatti.it/files/styles/full_width/public/images/featured/205/cani-e-gatti.jpg?itok=WAAiTGS6", | |
) | |
MAX_ITER = 1 | |
col1, col2 = st.columns([0.75, 0.25]) | |
with col2: | |
pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2) | |
iterations = st.selectbox("Refinement Steps", options=range(1, 6, 1), index=0) | |
compute = st.button("LOCATE") | |
with col1: | |
caption = st.text_input(f"Insert label...") | |
if compute: | |
with st.spinner("Waiting for resources..."): | |
sleep_time = 5 | |
while psutil.cpu_percent() > 50: | |
time.sleep(sleep_time) | |
if not caption or not image_url: | |
st.error("Please choose one image and at least one label") | |
else: | |
with st.spinner( | |
"Computing... This might take up to a few minutes depending on the current load π \n" | |
"Otherwise, you can use this [Colab notebook](https://colab.research.google.com/drive/10neENr1DEAFq_GzsLqBDo0gZ50hOhkOr?usp=sharing)" | |
): | |
heatmap, image = get_heatmap(image_url, caption, pixel_size, iterations) | |
with col1: | |
st.image(image, use_column_width=True) | |
st.image(heatmap, use_column_width=True) | |
st.image(np.asarray(image) / 255.0 * heatmap, use_column_width=True) | |
gc.collect() | |
elif image_url: | |
image = requests.get( | |
image_url, | |
stream=True, | |
).raw | |
image = Image.open(image).convert("RGB") | |
with col1: | |
st.image(image) | |