leedoming JoJosmin commited on
Commit
717d9b3
1 Parent(s): f388a49

Update app.py (#4)

Browse files

- Update app.py (4be1017e99d1bb63b66cf583fb47ef5a2cbdcda3)


Co-authored-by: Jo sungmin <JoJosmin@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +36 -37
app.py CHANGED
@@ -7,9 +7,9 @@ from io import BytesIO
7
  import time
8
  import json
9
  import numpy as np
 
10
  import cv2
11
  import chromadb
12
- from ultralytics import YOLO
13
 
14
  # Load CLIP model and tokenizer
15
  @st.cache_resource
@@ -22,16 +22,13 @@ def load_clip_model():
22
 
23
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
24
 
25
- # Load YOLOS model
26
  @st.cache_resource
27
  def load_yolo_model():
28
  return YOLO("./best.pt")
29
 
30
  yolo_model = load_yolo_model()
31
 
32
- # Define the categories
33
- #CATS = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'glove', 'shoe', 'bag', 'wallet', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
34
-
35
  # Helper functions
36
  def load_image_from_url(url, max_retries=3):
37
  for attempt in range(max_retries):
@@ -45,10 +42,9 @@ def load_image_from_url(url, max_retries=3):
45
  time.sleep(1)
46
  else:
47
  return None
48
-
49
  #Load chromaDB
50
  client = chromadb.PersistentClient(path="./clothesDB")
51
- collection = client.get_collection(name="fashion_items_ver2")
52
 
53
  def get_image_embedding(image):
54
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
@@ -89,14 +85,11 @@ def find_similar_images(query_embedding, collection, top_k=5):
89
  })
90
  return results
91
 
 
 
92
  def detect_clothing(image):
93
- #inputs = yolos_processor(images=image, return_tensors="pt")
94
- #outputs = yolos_model(**inputs)
95
-
96
- #target_sizes = torch.tensor([image.size[::-1]])
97
  results = yolo_model(image)
98
  detections = results[0].boxes.data.cpu().numpy()
99
-
100
  categories = []
101
  for detection in detections:
102
  x1, y1, x2, y2, conf, cls = detection
@@ -112,10 +105,7 @@ def detect_clothing(image):
112
  def crop_image(image, bbox):
113
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
114
 
115
- # Streamlit app
116
- st.title("Advanced Fashion Search App")
117
-
118
- # Initialize session state
119
  if 'step' not in st.session_state:
120
  st.session_state.step = 'input'
121
  if 'query_image_url' not in st.session_state:
@@ -125,7 +115,10 @@ if 'detections' not in st.session_state:
125
  if 'selected_category' not in st.session_state:
126
  st.session_state.selected_category = None
127
 
128
- # Step-by-step processing
 
 
 
129
  if st.session_state.step == 'input':
130
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
131
  if st.button("Detect Clothing"):
@@ -143,6 +136,7 @@ if st.session_state.step == 'input':
143
  else:
144
  st.warning("Please enter an image URL.")
145
 
 
146
  elif st.session_state.step == 'select_category':
147
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
148
  st.subheader("Detected Clothing Items:")
@@ -179,7 +173,9 @@ elif st.session_state.step == 'show_results':
179
  with col2:
180
  st.write(f"Name: {img['info']['name']}")
181
  st.write(f"Brand: {img['info']['brand']}")
182
- st.write(f"Category: {img['info']['category']}")
 
 
183
  st.write(f"Price: {img['info']['price']}")
184
  st.write(f"Discount: {img['info']['discount']}%")
185
  st.write(f"Similarity: {img['similarity']:.2f}")
@@ -190,22 +186,25 @@ elif st.session_state.step == 'show_results':
190
  st.session_state.detections = []
191
  st.session_state.selected_category = None
192
 
