Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +16 -9
src/streamlit_app.py
CHANGED
|
@@ -137,13 +137,13 @@ model_dir = snapshot_download(
|
|
| 137 |
local_dir_use_symlinks=False # Tránh tạo symlink vào /.cache
|
| 138 |
)
|
| 139 |
|
| 140 |
-
print("Model directory:", model_dir)
|
| 141 |
-
|
| 142 |
# Load model using the local path + token
|
| 143 |
model = CLIPModel.from_pretrained(
|
| 144 |
model_dir,
|
| 145 |
-
use_auth_token=hf_token
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
index_path = os.path.join(model_dir, "faiss_index.bin")
|
|
@@ -190,8 +190,15 @@ for idx, img_id in enumerate(example):
|
|
| 190 |
|
| 191 |
# Chạy ví dụ với prompt
|
| 192 |
st.subheader("Example usage: enter a prompt to retrieve related images")
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
local_dir_use_symlinks=False # Tránh tạo symlink vào /.cache
|
| 138 |
)
|
| 139 |
|
|
|
|
|
|
|
| 140 |
# Load model using the local path + token
|
| 141 |
model = CLIPModel.from_pretrained(
|
| 142 |
model_dir,
|
| 143 |
+
use_auth_token=hf_token,
|
| 144 |
+
device_map="auto", # Tự động phân phối weights lên CPU/GPU
|
| 145 |
+
low_cpu_mem_usage=True, # Giảm RAM khi load
|
| 146 |
+
).eval()
|
| 147 |
|
| 148 |
|
| 149 |
index_path = os.path.join(model_dir, "faiss_index.bin")
|
|
|
|
| 190 |
|
| 191 |
# Chạy ví dụ với prompt
|
| 192 |
st.subheader("Example usage: enter a prompt to retrieve related images")
|
| 193 |
+
with st.form(key="retrieval_form"):
|
| 194 |
+
prompt_input = st.text_input("Enter a prompt", placeholder="e.g., a red Apparel dress")
|
| 195 |
+
top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5, step=1)
|
| 196 |
+
|
| 197 |
+
submitted = st.form_submit_button(label="Find Related Images")
|
| 198 |
+
|
| 199 |
+
# Khi người dùng nhấn nút Submit
|
| 200 |
+
if submitted:
|
| 201 |
+
if prompt_input.strip() and top_k_input > 0:
|
| 202 |
+
running(prompt_input, top_k_input)
|
| 203 |
+
else:
|
| 204 |
+
st.warning("Please enter a valid prompt and top_k.")
|