Spaces:
Runtime error
Runtime error
from html import escape | |
import requests | |
from io import BytesIO | |
import base64 | |
from multiprocessing.dummy import Pool | |
from PIL import Image, ImageDraw | |
import streamlit as st | |
import pandas as pd, numpy as np | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
from transformers.image_utils import ImageFeatureExtractionMixin | |
import tokenizers | |
DEBUG = False | |
if DEBUG: | |
MODEL = "vit-base-patch32" | |
else: | |
MODEL = "vit-large-patch14-336" | |
CLIP_MODEL = f"openai/clip-{MODEL}" | |
OWL_MODEL = f"google/owlvit-base-patch32" | |
if not DEBUG and torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
HEIGHT = 200 | |
N_RESULTS = 6 | |
color = st.get_option("theme.primaryColor") | |
if color is None: | |
color = (255, 196, 35) | |
else: | |
color = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4)) | |
def load(): | |
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} | |
clip_model = CLIPModel.from_pretrained(CLIP_MODEL) | |
clip_model.to(device) | |
clip_model.eval() | |
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL) | |
owl_model = OwlViTForObjectDetection.from_pretrained(OWL_MODEL) | |
owl_model.to(device) | |
owl_model.eval() | |
owl_processor = OwlViTProcessor.from_pretrained(OWL_MODEL) | |
embeddings = { | |
0: np.load(f"embeddings-{MODEL}.npy"), | |
1: np.load(f"embeddings2-{MODEL}.npy"), | |
} | |
for k in [0, 1]: | |
embeddings[k] = embeddings[k] / np.linalg.norm( | |
embeddings[k], axis=1, keepdims=True | |
) | |
return clip_model, clip_processor, owl_model, owl_processor, df, embeddings | |
clip_model, clip_processor, owl_model, owl_processor, df, embeddings = load() | |
mixin = ImageFeatureExtractionMixin() | |
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} | |
def compute_text_embeddings(list_of_strings): | |
inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to( | |
device | |
) | |
with torch.no_grad(): | |
result = clip_model.get_text_features(**inputs).detach().cpu().numpy() | |
return result / np.linalg.norm(result, axis=1, keepdims=True) | |
def image_search(query, corpus, n_results=N_RESULTS): | |
query_embedding = compute_text_embeddings([query]) | |
corpus_id = 0 if corpus == "Unsplash" else 1 | |
dot_product = (embeddings[corpus_id] @ query_embedding.T)[:, 0] | |
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] | |
return [ | |
( | |
df[corpus_id].iloc[i].path, | |
df[corpus_id].iloc[i].tooltip + source[corpus_id], | |
df[corpus_id].iloc[i].link, | |
) | |
for i in results | |
] | |
def make_square(img, fill_color=(255, 255, 255)): | |
x, y = img.size | |
size = max(x, y) | |
new_img = Image.new("RGB", (size, size), fill_color) | |
new_img.paste(img, (int((size - x) / 2), int((size - y) / 2))) | |
return new_img, x, y | |
def get_images(paths): | |
def process_image(path): | |
return make_square(Image.open(BytesIO(requests.get(path).content))) | |
processed = Pool(N_RESULTS).map(process_image, paths) | |
imgs, xs, ys = [], [], [] | |
for img, x, y in processed: | |
imgs.append(img) | |
xs.append(x) | |
ys.append(y) | |
return imgs, xs, ys | |
def apply_owl_model(owl_queries, images): | |
inputs = owl_processor(text=owl_queries, images=images, return_tensors="pt").to( | |
device | |
) | |
with torch.no_grad(): | |
results = owl_model(**inputs) | |
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) | |
return owl_processor.post_process(outputs=results, target_sizes=target_sizes) | |
def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8): | |
candidates = [] | |
for box, score in zip(boxes, scores): | |
box = [round(i, 0) for i in box.tolist()] | |
if score >= score_threshold: | |
candidates.append((box, float(score))) | |
to_ignore = set() | |
for i in range(len(candidates) - 1): | |
if i in to_ignore: | |
continue | |
for j in range(i + 1, len(candidates)): | |
if j in to_ignore: | |
continue | |
xmin1, ymin1, xmax1, ymax1 = candidates[i][0] | |
xmin2, ymin2, xmax2, ymax2 = candidates[j][0] | |
if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1: | |
continue | |
else: | |
xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3] | |
ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3] | |
area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter) | |
area1 = (xmax1 - xmin1) * (ymax1 - ymin1) | |
area2 = (xmax2 - xmin2) * (ymax2 - ymin2) | |
iou = area_inter / (area1 + area2 - area_inter) | |
if iou > max_iou: | |
if candidates[i][1] > candidates[j][1]: | |
to_ignore.add(j) | |
else: | |
to_ignore.add(i) | |
break | |
else: | |
if area_inter / area1 > 0.9: | |
if candidates[i][1] < 1.1 * candidates[j][1]: | |
to_ignore.add(i) | |
if area_inter / area2 > 0.9: | |
if 1.1 * candidates[i][1] > candidates[j][1]: | |
to_ignore.add(j) | |
return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore] | |
def convert_pil_to_base64(image): | |
img_buffer = BytesIO() | |
image.save(img_buffer, format="JPEG") | |
byte_data = img_buffer.getvalue() | |
base64_str = base64.b64encode(byte_data) | |
return base64_str | |
def draw_reshape_encode(img, boxes, x, y): | |
image = img.copy() | |
draw = ImageDraw.Draw(image) | |
new_x, new_y = int(x * HEIGHT / y), HEIGHT | |
for box in boxes: | |
draw.rectangle( | |
(tuple(box[:2]), tuple(box[2:])), outline=color, width=2 * int(y / HEIGHT) | |
) | |
if x > y: | |
image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2)) | |
else: | |
image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y)) | |
return convert_pil_to_base64(image.resize((new_x, new_y))) | |
def get_html(url_list, encoded_images): | |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" | |
for i in range(len(url_list)): | |
title, encoded = url_list[i][1], encoded_images[i] | |
html = ( | |
html | |
+ f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>" | |
) | |
html += "</div>" | |
return html | |
description = """ | |
# Search and Detect | |
This demo illustrates how to both retrieve images containing certain objects and locate these objects with a simple text query. | |
**Enter your query and hit enter** | |
**Tip 1**: if your query includes "/", the left and right parts will be used to respectively retrieve images and locate objects. For example, if you want to retrieve pictures with several cats but locate individual cats, you can type "cats / cat". | |
**Tip 2**: change the score threshold to adjust the sensitivity of the object detection step. | |
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model and Google's [OWL-ViT](https://arxiv.org/abs/2205.06230) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)* | |
""" | |
div_style = { | |
"display": "flex", | |
"justify-content": "center", | |
"flex-wrap": "wrap", | |
} | |
def main(): | |
st.markdown( | |
""" | |
<style> | |
.block-container{ | |
max-width: 1200px; | |
} | |
div.row-widget > div{ | |
flex-direction:row; | |
display: flex; | |
justify-content: center; | |
} | |
div.row-widget.stRadio > div > label{ | |
margin-left: 5px; | |
margin-right: 5px; | |
} | |
.row-widget { | |
margin-top: -25px; | |
} | |
section>div:first-child { | |
padding-top: 30px; | |
} | |
div.appview-container > section:first-child{ | |
max-width: 320px; | |
} | |
#MainMenu { | |
visibility: hidden; | |
} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown(description) | |
_, c, _ = st.columns((1, 3, 1)) | |
query = c.text_input("", value="koala") | |
corpus = c.radio("", ["Unsplash", "Movies"]) | |
score_threshold = c.slider( | |
"Score threshold", min_value=0.01, max_value=0.2, value=0.03, step=0.01 | |
) | |
if len(query) > 0: | |
if "/" in query: | |
queries = query.split("/") | |
clip_query, owl_query = ("/").join(queries[:-1]).strip(), queries[ | |
-1 | |
].strip() | |
else: | |
clip_query, owl_query = query, query | |
retrieved = image_search(clip_query, corpus) | |
imgs, xs, ys = get_images([x[0] for x in retrieved]) | |
results = apply_owl_model([[owl_query]] * len(imgs), imgs) | |
encoded_images = [] | |
for image_idx in range(len(imgs)): | |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] | |
boxes = keep_best_boxes( | |
results[image_idx]["boxes"], | |
results[image_idx]["scores"], | |
score_threshold=score_threshold, | |
) | |
encoded_images.append(draw_reshape_encode(img0, boxes, x, y)) | |
st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() | |