Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +9 -1
src/streamlit_app.py
CHANGED
|
@@ -68,16 +68,24 @@ def show_img(img_id, score=None, col=None):
|
|
| 68 |
st.error(f"Error processing image: {e}")
|
| 69 |
|
| 70 |
def search_faiss(model, processor, index, id_map, prompt, top_k=5, device='cpu'):
|
|
|
|
| 71 |
inputs = processor(text=[prompt], return_tensors='pt', padding=True).to(device)
|
|
|
|
|
|
|
| 72 |
with torch.no_grad():
|
| 73 |
txt_emb = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
| 74 |
txt_emb = txt_emb / txt_emb.norm(p=2, dim=-1, keepdim=True)
|
|
|
|
|
|
|
| 75 |
q = txt_emb.cpu().numpy().astype('float32')
|
| 76 |
|
| 77 |
D, I = index.search(q, top_k)
|
|
|
|
|
|
|
| 78 |
return [(id_map[i], float(D[0][j])) for j, i in enumerate(I[0])]
|
| 79 |
|
| 80 |
def running(prompt, top_k=5):
|
|
|
|
| 81 |
results = search_faiss(
|
| 82 |
model, processor,
|
| 83 |
index, id_map,
|
|
@@ -175,5 +183,5 @@ st.subheader("Example Usage: Enter a Prompt to Retrieve Related Images")
|
|
| 175 |
prompt_input = st.text_input("Enter a prompt", "a red Apparel dress")
|
| 176 |
top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5)
|
| 177 |
|
| 178 |
-
if st.button("Find top
|
| 179 |
running(prompt_input, top_k_input)
|
|
|
|
| 68 |
st.error(f"Error processing image: {e}")
|
| 69 |
|
| 70 |
def search_faiss(model, processor, index, id_map, prompt, top_k=5, device='cpu'):
|
| 71 |
+
st.write(f"Running FAISS search for prompt: '{prompt}' with top_k={top_k}")
|
| 72 |
inputs = processor(text=[prompt], return_tensors='pt', padding=True).to(device)
|
| 73 |
+
st.write("Prompt processed by tokenizer.")
|
| 74 |
+
|
| 75 |
with torch.no_grad():
|
| 76 |
txt_emb = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
| 77 |
txt_emb = txt_emb / txt_emb.norm(p=2, dim=-1, keepdim=True)
|
| 78 |
+
st.write("Text embedding computed.")
|
| 79 |
+
|
| 80 |
q = txt_emb.cpu().numpy().astype('float32')
|
| 81 |
|
| 82 |
D, I = index.search(q, top_k)
|
| 83 |
+
|
| 84 |
+
st.write("FAISS search completed.")
|
| 85 |
return [(id_map[i], float(D[0][j])) for j, i in enumerate(I[0])]
|
| 86 |
|
| 87 |
def running(prompt, top_k=5):
|
| 88 |
+
st.write("Starting image retrieval...")
|
| 89 |
results = search_faiss(
|
| 90 |
model, processor,
|
| 91 |
index, id_map,
|
|
|
|
| 183 |
prompt_input = st.text_input("Enter a prompt", "a red Apparel dress")
|
| 184 |
top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5)
|
| 185 |
|
| 186 |
+
if st.button(f"Find top {top_k_input} related images"):
|
| 187 |
running(prompt_input, top_k_input)
|