Spaces:
Running
Running
from html import escape | |
from io import BytesIO | |
import base64 | |
from multiprocessing.dummy import Pool | |
from PIL import Image, ImageDraw | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import torch | |
# from transformers import CLIPProcessor, CLIPModel | |
# from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
# from transformers.image_utils import ImageFeatureExtractionMixin | |
import pickle as pkl | |
# sketches | |
from streamlit_drawable_canvas import st_canvas | |
from PIL import Image, ImageOps | |
from torchvision import transforms | |
# model | |
import os | |
# No reconoce la carpeta que esta dos niveles abajo src | |
from src.model_LN_prompt import Model | |
from src.options import opts | |
from datasets import load_dataset | |
DEBUG = False | |
if DEBUG: | |
MODEL = "vit-base-patch32" | |
else: | |
MODEL = "vit-large-patch14-336" | |
CLIP_MODEL = f"openai/clip-{MODEL}" | |
OWL_MODEL = f"google/owlvit-base-patch32" | |
if not DEBUG and torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
HEIGHT = 350 | |
N_RESULTS = 5 | |
from huggingface_hub import hf_hub_download,login | |
token = os.getenv("HUGGINGFACE_TOKEN") | |
# Autentica usando el token | |
login(token=token) | |
color = st.get_option("theme.primaryColor") | |
if color is None: | |
color = (0, 255, 0) | |
else: | |
color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) | |
def load(): | |
# Descargamos el dataset | |
dataset = load_dataset("CHSTR/docexplore") | |
print(dataset) | |
print(dataset['features']) | |
#local_dir = "./" | |
#dataset.save_to_disk(local_dir) | |
path_images = dataset['features']['image']['filename'] | |
path_model = hf_hub_download(repo_id="CHSTR/DocExplore", filename="epoch=16-mAP=0.66_triplet.ckpt")#"models/epoch=16-mAP=0.66_triplet.ckpt" | |
model = Model() | |
model_checkpoint = torch.load(path_model, map_location=device) # 'model_60k_images_073.ckpt' -> modelo entrenado con 60k imagenes sin pidinet | |
model.load_state_dict(model_checkpoint['state_dict']) # 'modified_model_083.ckpt' -> modelo entrenado con 60k imagenes con pidinet | |
model.eval() # 'original_model_083.ckpt' -> modelo original entrenado con 60k imagenes con pidinet | |
print("Modelo cargado exitosamente") | |
embeddings_file_1 = hf_hub_download(repo_id="CHSTR/DocExplore", filename="dino_flicker_docexplore_groundingDINO.pkl") | |
embeddings_file_0 = hf_hub_download(repo_id="CHSTR/DocExplore", filename="docexp_embeddings.pkl") | |
embeddings = { | |
0: pkl.load(open(embeddings_file_0, "rb")), | |
1: pkl.load(open(embeddings_file_1, "rb")) | |
} | |
# embeddings = { | |
# 0: pkl.load(open("docexp_embeddings.pkl", "rb")), | |
# 1: pkl.load(open("dino_flicker_docexplore_groundingDINO.pkl", "rb")) | |
# } | |
# Actualizar los paths de las imágenes en los embeddings | |
#for i in range(len(embeddings[0])): | |
# print(embeddings[0][i]) | |
#embeddings[0][i] = (embeddings[0][i][0], path_images + "/".join(embeddings[0][i][1].split("/")[:-3])) | |
#for i in range(len(embeddings[1])): | |
# print(embeddings[1][i]) | |
#embeddings[1][i] = (embeddings[1][i][0], path_images + "/".join(embeddings[1][i][1].split("/")[:-3])) | |
return model, path_images, embeddings | |
print("Cargando modelos...") | |
model, path_images, embeddings = load() | |
source = {0: "\nDocExplore SAM", 1: "\nDocExplore GroundingDINO"} | |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5) | |
dataset_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def compute_text_embeddings(sketch): | |
with torch.no_grad(): | |
sketch_feat = model(sketch.to(device), dtype='sketch') | |
return sketch_feat | |
# inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to( | |
# device | |
# ) | |
# with torch.no_grad(): | |
# result = clip_model.get_text_features(**inputs).detach().cpu().numpy() | |
# return result / np.linalg.norm(result, axis=1, keepdims=True) | |
#return torch.randn(1, 768) | |
def image_search(query, corpus, n_results=N_RESULTS): | |
query_embedding = compute_text_embeddings(query) | |
corpus_id = 0 if corpus == "DocExplore SAM" else 1 | |
image_features = torch.tensor([item[0] for item in embeddings[corpus_id]]).to(device) | |
bbox_of_images = torch.tensor([item[1] for item in embeddings[corpus_id]]).to(device) | |
label_of_images = torch.tensor([item[2] for item in embeddings[corpus_id]]).to(device) | |
dot_product = (image_features @ query_embedding.T)[:, 0] | |
_, max_indices = torch.topk(dot_product, n_results, dim=0, largest=True, sorted=True) | |
return [ | |
( | |
path_images + "page" + str(i) + ".jpg", | |
) | |
for i in label_of_images[max_indices].cpu().numpy().tolist() | |
], bbox_of_images[max_indices], dot_product[max_indices] | |
def make_square(img, fill_color=(255, 255, 255)): | |
x, y = img.size | |
size = max(x, y) | |
new_img = Image.new("RGB", (x, y), fill_color) | |
new_img.paste(img) | |
return new_img, x, y | |
def get_images(paths): | |
def process_image(path): | |
return make_square(Image.open(path)) | |
processed = Pool(N_RESULTS).map(process_image, paths) | |
imgs, xs, ys = [], [], [] | |
for img, x, y in processed: | |
imgs.append(img) | |
xs.append(x) | |
ys.append(y) | |
return imgs, xs, ys | |
def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8): | |
candidates = [] | |
for box, score in zip(boxes, scores): | |
box = [round(i, 0) for i in box.tolist()] | |
if score >= score_threshold: | |
candidates.append((box, float(score))) | |
to_ignore = set() | |
for i in range(len(candidates) - 1): | |
if i in to_ignore: | |
continue | |
for j in range(i + 1, len(candidates)): | |
if j in to_ignore: | |
continue | |
xmin1, ymin1, xmax1, ymax1 = candidates[i][0] | |
xmin2, ymin2, xmax2, ymax2 = candidates[j][0] | |
if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1: | |
continue | |
else: | |
xmin_inter, xmax_inter = sorted( | |
[xmin1, xmax1, xmin2, xmax2])[1:3] | |
ymin_inter, ymax_inter = sorted( | |
[ymin1, ymax1, ymin2, ymax2])[1:3] | |
area_inter = (xmax_inter - xmin_inter) * \ | |
(ymax_inter - ymin_inter) | |
area1 = (xmax1 - xmin1) * (ymax1 - ymin1) | |
area2 = (xmax2 - xmin2) * (ymax2 - ymin2) | |
iou = area_inter / (area1 + area2 - area_inter) | |
if iou > max_iou: | |
if candidates[i][1] > candidates[j][1]: | |
to_ignore.add(j) | |
else: | |
to_ignore.add(i) | |
break | |
else: | |
if area_inter / area1 > 0.9: | |
if candidates[i][1] < 1.1 * candidates[j][1]: | |
to_ignore.add(i) | |
if area_inter / area2 > 0.9: | |
if 1.1 * candidates[i][1] > candidates[j][1]: | |
to_ignore.add(j) | |
return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore] | |
def convert_pil_to_base64(image): | |
img_buffer = BytesIO() | |
image.save(img_buffer, format="JPEG") | |
byte_data = img_buffer.getvalue() | |
base64_str = base64.b64encode(byte_data) | |
return base64_str | |
def draw_reshape_encode(img, boxes, x, y): | |
boxes = [boxes.tolist()] | |
image = img.copy() | |
draw = ImageDraw.Draw(image) | |
new_x, new_y = int(x * HEIGHT / y), HEIGHT | |
for box in boxes: | |
print("box:", box) | |
draw.rectangle( | |
[(box[0], box[1]), (box[2], box[3])], # (x_min, y_min, x_max, y_max) | |
outline=color, # Box color | |
width=10 # Box width | |
) | |
#if x > y: | |
# image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2)) | |
#else: | |
# image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y)) | |
return convert_pil_to_base64(image.resize((new_x, new_y))) | |
def get_html(url_list, encoded_images): | |
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" | |
for i in range(len(url_list)): | |
title, encoded = url_list[i][0], encoded_images[i] | |
html = ( | |
html | |
+ f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 1px' src='data:image/jpeg;base64,{encoded.decode()}'>" | |
) | |
html += "</div>" | |
return html | |
description = """ | |
# Sketch-based Detection | |
This app retrieves images from the [DocExplore](https://www.docexplore.eu/?lang=en) dataset based on a sketch query. | |
**Tip 1**: you can draw a sketch in the canvas. | |
**Tip 2**: you can change the size of the stroke with the slider. | |
The model utilized in this application is a DINOv2, which was trained in a self-supervised manner on the Flickr25k dataset. | |
""" | |
div_style = { | |
"display": "flex", | |
"justify-content": "center", | |
"flex-wrap": "wrap", | |
} | |
def main(): | |
st.markdown( | |
""" | |
<style> | |
.block-container{ | |
max-width: 1600px; | |
} | |
div.row-widget > div{ | |
flex-direction: row; | |
display: flex; | |
justify-content: center; | |
} | |
div.row-widget.stRadio > div > label{ | |
margin-left: 5px; | |
margin-right: 5px; | |
} | |
.row-widget { | |
margin-top: -25px; | |
} | |
section > div:first-child { | |
padding-top: 30px; | |
} | |
div.appview-container > section:first-child{ | |
max-width: 320px; | |
} | |
#MainMenu { | |
visibility: hidden; | |
} | |
.stMarkdown { | |
display: grid; | |
place-items: center; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown(description) | |
st.title("One-Shot Detection") | |
# Create two main columns | |
left_col, right_col = st.columns([0.2, 0.8]) # Adjust the weights as needed | |
with left_col: | |
# Canvas for drawing | |
canvas_result = st_canvas( | |
background_color="#eee", | |
stroke_width=stroke_width, | |
update_streamlit=True, | |
height=300, | |
width=300, | |
key="color_annotation_app", | |
) | |
# Input controls | |
query = [0] | |
corpus = st.radio("", ["DocExplore SAM", "DocExplore GroundingDINO"], index=0) | |
# score_threshold = st.slider( | |
# "Score threshold", min_value=0.01, max_value=1.0, value=0.5, step=0.01 | |
# ) | |
with right_col: | |
if canvas_result.image_data is not None: | |
draw = Image.fromarray(canvas_result.image_data.astype("uint8")) | |
draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224)) | |
draw.save("draw.jpg") | |
draw_tensor = transforms.ToTensor()(draw) | |
draw_tensor = transforms.Resize((224, 224))(draw_tensor) | |
draw_tensor = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
)(draw_tensor) | |
draw_tensor = draw_tensor.unsqueeze(0) | |
else: | |
return | |
if len(query) > 0: | |
retrieved, bbox_of_images, dot_product = image_search(draw_tensor, corpus) | |
imgs, xs, ys = get_images([x[0] for x in retrieved]) | |
encoded_images = [] | |
for image_idx in range(len(imgs)): | |
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] | |
encoded_images.append(draw_reshape_encode(img0, bbox_of_images[image_idx], x, y)) | |
st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() | |