import os import streamlit as st from io import BytesIO from multiprocessing.dummy import Pool import base64 from PIL import Image, ImageOps import torch from torchvision import transforms from streamlit_drawable_canvas import st_canvas from src.model_LN_prompt import Model from html import escape import pickle as pkl from huggingface_hub import hf_hub_download, login from datasets import load_dataset if 'initialized' not in st.session_state: st.session_state.initialized = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") HEIGHT = 200 N_RESULTS = 20 color = st.get_option("theme.primaryColor") if color is None: color = (0, 0, 255) else: color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) @st.cache_resource def initialize_huggingface(): token = os.getenv("HUGGINGFACE_TOKEN") if token: login(token=token) else: st.error("HUGGINGFACE_TOKEN not found in environment variables") @st.cache_resource def load_model_and_data(): print("Loading everything...") dataset = load_dataset("CHSTR/ecommerce") path_images = "/".join(dataset['validation'] ['image'][0].filename.split("/")[:-3]) + "/" # Download model path_model = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="dinov2_ecommerce.ckpt") # Load model model = Model().to(device) model_checkpoint = torch.load(path_model, map_location=device) model.load_state_dict(model_checkpoint['state_dict']) model.eval() # Download and load embeddings embeddings_file = hf_hub_download( repo_id="CHSTR/Ecommerce", filename="ecommerce_demo.pkl") embeddings = { 0: pkl.load(open(embeddings_file, "rb")), 1: pkl.load(open(embeddings_file, "rb")) } # Update image paths for corpus_id in [0, 1]: embeddings[corpus_id] = [ (emb[0], path_images + "/".join(emb[1].split("/")[-3:])) for emb in embeddings[corpus_id] ] return model, path_images, embeddings def compute_sketch(_sketch, model): with torch.no_grad(): sketch_feat = model(_sketch.to(device), dtype='sketch') return sketch_feat def image_search(_query, corpus, model, embeddings, n_results=N_RESULTS): query_embedding = compute_sketch(_query, model) corpus_id = 0 if corpus == "Unsplash" else 1 image_features = torch.tensor( [item[0] for item in embeddings[corpus_id]]).to(device) dot_product = (image_features @ query_embedding.T)[:, 0] _, max_indices = torch.topk( dot_product, n_results, dim=0, largest=True, sorted=True) path_to_label = {path: idx for idx, (_, path) in enumerate(embeddings[corpus_id])} label_to_path = {idx: path for path, idx in path_to_label.items()} label_of_images = torch.tensor( [path_to_label[item[1]] for item in embeddings[corpus_id]]).to(device) return [ (label_to_path[i],) for i in label_of_images[max_indices].cpu().numpy().tolist() ], dot_product[max_indices] @st.cache_data def make_square(img_path, fill_color=(255, 255, 255)): img = Image.open(img_path) x, y = img.size size = max(x, y) new_img = Image.new("RGB", (x, y), fill_color) new_img.paste(img) return new_img, x, y @st.cache_data def get_images(paths): processed = [make_square(path) for path in paths] imgs, xs, ys = zip(*processed) return list(imgs), list(xs), list(ys) @st.cache_data def convert_pil_to_base64(image): img_buffer = BytesIO() image.save(img_buffer, format="JPEG") byte_data = img_buffer.getvalue() base64_str = base64.b64encode(byte_data) return base64_str def get_html(url_list, encoded_images): html = "