Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import clip | |
from PIL import Image | |
import os | |
import numpy as np | |
import chromadb | |
import requests | |
import tempfile | |
import time | |
# ----- Setup ----- | |
st.set_page_config(page_title="CLIP Image Search", layout="wide") | |
CACHE_DIR = tempfile.gettempdir() | |
CHROMA_PATH = os.path.join(CACHE_DIR, "chroma_db") | |
DEMO_DIR = os.path.join(CACHE_DIR, "demo_images") | |
os.makedirs(DEMO_DIR, exist_ok=True) | |
# ----- Session State Init ----- | |
if 'dataset_loaded' not in st.session_state: | |
st.session_state.dataset_loaded = False | |
if 'dataset_name' not in st.session_state: | |
st.session_state.dataset_name = None | |
if 'demo_images' not in st.session_state: | |
st.session_state.demo_images = [] | |
if 'user_images' not in st.session_state: | |
st.session_state.user_images = [] | |
# ----- Load CLIP Model ----- | |
if 'model' not in st.session_state: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device, download_root=CACHE_DIR) | |
st.session_state.model = model | |
st.session_state.preprocess = preprocess | |
st.session_state.device = device | |
# ----- Initialize ChromaDB ----- | |
if 'chroma_client' not in st.session_state: | |
st.session_state.chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) | |
st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection( | |
name="demo_images", metadata={"hnsw:space": "cosine"} | |
) | |
st.session_state.user_collection = st.session_state.chroma_client.get_or_create_collection( | |
name="user_images", metadata={"hnsw:space": "cosine"} | |
) | |
# ----- Sidebar ----- | |
with st.sidebar: | |
st.title("π§ CLIP Search App") | |
st.markdown("Choose a dataset to begin:") | |
if st.button("π¦ Load Demo Images"): | |
st.session_state.dataset_name = "demo" | |
st.session_state.dataset_loaded = False | |
if st.button("π€ Upload Your Images"): | |
st.session_state.dataset_name = "user" | |
st.session_state.dataset_loaded = False | |
# ----- Helper ----- | |
def download_image_with_retry(url, path, retries=3, delay=1.0): | |
for attempt in range(retries): | |
try: | |
r = requests.get(url, timeout=10) | |
if r.status_code == 200: | |
with open(path, 'wb') as f: | |
f.write(r.content) | |
return True | |
except Exception: | |
time.sleep(delay) | |
return False | |
# ----- Main App ----- | |
left, right = st.columns([2, 1]) | |
with left: | |
st.title("π CLIP-Based Image Search") | |
# ----- Load Demo ----- | |
if st.session_state.dataset_name == "demo" and not st.session_state.dataset_loaded: | |
with st.spinner("Downloading and indexing demo images..."): | |
st.session_state.demo_collection.delete(ids=[str(i) for i in range(50)]) | |
demo_image_paths, demo_images = [], [] | |
for i in range(50): | |
path = os.path.join(DEMO_DIR, f"img_{i+1:02}.jpg") | |
if not os.path.exists(path): | |
url = f"https://picsum.photos/seed/{i}/1024/768" | |
download_image_with_retry(url, path) | |
try: | |
demo_images.append(Image.open(path).convert("RGB")) | |
demo_image_paths.append(path) | |
except: | |
continue | |
embeddings, ids, metadatas = [], [], [] | |
for i, img in enumerate(demo_images): | |
img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device) | |
with torch.no_grad(): | |
embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten() | |
embeddings.append(embedding) | |
ids.append(str(i)) | |
metadatas.append({"path": demo_image_paths[i]}) | |
st.session_state.demo_collection.add(embeddings=embeddings, ids=ids, metadatas=metadatas) | |
st.session_state.demo_images = demo_images | |
st.session_state.dataset_loaded = True | |
st.success("β Demo images loaded!") | |
# ----- Upload User Images ----- | |
if st.session_state.dataset_name == "user" and not st.session_state.dataset_loaded: | |
uploaded = st.file_uploader("Upload your images", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
if uploaded: | |
st.session_state.user_collection.delete(ids=[ | |
str(i) for i in range(st.session_state.user_collection.count()) | |
]) | |
user_images = [] | |
for i, file in enumerate(uploaded): | |
try: | |
img = Image.open(file).convert("RGB") | |
except: | |
continue | |
user_images.append(img) | |
img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device) | |
with torch.no_grad(): | |
embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten() | |
st.session_state.user_collection.add( | |
embeddings=[embedding], ids=[str(i)], metadatas=[{"index": i}] | |
) | |
st.session_state.user_images = user_images | |
st.session_state.dataset_loaded = True | |
st.success(f"β Uploaded {len(user_images)} images.") | |
# ----- Search Section ----- | |
if st.session_state.dataset_loaded: | |
st.subheader("π Search") | |
query_type = st.radio("Search by:", ("Text", "Image")) | |
query_embedding = None | |
if query_type == "Text": | |
text_query = st.text_input("Enter your search prompt:") | |
if text_query: | |
tokens = clip.tokenize([text_query]).to(st.session_state.device) | |
with torch.no_grad(): | |
query_embedding = st.session_state.model.encode_text(tokens).cpu().numpy().flatten() | |
elif query_type == "Image": | |
query_file = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"], key="query_image") | |
if query_file: | |
query_img = Image.open(query_file).convert("RGB") | |
st.image(query_img, caption="Query Image", width=200) | |
query_tensor = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device) | |
with torch.no_grad(): | |
query_embedding = st.session_state.model.encode_image(query_tensor).cpu().numpy().flatten() | |
# ----- Perform Search ----- | |
if query_embedding is not None: | |
if st.session_state.dataset_name == "demo": | |
collection = st.session_state.demo_collection | |
images = st.session_state.demo_images | |
else: | |
collection = st.session_state.user_collection | |
images = st.session_state.user_images | |
if collection.count() > 0: | |
results = collection.query( | |
query_embeddings=[query_embedding], | |
n_results=min(5, collection.count()) | |
) | |
ids = results["ids"][0] | |
distances = results["distances"][0] | |
similarities = [1 - d for d in distances] | |
st.subheader("π― Top Matches") | |
cols = st.columns(len(ids)) | |
for i, (img_id, sim) in enumerate(zip(ids, similarities)): | |
with cols[i]: | |
st.image(images[int(img_id)], caption=f"Similarity: {sim:.3f}", use_column_width=True) | |
else: | |
st.warning("β οΈ No images available for search.") | |
else: | |
st.info("π Choose a dataset from the sidebar to get started.") | |
# ----- Right Panel: Show Current Dataset Images ----- | |
with right: | |
st.subheader("πΌοΈ Dataset Preview") | |
image_list = st.session_state.demo_images if st.session_state.dataset_name == "demo" else st.session_state.user_images | |
if st.session_state.dataset_loaded and image_list: | |
st.caption(f"Showing {len(image_list)} images") | |
for i, img in enumerate(image_list[:20]): | |
st.image(img, use_column_width=True) | |
else: | |
st.markdown("No images to preview yet.") | |