leedoming commited on
Commit
4531600
1 Parent(s): 4f6b336

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import cv2
11
+ from inference_sdk import InferenceHTTPClient
12
+ import matplotlib.pyplot as plt
13
+ import base64
14
+
15
+ # Load model and tokenizer
16
+ @st.cache_resource
17
+ def load_model():
18
+ model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model.to(device)
21
+ return model, preprocess_val, tokenizer, device
22
+
23
+ model, preprocess_val, tokenizer, device = load_model()
24
+
25
+ # Load and process data
26
+ @st.cache_data
27
+ def load_data():
28
+ with open('musinsa-final.json', 'r', encoding='utf-8') as f:
29
+ return json.load(f)
30
+
31
+ data = load_data()
32
+
33
+ # Helper functions
34
+ @st.cache_data
35
+ def download_and_process_image(image_url):
36
+ try:
37
+ response = requests.get(image_url)
38
+ response.raise_for_status() # Raises an HTTPError for bad responses
39
+ image = Image.open(BytesIO(response.content))
40
+
41
+ # Convert image to RGB mode if it's in RGBA mode
42
+ if image.mode == 'RGBA':
43
+ image = image.convert('RGB')
44
+
45
+ return image
46
+ except requests.RequestException as e:
47
+ st.error(f"Error downloading image: {e}")
48
+ return None
49
+ except Exception as e:
50
+ st.error(f"Error processing image: {e}")
51
+ return None
52
+
53
+ def get_image_embedding(image):
54
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
55
+ with torch.no_grad():
56
+ image_features = model.encode_image(image_tensor)
57
+ image_features /= image_features.norm(dim=-1, keepdim=True)
58
+ return image_features.cpu().numpy()
59
+
60
+ def setup_roboflow_client(api_key):
61
+ return InferenceHTTPClient(
62
+ api_url="https://outline.roboflow.com",
63
+ api_key=api_key
64
+ )
65
+
66
+ def segment_image(image_path, client):
67
+ try:
68
+ # 이미지 파일 읽기
69
+ with open(image_path, "rb") as image_file:
70
+ image_data = image_file.read()
71
+
72
+ # 이미지를 base64로 인코딩
73
+ encoded_image = base64.b64encode(image_data).decode('utf-8')
74
+
75
+ # 원본 이미지 로드
76
+ image = cv2.imread(image_path)
77
+ image = cv2.resize(image, (800, 600))
78
+ mask = np.zeros(image.shape, dtype=np.uint8)
79
+
80
+ # Roboflow API 호출
81
+ results = client.infer(encoded_image, model_id="closet/1")
82
+
83
+ # 결과가 이미 딕셔너리인 경우 JSON 파싱 단계 제거
84
+ if isinstance(results, dict):
85
+ predictions = results.get('predictions', [])
86
+ else:
87
+ # 문자열인 경우에만 JSON 파싱
88
+ predictions = json.loads(results).get('predictions', [])
89
+
90
+ if predictions:
91
+ for prediction in predictions:
92
+ points = prediction['points']
93
+ pts = np.array([[p['x'], p['y']] for p in points], np.int32)
94
+ scale_x = image.shape[1] / results['image']['width']
95
+ scale_y = image.shape[0] / results['image']['height']
96
+ pts = pts * [scale_x, scale_y]
97
+ pts = pts.astype(np.int32)
98
+ pts = pts.reshape((-1, 1, 2))
99
+ cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
100
+
101
+ segmented_image = cv2.bitwise_and(image, mask)
102
+ else:
103
+ st.warning("No predictions found in the image. Returning original image.")
104
+ segmented_image = image
105
+
106
+ return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
107
+ except Exception as e:
108
+ st.error(f"Error in segmentation: {str(e)}")
109
+ # 원본 이미지를 다시 읽어 반환
110
+ return Image.open(image_path)
111
+
112
+ @st.cache_data
113
+ def process_database_cached(data):
114
+ database_embeddings = []
115
+ database_info = []
116
+ for item in data:
117
+ image_url = item['이미지 링크'][0]
118
+ product_id = item.get('\ufeff상품 ID') or item.get('상품 ID')
119
+
120
+ image = download_and_process_image(image_url)
121
+ if image is None:
122
+ continue
123
+
124
+ # Save the image temporarily
125
+ temp_path = f"temp_{product_id}.jpg"
126
+ image.save(temp_path, 'JPEG')
127
+
128
+ database_info.append({
129
+ 'id': product_id,
130
+ 'category': item['카테고리'],
131
+ 'brand': item['브랜드명'],
132
+ 'name': item['제품명'],
133
+ 'price': item['정가'],
134
+ 'discount': item['할인율'],
135
+ 'image_url': image_url,
136
+ 'temp_path': temp_path
137
+ })
138
+
139
+ return database_info
140
+
141
+ def process_database(client, data):
142
+ database_info = process_database_cached(data)
143
+ database_embeddings = []
144
+
145
+ for item in database_info:
146
+ segmented_image = segment_image(item['temp_path'], client)
147
+ embedding = get_image_embedding(segmented_image)
148
+ database_embeddings.append(embedding)
149
+
150
+ return np.vstack(database_embeddings), database_info
151
+
152
+ # Streamlit app
153
+ st.title("Fashion Search App with Segmentation")
154
+
155
+ # API Key input
156
+ api_key = st.text_input("Enter your Roboflow API Key", type="password")
157
+
158
+ if api_key:
159
+ CLIENT = setup_roboflow_client(api_key)
160
+
161
+ # Initialize database_embeddings and database_info
162
+ database_embeddings, database_info = process_database(CLIENT, data)
163
+
164
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
165
+ if uploaded_file is not None:
166
+ image = Image.open(uploaded_file)
167
+ st.image(image, caption='Uploaded Image', use_column_width=True)
168
+
169
+ if st.button('Find Similar Items'):
170
+ with st.spinner('Processing...'):
171
+ # Save uploaded image temporarily
172
+ temp_path = "temp_upload.jpg"
173
+ image.save(temp_path)
174
+
175
+ # Segment the uploaded image
176
+ segmented_image = segment_image(temp_path, CLIENT)
177
+ st.image(segmented_image, caption='Segmented Image', use_column_width=True)
178
+
179
+ # Get embedding for segmented image
180
+ query_embedding = get_image_embedding(segmented_image)
181
+ similar_images = find_similar_images(query_embedding)
182
+
183
+ st.subheader("Similar Items:")
184
+ for img in similar_images:
185
+ col1, col2 = st.columns(2)
186
+ with col1:
187
+ st.image(img['info']['image_url'], use_column_width=True)
188
+ with col2:
189
+ st.write(f"Name: {img['info']['name']}")
190
+ st.write(f"Brand: {img['info']['brand']}")
191
+ st.write(f"Category: {img['info']['category']}")
192
+ st.write(f"Price: {img['info']['price']}")
193
+ st.write(f"Discount: {img['info']['discount']}%")
194
+ st.write(f"Similarity: {img['similarity']:.2f}")
195
+ else:
196
+ st.warning("Please enter your Roboflow API Key to use the app.")