Spaces:
Sleeping
Sleeping
File size: 8,155 Bytes
9f0ec46 5969029 04792be 5969029 60c342d 67a9702 76c450e 69c4c51 b8fa391 76c450e f0e3479 186923b f0e3479 d7c7b18 f0e3479 186923b b8fa391 f0e3479 5969029 f0e3479 5969029 f0e3479 67a9702 f0e3479 67a9702 186923b b8fa391 186923b b8fa391 186923b dc9925a 186923b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import streamlit as st
import torch
import clip
from PIL import Image
import os
import numpy as np
import chromadb
import requests
import tempfile
import time
# ----- Setup -----
st.set_page_config(page_title="CLIP Image Search", layout="wide")
CACHE_DIR = tempfile.gettempdir()
CHROMA_PATH = os.path.join(CACHE_DIR, "chroma_db")
DEMO_DIR = os.path.join(CACHE_DIR, "demo_images")
os.makedirs(DEMO_DIR, exist_ok=True)
# ----- Session State Init -----
if 'dataset_loaded' not in st.session_state:
st.session_state.dataset_loaded = False
if 'dataset_name' not in st.session_state:
st.session_state.dataset_name = None
if 'demo_images' not in st.session_state:
st.session_state.demo_images = []
if 'user_images' not in st.session_state:
st.session_state.user_images = []
# ----- Load CLIP Model -----
if 'model' not in st.session_state:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, download_root=CACHE_DIR)
st.session_state.model = model
st.session_state.preprocess = preprocess
st.session_state.device = device
# ----- Initialize ChromaDB -----
if 'chroma_client' not in st.session_state:
st.session_state.chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
name="demo_images", metadata={"hnsw:space": "cosine"}
)
st.session_state.user_collection = st.session_state.chroma_client.get_or_create_collection(
name="user_images", metadata={"hnsw:space": "cosine"}
)
# ----- Sidebar -----
with st.sidebar:
st.title("π§ CLIP Search App")
st.markdown("Choose a dataset to begin:")
if st.button("π¦ Load Demo Images"):
st.session_state.dataset_name = "demo"
st.session_state.dataset_loaded = False
if st.button("π€ Upload Your Images"):
st.session_state.dataset_name = "user"
st.session_state.dataset_loaded = False
# ----- Helper -----
def download_image_with_retry(url, path, retries=3, delay=1.0):
for attempt in range(retries):
try:
r = requests.get(url, timeout=10)
if r.status_code == 200:
with open(path, 'wb') as f:
f.write(r.content)
return True
except Exception:
time.sleep(delay)
return False
# ----- Main App -----
left, right = st.columns([2, 1])
with left:
st.title("π CLIP-Based Image Search")
# ----- Load Demo -----
if st.session_state.dataset_name == "demo" and not st.session_state.dataset_loaded:
with st.spinner("Downloading and indexing demo images..."):
st.session_state.demo_collection.delete(ids=[str(i) for i in range(50)])
demo_image_paths, demo_images = [], []
for i in range(50):
path = os.path.join(DEMO_DIR, f"img_{i+1:02}.jpg")
if not os.path.exists(path):
url = f"https://picsum.photos/seed/{i}/1024/768"
download_image_with_retry(url, path)
try:
demo_images.append(Image.open(path).convert("RGB"))
demo_image_paths.append(path)
except:
continue
embeddings, ids, metadatas = [], [], []
for i, img in enumerate(demo_images):
img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
with torch.no_grad():
embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
embeddings.append(embedding)
ids.append(str(i))
metadatas.append({"path": demo_image_paths[i]})
st.session_state.demo_collection.add(embeddings=embeddings, ids=ids, metadatas=metadatas)
st.session_state.demo_images = demo_images
st.session_state.dataset_loaded = True
st.success("β
Demo images loaded!")
# ----- Upload User Images -----
if st.session_state.dataset_name == "user" and not st.session_state.dataset_loaded:
uploaded = st.file_uploader("Upload your images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
if uploaded:
st.session_state.user_collection.delete(ids=[
str(i) for i in range(st.session_state.user_collection.count())
])
user_images = []
for i, file in enumerate(uploaded):
try:
img = Image.open(file).convert("RGB")
except:
continue
user_images.append(img)
img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
with torch.no_grad():
embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
st.session_state.user_collection.add(
embeddings=[embedding], ids=[str(i)], metadatas=[{"index": i}]
)
st.session_state.user_images = user_images
st.session_state.dataset_loaded = True
st.success(f"β
Uploaded {len(user_images)} images.")
# ----- Search Section -----
if st.session_state.dataset_loaded:
st.subheader("π Search")
query_type = st.radio("Search by:", ("Text", "Image"))
query_embedding = None
if query_type == "Text":
text_query = st.text_input("Enter your search prompt:")
if text_query:
tokens = clip.tokenize([text_query]).to(st.session_state.device)
with torch.no_grad():
query_embedding = st.session_state.model.encode_text(tokens).cpu().numpy().flatten()
elif query_type == "Image":
query_file = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"], key="query_image")
if query_file:
query_img = Image.open(query_file).convert("RGB")
st.image(query_img, caption="Query Image", width=200)
query_tensor = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
with torch.no_grad():
query_embedding = st.session_state.model.encode_image(query_tensor).cpu().numpy().flatten()
# ----- Perform Search -----
if query_embedding is not None:
if st.session_state.dataset_name == "demo":
collection = st.session_state.demo_collection
images = st.session_state.demo_images
else:
collection = st.session_state.user_collection
images = st.session_state.user_images
if collection.count() > 0:
results = collection.query(
query_embeddings=[query_embedding],
n_results=min(5, collection.count())
)
ids = results["ids"][0]
distances = results["distances"][0]
similarities = [1 - d for d in distances]
st.subheader("π― Top Matches")
cols = st.columns(len(ids))
for i, (img_id, sim) in enumerate(zip(ids, similarities)):
with cols[i]:
st.image(images[int(img_id)], caption=f"Similarity: {sim:.3f}", use_column_width=True)
else:
st.warning("β οΈ No images available for search.")
else:
st.info("π Choose a dataset from the sidebar to get started.")
# ----- Right Panel: Show Current Dataset Images -----
with right:
st.subheader("πΌοΈ Dataset Preview")
image_list = st.session_state.demo_images if st.session_state.dataset_name == "demo" else st.session_state.user_images
if st.session_state.dataset_loaded and image_list:
st.caption(f"Showing {len(image_list)} images")
for i, img in enumerate(image_list[:20]):
st.image(img, use_column_width=True)
else:
st.markdown("No images to preview yet.")
|