clip-owlvit / app.py
Vivien
Improve presentation
6392199
raw
history blame
10.2 kB
from html import escape
import requests
from io import BytesIO
import base64
from multiprocessing.dummy import Pool
from PIL import Image, ImageDraw
import streamlit as st
import pandas as pd, numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from transformers.image_utils import ImageFeatureExtractionMixin
import tokenizers
DEBUG = False
if DEBUG:
MODEL = "vit-base-patch32"
else:
MODEL = "vit-large-patch14-336"
CLIP_MODEL = f"openai/clip-{MODEL}"
OWL_MODEL = f"google/owlvit-base-patch32"
if not DEBUG and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
HEIGHT = 200
N_RESULTS = 6
color = st.get_option("theme.primaryColor")
if color is None:
color = (255, 196, 35)
else:
color = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4))
@st.cache(allow_output_mutation=True)
def load():
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
clip_model = CLIPModel.from_pretrained(CLIP_MODEL)
clip_model.to(device)
clip_model.eval()
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
owl_model = OwlViTForObjectDetection.from_pretrained(OWL_MODEL)
owl_model.to(device)
owl_model.eval()
owl_processor = OwlViTProcessor.from_pretrained(OWL_MODEL)
embeddings = {
0: np.load(f"embeddings-{MODEL}.npy"),
1: np.load(f"embeddings2-{MODEL}.npy"),
}
for k in [0, 1]:
embeddings[k] = embeddings[k] / np.linalg.norm(
embeddings[k], axis=1, keepdims=True
)
return clip_model, clip_processor, owl_model, owl_processor, df, embeddings
clip_model, clip_processor, owl_model, owl_processor, df, embeddings = load()
mixin = ImageFeatureExtractionMixin()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
def compute_text_embeddings(list_of_strings):
inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to(
device
)
with torch.no_grad():
result = clip_model.get_text_features(**inputs).detach().cpu().numpy()
return result / np.linalg.norm(result, axis=1, keepdims=True)
def image_search(query, corpus, n_results=N_RESULTS):
query_embedding = compute_text_embeddings([query])
corpus_id = 0 if corpus == "Unsplash" else 1
dot_product = (embeddings[corpus_id] @ query_embedding.T)[:, 0]
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
return [
(
df[corpus_id].iloc[i].path,
df[corpus_id].iloc[i].tooltip + source[corpus_id],
df[corpus_id].iloc[i].link,
)
for i in results
]
def make_square(img, fill_color=(255, 255, 255)):
x, y = img.size
size = max(x, y)
new_img = Image.new("RGB", (size, size), fill_color)
new_img.paste(img, (int((size - x) / 2), int((size - y) / 2)))
return new_img, x, y
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_images(paths):
def process_image(path):
return make_square(Image.open(BytesIO(requests.get(path).content)))
processed = Pool(N_RESULTS).map(process_image, paths)
imgs, xs, ys = [], [], []
for img, x, y in processed:
imgs.append(img)
xs.append(x)
ys.append(y)
return imgs, xs, ys
@st.cache(
hash_funcs={
tokenizers.Tokenizer: lambda x: None,
tokenizers.AddedToken: lambda x: None,
torch.nn.parameter.Parameter: lambda x: None,
},
allow_output_mutation=True,
show_spinner=False,
)
def apply_owl_model(owl_queries, images):
inputs = owl_processor(text=owl_queries, images=images, return_tensors="pt").to(
device
)
with torch.no_grad():
results = owl_model(**inputs)
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
return owl_processor.post_process(outputs=results, target_sizes=target_sizes)
def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8):
candidates = []
for box, score in zip(boxes, scores):
box = [round(i, 0) for i in box.tolist()]
if score >= score_threshold:
candidates.append((box, float(score)))
to_ignore = set()
for i in range(len(candidates) - 1):
if i in to_ignore:
continue
for j in range(i + 1, len(candidates)):
if j in to_ignore:
continue
xmin1, ymin1, xmax1, ymax1 = candidates[i][0]
xmin2, ymin2, xmax2, ymax2 = candidates[j][0]
if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1:
continue
else:
xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3]
ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3]
area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter)
area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
iou = area_inter / (area1 + area2 - area_inter)
if iou > max_iou:
if candidates[i][1] > candidates[j][1]:
to_ignore.add(j)
else:
to_ignore.add(i)
break
else:
if area_inter / area1 > 0.9:
if candidates[i][1] < 1.1 * candidates[j][1]:
to_ignore.add(i)
if area_inter / area2 > 0.9:
if 1.1 * candidates[i][1] > candidates[j][1]:
to_ignore.add(j)
return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore]
def convert_pil_to_base64(image):
img_buffer = BytesIO()
image.save(img_buffer, format="JPEG")
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data)
return base64_str
def draw_reshape_encode(img, boxes, x, y):
image = img.copy()
draw = ImageDraw.Draw(image)
new_x, new_y = int(x * HEIGHT / y), HEIGHT
for box in boxes:
draw.rectangle(
(tuple(box[:2]), tuple(box[2:])), outline=color, width=2 * int(y / HEIGHT)
)
if x > y:
image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2))
else:
image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y))
return convert_pil_to_base64(image.resize((new_x, new_y)))
def get_html(url_list, encoded_images):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for i in range(len(url_list)):
title, encoded = url_list[i][1], encoded_images[i]
html = (
html
+ f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>"
)
html += "</div>"
return html
description = """
# Search and Detect
This demo illustrates how to both retrieve images containing certain objects and locate these objects with a simple text query.
**Enter your query and hit enter**
**Tip 1**: if your query includes "/", the part left (resp. right) of "/" will be used to retrieve images (resp. locate objects). For example, if you want to retrieve pictures with several cats but locate individual cats, you can type "cats / cat".
**Tip 2**: change the score threshold to adjust the sensitivity of the object detection step.
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model and Google's [OWL-ViT](https://arxiv.org/abs/2205.06230) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
"""
div_style = {
"display": "flex",
"justify-content": "center",
"flex-wrap": "wrap",
}
def main():
st.markdown(
"""
<style>
.block-container{
max-width: 1200px;
}
div.row-widget > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
.row-widget {
margin-top: -25px;
}
section>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
</style>""",
unsafe_allow_html=True,
)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input("", value="koala")
corpus = c.radio("", ["Unsplash", "Movies"])
score_threshold = c.slider(
"Score threshold", min_value=0.01, max_value=0.1, value=0.03, step=0.01
)
if len(query) > 0:
if "/" in query:
queries = query.split("/")
clip_query, owl_query = ("/").join(queries[:-1]).strip(), queries[
-1
].strip()
else:
clip_query, owl_query = query, query
retrieved = image_search(clip_query, corpus)
imgs, xs, ys = get_images([x[0] for x in retrieved])
results = apply_owl_model([[owl_query]] * len(imgs), imgs)
encoded_images = []
for image_idx in range(len(imgs)):
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
boxes = keep_best_boxes(
results[image_idx]["boxes"],
results[image_idx]["scores"],
score_threshold=score_threshold,
)
encoded_images.append(draw_reshape_encode(img0, boxes, x, y))
st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True)
if __name__ == "__main__":
main()