193
- # Text search
194
- st.sidebar.title("Text Search")
195
- query_text = st.sidebar.text_input("Enter search text:")
196
- if st.sidebar.button("Search by Text"):
197
- if query_text:
198
- text_embedding = get_text_embedding(query_text)
199
- similar_images = find_similar_images(text_embedding, collection)
200
- st.sidebar.subheader("Similar Items:")
201
- for img in similar_images:
202
- st.sidebar.image(img['info']['image_url'], use_column_width=True)
203
- st.sidebar.write(f"Name: {img['info']['name']}")
204
- st.sidebar.write(f"Brand: {img['info']['brand']}")
205
- st.sidebar.write(f"Category: {img['info']['category']}")
206
- st.sidebar.write(f"Price: {img['info']['price']}")
207
- st.sidebar.write(f"Discount: {img['info']['discount']}%")
208
- st.sidebar.write(f"Similarity: {img['similarity']:.2f}")
209
- st.sidebar.write("---")
210
- else:
211
- st.sidebar.warning("Please enter a search text.")
 
 
 
 
7
  import time
8
  import json
9
  import numpy as np
10
+ from ultralytics import YOLO
11
  import cv2
12
  import chromadb
 
13
 
14
  # Load CLIP model and tokenizer
15
  @st.cache_resource
 
22
 
23
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
24
 
25
+ # Load YOLOv8 model
26
  @st.cache_resource
27
  def load_yolo_model():
28
  return YOLO("./best.pt")
29
 
30
  yolo_model = load_yolo_model()
31
 
 
 
 
32
  # Helper functions
33
  def load_image_from_url(url, max_retries=3):
34
  for attempt in range(max_retries):
 
42
  time.sleep(1)
43
  else:
44
  return None
 
45
  #Load chromaDB
46
  client = chromadb.PersistentClient(path="./clothesDB")
47
+ collection = client.get_collection(name="clothes_items_ver3")
48
 
49
  def get_image_embedding(image):
50
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
 
85
  })
86
  return results
87
 
88
+
89
+
90
  def detect_clothing(image):
 
 
 
 
91
  results = yolo_model(image)
92
  detections = results[0].boxes.data.cpu().numpy()
 
93
  categories = []
94
  for detection in detections:
95
  x1, y1, x2, y2, conf, cls = detection
 
105
  def crop_image(image, bbox):
106
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
107
 
108
+ # 세션 상태 초기화
 
 
 
109
  if 'step' not in st.session_state:
110
  st.session_state.step = 'input'
111
  if 'query_image_url' not in st.session_state:
 
115
  if 'selected_category' not in st.session_state:
116
  st.session_state.selected_category = None
117
 
118
+ # Streamlit app
119
+ st.title("Advanced Fashion Search App")
120
+
121
+ # 단계별 처리
122
  if st.session_state.step == 'input':
123
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
124
  if st.button("Detect Clothing"):
 
136
  else:
137
  st.warning("Please enter an image URL.")
138
 
139
+ # Update the 'select_category' step
140
  elif st.session_state.step == 'select_category':
141
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
142
  st.subheader("Detected Clothing Items:")
 
173
  with col2:
174
  st.write(f"Name: {img['info']['name']}")
175
  st.write(f"Brand: {img['info']['brand']}")
176
+ category = img['info'].get('category')
177
+ if category:
178
+ st.write(f"Category: {category}")
179
  st.write(f"Price: {img['info']['price']}")
180
  st.write(f"Discount: {img['info']['discount']}%")
181
  st.write(f"Similarity: {img['similarity']:.2f}")
 
186
  st.session_state.detections = []
187
  st.session_state.selected_category = None
188
 
189
+ else: # Text search
190
+ query_text = st.text_input("Enter search text:")
191
+ if st.button("Search by Text"):
192
+ if query_text:
193
+ text_embedding = get_text_embedding(query_text)
194
+ similar_images = find_similar_images(text_embedding, collection)
195
+ st.subheader("Similar Items:")
196
+ for img in similar_images:
197
+ col1, col2 = st.columns(2)
198
+ with col1:
199
+ st.image(img['info']['image_url'], use_column_width=True)
200
+ with col2:
201
+ st.write(f"Name: {img['info']['name']}")
202
+ st.write(f"Brand: {img['info']['brand']}")
203
+ category = img['info'].get('category')
204
+ if category:
205
+ st.write(f"Category: {category}")
206
+ st.write(f"Price: {img['info']['price']}")
207
+ st.write(f"Discount: {img['info']['discount']}%")
208
+ st.write(f"Similarity: {img['similarity']:.2f}")
209
+ else:
210
+ st.warning("Please enter a search text.")