roy214 commited on
Commit
391d821
·
verified ·
1 Parent(s): 3afd1d7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +16 -9
src/streamlit_app.py CHANGED
@@ -137,13 +137,13 @@ model_dir = snapshot_download(
137
  local_dir_use_symlinks=False # Tránh tạo symlink vào /.cache
138
  )
139
 
140
- print("Model directory:", model_dir)
141
-
142
  # Load model using the local path + token
143
  model = CLIPModel.from_pretrained(
144
  model_dir,
145
- use_auth_token=hf_token
146
- ).to("cpu").eval()
 
 
147
 
148
 
149
  index_path = os.path.join(model_dir, "faiss_index.bin")
@@ -190,8 +190,15 @@ for idx, img_id in enumerate(example):
190
 
191
  # Chạy ví dụ với prompt
192
  st.subheader("Example usage: enter a prompt to retrieve related images")
193
- prompt_input = st.text_input("Enter a prompt", "a red Apparel dress")
194
- top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5)
195
-
196
- if st.button(f"Find top {top_k_input} related images"):
197
- running(prompt_input, top_k_input)
 
 
 
 
 
 
 
 
137
  local_dir_use_symlinks=False # Tránh tạo symlink vào /.cache
138
  )
139
 
 
 
140
  # Load model using the local path + token
141
  model = CLIPModel.from_pretrained(
142
  model_dir,
143
+ use_auth_token=hf_token,
144
+ device_map="auto", # Tự động phân phối weights lên CPU/GPU
145
+ low_cpu_mem_usage=True, # Giảm RAM khi load
146
+ ).eval()
147
 
148
 
149
  index_path = os.path.join(model_dir, "faiss_index.bin")
 
190
 
191
  # Chạy ví dụ với prompt
192
  st.subheader("Example usage: enter a prompt to retrieve related images")
193
+ with st.form(key="retrieval_form"):
194
+ prompt_input = st.text_input("Enter a prompt", placeholder="e.g., a red Apparel dress")
195
+ top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5, step=1)
196
+
197
+ submitted = st.form_submit_button(label="Find Related Images")
198
+
199
+ # Khi người dùng nhấn nút Submit
200
+ if submitted:
201
+ if prompt_input.strip() and top_k_input > 0:
202
+ running(prompt_input, top_k_input)
203
+ else:
204
+ st.warning("Please enter a valid prompt and top_k.")