roy214 commited on
Commit
e399303
·
verified ·
1 Parent(s): 63b530b

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/streamlit_app.py +204 -0
  2. src/styles.csv +0 -0
src/streamlit_app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import boto3
3
+ from botocore.exceptions import NoCredentialsError
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ import faiss
10
+ import pickle
11
+ import torch
12
+ from transformers import CLIPModel, CLIPProcessor
13
+ from huggingface_hub import hf_hub_download, snapshot_download
14
+ import json
15
+ import requests
16
+
17
+ # Khởi tạo client S3 với thông tin cấu hình từ secrets
18
+ s3 = boto3.client('s3')
19
+
20
+
21
+ def get_image_from_s3(bucket_name, img_id):
22
+ try:
23
+ # Trả về URL S3 trực tiếp cho ảnh
24
+ img_url = f"https://{bucket_name}.s3.amazonaws.com/{img_id}.jpg"
25
+ return img_url
26
+ except Exception as e:
27
+ st.error(f"Error constructing image URL: {e}")
28
+ return None
29
+
30
+ def show_img(img_id, score=None, col=None):
31
+ # Lấy URL ảnh từ S3
32
+ img_url = get_image_from_s3(bucket_name, img_id)
33
+
34
+ if img_url:
35
+ try:
36
+ # Tải ảnh từ URL S3
37
+ response = requests.get(img_url)
38
+ response.raise_for_status() # Kiểm tra nếu có lỗi trong quá trình tải ảnh
39
+
40
+ # Mở ảnh từ dữ liệu trong bộ nhớ
41
+ img = Image.open(BytesIO(response.content))
42
+
43
+ # Lấy thông tin style từ img_id (giả sử bạn có một dataframe style)
44
+ img_style = style[style['id'] == int(img_id)]
45
+
46
+ if not img_style.empty:
47
+ parts = []
48
+ parts.append(str(img_style['gender'].values[0]))
49
+ parts.append(str(img_style['masterCategory'].values[0]))
50
+ parts.append(str(img_style['subCategory'].values[0]))
51
+ parts.append(str(img_style['articleType'].values[0]))
52
+ parts.append(str(img_style['baseColour'].values[0]))
53
+ parts.append(str(img_style['year'].values[0]))
54
+ parts.append(str(img_style['usage'].values[0]))
55
+ parts.append(str(img_style['productDisplayName'].values[0]))
56
+
57
+ text = '- '.join(parts)
58
+ if score:
59
+ text += f'\n\n Score: {score:.2f}'
60
+
61
+ # Hiển thị ảnh trong cột
62
+ if col:
63
+ col.image(img, caption=text, use_container_width=True)
64
+ else:
65
+ st.write("img_style is empty")
66
+
67
+ except requests.exceptions.RequestException as e:
68
+ st.error(f"Error fetching image: {e}")
69
+ except Exception as e:
70
+ st.error(f"Error processing image: {e}")
71
+
72
+ def search_faiss(model, processor, index, id_map, prompt, top_k=5, device='cpu'):
73
+ st.write(f"Running FAISS search for prompt: '{prompt}' with top_k={top_k}")
74
+ inputs = processor(text=[prompt], return_tensors='pt', padding=True).to(device)
75
+ st.write("Prompt processed by tokenizer.")
76
+
77
+ with torch.no_grad():
78
+ txt_emb = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
79
+ txt_emb = txt_emb / txt_emb.norm(p=2, dim=-1, keepdim=True)
80
+ st.write("Text embedding computed.")
81
+
82
+ q = txt_emb.cpu().numpy().astype('float32')
83
+
84
+ D, I = index.search(q, top_k)
85
+
86
+ st.write("FAISS search completed.")
87
+ # st.write("Indices returned:", I[0])
88
+ # st.write("Scores returned:", D[0])
89
+ # st.write("ID map keys sample:", list(id_map.keys())[:10])
90
+ return [(id_map[i], float(D[0][j])) for j, i in enumerate(I[0])]
91
+
92
+ def running(prompt, top_k=5):
93
+ st.write("Starting image retrieval...")
94
+ results = search_faiss(
95
+ model, processor,
96
+ index, id_map,
97
+ prompt=prompt,
98
+ top_k=top_k,
99
+ )
100
+
101
+ # Chia thành các cột (5 ảnh mỗi hàng)
102
+ cols = st.columns(5) # Chia thành 5 cột
103
+ col_idx = 0
104
+ for img_id, score in results:
105
+ # st.write(f"results: {img_id} và {score}")
106
+ show_img(img_id, score, col=cols[col_idx])
107
+ col_idx += 1
108
+ if col_idx == 5: # Sau khi hiển thị 5 ảnh, reset cột
109
+ col_idx = 0
110
+
111
+ if not results:
112
+ st.warning("No results were returned from FAISS. Check your prompt or embedding.")
113
+
114
+
115
+ # Đọc file CSV
116
+
117
+ current_dir = os.path.dirname(__file__)
118
+ csv_path = os.path.join(current_dir, 'styles.csv')
119
+ style = pd.read_csv(csv_path, usecols=range(10)) # Sửa lại đường dẫn nếu cần
120
+ bucket_name = "image-text-retrieval" # Tên bucket của bạn
121
+ your_username = 'roy214'
122
+
123
+ # Dùng thư mục được phép ghi
124
+ os.environ["HF_HOME"] = "/tmp/huggingface"
125
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
126
+ os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
127
+
128
+ hf_token = os.environ["HUGGINGFACE_TOKEN"]
129
+ your_username = "roy214"
130
+ model_repo = f"{your_username}/clip-finetuned-fashion"
131
+
132
+ # Tải toàn bộ repo về thư mục /tmp
133
+ model_dir = snapshot_download(
134
+ repo_id=model_repo,
135
+ token=hf_token,
136
+ local_dir="/tmp/model", # Chỉ định nơi lưu
137
+ local_dir_use_symlinks=False # Tránh tạo symlink vào /.cache
138
+ )
139
+
140
+ # Load model using the local path + token
141
+ model = CLIPModel.from_pretrained(
142
+ model_dir,
143
+ use_auth_token=hf_token,
144
+ device_map="auto", # Tự động phân phối weights lên CPU/GPU
145
+ low_cpu_mem_usage=True, # Giảm RAM khi load
146
+ ).eval()
147
+
148
+
149
+ index_path = os.path.join(model_dir, "faiss_index.bin")
150
+ mapping_path = os.path.join(model_dir, "id_map.json")
151
+
152
+ # Kiểm tra file tồn tại
153
+ assert os.path.isfile(index_path), f"Không tìm thấy {index_path}"
154
+ assert os.path.isfile(mapping_path), f"Không tìm thấy {mapping_path}"
155
+
156
+ # Load index
157
+ index = faiss.read_index(index_path)
158
+
159
+ # 4. Load processor cũng từ thư mục local
160
+ processor = CLIPProcessor.from_pretrained(
161
+ model_dir,
162
+ use_auth_token=hf_token
163
+ )
164
+
165
+ with open(mapping_path, "rb") as f:
166
+ id_map = pickle.load(f)
167
+
168
+ st.title("Fashion Product Image Retrieval")
169
+
170
+ st.markdown("""
171
+ ### **Overview**
172
+
173
+ In this project, I demonstrate an **Image Retrieval** system for fashion products. The system uses a fine-tuned **CLIP model** (`clip-vit-base-patch32`) to match images with relevant text descriptions. We have a dataset of **1000 fashion product images**, stored on **Amazon S3**. Each image is associated with detailed product descriptions, such as **product type**, **color**, **category**, and **brand**.
174
+
175
+ The goal of this system is to retrieve the most relevant fashion images based on a given text prompt (e.g., "red dress") and vice versa. With this system, users can search for fashion products in a more intuitive, text-based manner.
176
+
177
+ #### Key Features:
178
+ - **Dataset**: 1000 fashion product images with descriptive text.
179
+ - **Storage**: Images are stored on **Amazon S3**.
180
+ - **Model**: Fine-tuned **OpenAI CLIP model** (`clip-vit-base-patch32`) on the dataset.
181
+ - **Objective**: Given a prompt like "red dress", the system retrieves the most relevant images.
182
+ """)
183
+
184
+ # Example to show some images
185
+ st.subheader("Some sample images and their captions:")
186
+ example = [13422, 10037, 38246, 23273, 2008]
187
+ example_cols = st.columns(5) # Chia thành 5 cột
188
+ for idx, img_id in enumerate(example):
189
+ show_img(img_id, None, example_cols[idx])
190
+
191
+ # Chạy ví dụ với prompt
192
+ st.subheader("Example usage: enter a prompt to retrieve related images")
193
+ with st.form(key="retrieval_form"):
194
+ prompt_input = st.text_input("Enter a prompt", placeholder="e.g., a red Apparel dress")
195
+ top_k_input = st.number_input("Enter the number of results (top_k)", min_value=1, max_value=10, value=5, step=1)
196
+
197
+ submitted = st.form_submit_button(label="Find Related Images")
198
+
199
+ # Khi người dùng nhấn nút Submit
200
+ if submitted:
201
+ if prompt_input.strip() and top_k_input > 0:
202
+ running(prompt_input, top_k_input)
203
+ else:
204
+ st.warning("Please enter a valid prompt and top_k.")
src/styles.csv ADDED
The diff for this file is too large to render. See raw diff