Spaces:
Running
Running
import streamlit as st | |
from text2image import get_model, get_tokenizer, get_image_transform | |
from utils import text_encoder | |
from torchvision import transforms | |
from PIL import Image | |
from jax import numpy as jnp | |
import pandas as pd | |
import numpy as np | |
import requests | |
import psutil | |
import time | |
import jax | |
import gc | |
preprocess = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) | |
), | |
] | |
) | |
def pad_to_square(image, size=224): | |
ratio = float(size) / max(image.size) | |
new_size = tuple([int(x * ratio) for x in image.size]) | |
image = image.resize(new_size, Image.ANTIALIAS) | |
new_image = Image.new("RGB", size=(size, size), color=(128, 128, 128)) | |
new_image.paste(image, ((size - new_size[0]) // 2, (size - new_size[1]) // 2)) | |
return new_image | |
def image_encoder(image, model): | |
image = np.transpose(image, (0, 2, 3, 1)) | |
features = model.get_image_features(image) | |
features /= jnp.linalg.norm(features, keepdims=True) | |
return features | |
def gen_image_batch(image_url, image_size=224, pixel_size=10): | |
n_pixels = image_size // pixel_size + 1 | |
image_batch = [] | |
masks = [] | |
image_raw = requests.get(image_url, stream=True).raw | |
image = Image.open(image_raw).convert("RGB") | |
image = pad_to_square(image, size=image_size) | |
gray = np.ones_like(image) * 128 | |
mask = np.ones_like(image) | |
image_batch.append(image) | |
masks.append(mask) | |
for i in range(0, n_pixels): | |
for j in range(i + 1, n_pixels): | |
m = mask.copy() | |
m[: min(i * pixel_size, image_size) + 1, :] = 0 | |
m[min(j * pixel_size, image_size) + 1 :, :] = 0 | |
neg_m = 1 - m | |
image_batch.append(image * m + gray * neg_m) | |
masks.append(m) | |
for i in range(0, n_pixels + 1): | |
for j in range(i + 1, n_pixels + 1): | |
m = mask.copy() | |
m[:, : min(i * pixel_size + 1, image_size)] = 0 | |
m[:, min(j * pixel_size + 1, image_size) :] = 0 | |
neg_m = 1 - m | |
image_batch.append(image * m + gray * neg_m) | |
masks.append(m) | |
return image_batch, masks | |
def get_heatmap(image_url, text, pixel_size=10, iterations=3): | |
tokenizer = get_tokenizer() | |
model = get_model() | |
image_size = model.config.vision_config.image_size | |
text_embedding = text_encoder(text, model, tokenizer) | |
images, masks = gen_image_batch( | |
image_url, image_size=image_size, pixel_size=pixel_size | |
) | |
input_image = images[0].copy() | |
images = np.stack([preprocess(image) for image in images], axis=0) | |
image_embeddings = jnp.asarray(image_encoder(images, model)) | |
sims = [] | |
scores = [] | |
mask_val = jnp.zeros_like(masks[0]) | |
for e, m in zip(image_embeddings, masks): | |
sim = jnp.matmul(e, text_embedding.T) | |
sims.append(sim) | |
if len(sims) > 1: | |
scores.append(sim * m) | |
mask_val += 1 - m | |
score = jnp.mean(jnp.clip(jnp.array(scores) - sims[0], 0, jnp.inf), axis=0) | |
for i in range(iterations): | |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf) | |
score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score)) | |
return np.asarray(score), input_image | |
def app(): | |
st.title("Zero-Shot Localization") | |
st.markdown( | |
""" | |
### π Ciao! | |
Here you can find an example for zero-shot localization that will show you where in an image the model sees an object. | |
The object location is computed by masking different areas of the image and looking at | |
how the similarity to the image description changes. If you want to have a look at the implementation in detail, | |
you can find it in [this Colab](https://colab.research.google.com/drive/10neENr1DEAFq_GzsLqBDo0gZ50hOhkOr?usp=sharing). | |
On the two parameters: | |
+ the *pixel size* defines the resolution of the localization map. A pixel size of 15 means | |
that 15 pixels in the original image will form 1 pixel in the heatmap. | |
+ The *refinement iterations* are just a cheap operation to reduce background noise. Too few iterations will leave a lot of noise. | |
Too many will shrink the heatmap too much. | |
π€ Italian mode on! π€ | |
For example, try typing "gatto" (cat) or "cane" (dog) in the space for label and click "locate"! | |
""" | |
) | |
image_url = st.text_input( | |
"You can input the URL of an image here...", | |
value="https://www.tuttosuigatti.it/files/styles/full_width/public/images/featured/205/cani-e-gatti.jpg?itok=WAAiTGS6", | |
) | |
MAX_ITER = 1 | |
col1, col2 = st.columns([0.75, 0.25]) | |
with col2: | |
pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2) | |
iterations = st.selectbox("Refinement Steps", options=range(3, 30, 3), index=0) | |
compute = st.button("LOCATE") | |
with col1: | |
caption = st.text_input(f"Insert label...") | |
if compute: | |
with st.spinner("Waiting for resources..."): | |
sleep_time = 5 | |
while psutil.cpu_percent() > 50: | |
time.sleep(sleep_time) | |
if not caption or not image_url: | |
st.error("Please choose one image and at least one label") | |
else: | |
with st.spinner( | |
"Computing... This might take up to a few minutes depending on the current load π \n" | |
"Otherwise, you can use this [Colab notebook](https://colab.research.google.com/drive/10neENr1DEAFq_GzsLqBDo0gZ50hOhkOr?usp=sharing)" | |
): | |
heatmap, image = get_heatmap(image_url, caption, pixel_size, iterations) | |
with col1: | |
st.image(image, use_column_width=True) | |
st.image(heatmap, use_column_width=True) | |
st.image(np.asarray(image) / 255.0 * heatmap, use_column_width=True) | |
gc.collect() | |
elif image_url: | |
image = requests.get( | |
image_url, | |
stream=True, | |
).raw | |
image = Image.open(image).convert("RGB") | |
with col1: | |
st.image(image) | |