import os import streamlit as st from io import BytesIO from multiprocessing.dummy import Pool import base64 from PIL import Image, ImageOps import torch from torchvision import transforms from streamlit_drawable_canvas import st_canvas from src.model_LN_prompt import Model from html import escape import pickle as pkl from huggingface_hub import hf_hub_download, login from datasets import load_dataset if 'initialized' not in st.session_state: st.session_state.initialized = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") HEIGHT = 200 N_RESULTS = 20 color = st.get_option("theme.primaryColor") if color is None: color = (0, 0, 255) else: color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) @st.cache_resource def initialize_huggingface(): token = os.getenv("HUGGINGFACE_TOKEN") if token: login(token=token) else: st.error("HUGGINGFACE_TOKEN not found in environment variables") @st.cache_resource def load_model_and_data(): print("Loading everything...") dataset = load_dataset("CHSTR/ecommerce") path_images = "/".join(dataset['validation'] ['image'][0].filename.split("/")[:-3]) + "/" # Download model path_model = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt") # Load model model = Model().to(device) model_checkpoint = torch.load(path_model, map_location=device) model.load_state_dict(model_checkpoint['state_dict']) model.eval() # Download and load embeddings embeddings_file = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl") embeddings = { 0: pkl.load(open(embeddings_file, "rb")), 1: pkl.load(open(embeddings_file, "rb")) } # Update image paths for corpus_id in [0, 1]: embeddings[corpus_id] = [ (emb[0], path_images + "/".join(emb[1].split("/")[-3:])) for emb in embeddings[corpus_id] ] return model, path_images, embeddings def compute_sketch(_sketch, model): with torch.no_grad(): sketch_feat = model(_sketch.to(device), dtype='sketch') return sketch_feat def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS): query_embedding = compute_sketch(_query, model) corpus_id = 0 if corpus == "Unsplash" else 1 image_features = torch.tensor( [item[0] for item in embeddings[corpus_id]]).to(device) dot_product = (image_features @ query_embedding.T)[:, 0] _, max_indices = torch.topk( dot_product, n_results, dim=0, largest=True, sorted=True) path_to_label = {path: idx for idx, (_, path) in enumerate(embeddings[corpus_id])} label_to_path = {idx: path for path, idx in path_to_label.items()} label_of_images = torch.tensor( [path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device) return [ (label_to_path[i],) for i in label_of_images[max_indices].cpu().numpy().tolist() ], dot_product[max_indices] @st.cache_data def make_square(img_path, fill_color=(255, 255, 255)): img = Image.open(img_path) x, y = img.size size = max(x, y) new_img = Image.new("RGB", (x, y), fill_color) new_img.paste(img) return new_img, x, y @st.cache_data def get_images(paths): processed = [make_square(path) for path in paths] imgs, xs, ys = zip(*processed) return list(imgs), list(xs), list(ys) @st.cache_data def convert_pil_to_base64(image): img_buffer = BytesIO() image.save(img_buffer, format="JPEG") byte_data = img_buffer.getvalue() base64_str = base64.b64encode(byte_data) return base64_str def get_html(url_list, encoded_images): html = "
" for i in range(len(url_list)): title, encoded = url_list[i][0], encoded_images[i] html = ( html + f"" ) html += "
" return html def main(): if not st.session_state.initialized: initialize_huggingface() st.session_state.model, st.session_state.path_images, st.session_state.embeddings = load_model_and_data() st.session_state.initialized = True description = """ # Self-Supervised Sketch-based Image Retrieval (S3BIR) Our approaches, S3BIR-CLIP and S3BIR-DINOv2, can produce a bimodal sketch-photo feature space from unpaired data without explicit sketch-photo pairs. Our experiments perform outstandingly in three diverse public datasets where the models are trained without real sketches. """ st.sidebar.markdown(description) stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5) # styles st.markdown( """ """, unsafe_allow_html=True, ) st.title("S3BIR App") _, col, _ = st.columns((1, 1, 1)) with col: canvas_result = st_canvas( background_color="#eee", stroke_width=stroke_width, update_streamlit=True, height=300, width=300, key="color_annotation_app", ) corpus = ["Ecommerce"] st.columns((1, 3, 1)) if canvas_result.image_data is not None: draw = Image.fromarray(canvas_result.image_data.astype("uint8")) draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224)) draw_tensor = transforms.ToTensor()(draw) draw_tensor = transforms.Resize((224, 224))(draw_tensor) draw_tensor = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(draw_tensor) draw_tensor = draw_tensor.unsqueeze(0) retrieved, _ = image_search( draw_tensor, corpus[0], st.session_state.model, st.session_state.embeddings) imgs, xs, ys = get_images([x[0] for x in retrieved]) encoded_images = [] for image_idx in range(len(imgs)): img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] new_x, new_y = int(x * HEIGHT / y), HEIGHT encoded_images.append(convert_pil_to_base64( img0.resize((new_x, new_y)))) st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True) if __name__ == "__main__": main()