modularization
Browse files- app.py +4 -985
- config.py +25 -0
- conversation.py +41 -0
- image_utils.py +82 -0
- model_manager.py +220 -0
- models.py +25 -0
- query_processing.py +140 -0
- routes.py +475 -0
- wardrobe.py +113 -0
app.py
CHANGED
|
@@ -1,36 +1,10 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
-
from
|
| 4 |
-
from
|
| 5 |
-
from typing import List, Optional, AsyncGenerator
|
| 6 |
-
import re
|
| 7 |
-
import os
|
| 8 |
-
import base64
|
| 9 |
-
from PIL import Image
|
| 10 |
-
import requests
|
| 11 |
-
from io import BytesIO
|
| 12 |
-
from rag import initialize_rag, retrieve_relevant_context, format_rag_context
|
| 13 |
-
import json
|
| 14 |
-
import asyncio
|
| 15 |
-
import threading
|
| 16 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 17 |
-
from qwen_vl_utils import process_vision_info
|
| 18 |
-
import torch
|
| 19 |
|
| 20 |
app = FastAPI(title="Style GPT API", version="1.0.0")
|
| 21 |
|
| 22 |
-
style_model = None
|
| 23 |
-
style_processor = None
|
| 24 |
-
model_lock = threading.Lock()
|
| 25 |
-
|
| 26 |
-
class WardrobeItem(BaseModel):
|
| 27 |
-
id: Optional[int] = None
|
| 28 |
-
category: str
|
| 29 |
-
style: str
|
| 30 |
-
color: Optional[str] = None
|
| 31 |
-
brand: Optional[str] = None
|
| 32 |
-
name: Optional[str] = None
|
| 33 |
-
|
| 34 |
app.add_middleware(
|
| 35 |
CORSMiddleware,
|
| 36 |
allow_origins=["*"],
|
|
@@ -39,969 +13,14 @@ app.add_middleware(
|
|
| 39 |
allow_headers=["*"],
|
| 40 |
)
|
| 41 |
|
| 42 |
-
conversation_memory = {}
|
| 43 |
-
|
| 44 |
-
COLOR_HARMONY = {
|
| 45 |
-
"green": ["beige", "white", "grey", "gray", "navy", "black", "brown"],
|
| 46 |
-
"blue": ["white", "beige", "khaki", "grey", "gray", "navy", "red"],
|
| 47 |
-
"navy": ["white", "beige", "khaki", "grey", "gray", "red", "yellow"],
|
| 48 |
-
"black": ["white", "beige", "grey", "gray", "navy", "red", "blue"],
|
| 49 |
-
"white": ["black", "navy", "grey", "gray", "beige", "any color"],
|
| 50 |
-
"grey": ["black", "white", "navy", "beige", "blue"],
|
| 51 |
-
"gray": ["black", "white", "navy", "beige", "blue"],
|
| 52 |
-
"beige": ["green", "navy", "brown", "black", "white", "blue"],
|
| 53 |
-
"red": ["white", "black", "navy", "blue", "beige"],
|
| 54 |
-
"brown": ["beige", "white", "navy", "green", "blue"],
|
| 55 |
-
"yellow": ["navy", "black", "white", "grey", "gray"],
|
| 56 |
-
"pink": ["white", "navy", "grey", "gray", "black"],
|
| 57 |
-
"purple": ["white", "grey", "gray", "black", "navy"],
|
| 58 |
-
"orange": ["navy", "white", "black", "beige"],
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
CLOTHING_TYPES = [
|
| 62 |
-
"hoodie", "sweatshirt", "t-shirt", "shirt", "blazer", "jacket",
|
| 63 |
-
"jeans", "pants", "chinos", "trousers", "shorts",
|
| 64 |
-
"sneakers", "boots", "shoes", "sandal", "heels",
|
| 65 |
-
"dress", "skirt", "top", "sweater", "cardigan",
|
| 66 |
-
"vest", "coat", "parka", "blouse", "polo"
|
| 67 |
-
]
|
| 68 |
-
|
| 69 |
-
class ChatRequest(BaseModel):
|
| 70 |
-
message: str
|
| 71 |
-
session_id: Optional[str] = "default"
|
| 72 |
-
wardrobe: Optional[List[WardrobeItem]] = None
|
| 73 |
-
images: Optional[List[str]] = None
|
| 74 |
-
|
| 75 |
-
class ChatResponse(BaseModel):
|
| 76 |
-
response: str
|
| 77 |
-
session_id: str
|
| 78 |
-
|
| 79 |
-
def extract_clothing_info(text: str) -> dict:
|
| 80 |
-
text_lower = text.lower()
|
| 81 |
-
|
| 82 |
-
colors = list(COLOR_HARMONY.keys())
|
| 83 |
-
found_color = None
|
| 84 |
-
for color in colors:
|
| 85 |
-
if color in text_lower:
|
| 86 |
-
found_color = color
|
| 87 |
-
break
|
| 88 |
-
|
| 89 |
-
found_types = []
|
| 90 |
-
for clothing_type in CLOTHING_TYPES:
|
| 91 |
-
if clothing_type in text_lower:
|
| 92 |
-
found_types.append(clothing_type)
|
| 93 |
-
|
| 94 |
-
existing_item = None
|
| 95 |
-
requested_item = None
|
| 96 |
-
|
| 97 |
-
question_patterns = ["what kind of", "what", "which", "suggest", "recommend"]
|
| 98 |
-
is_question = any(pattern in text_lower for pattern in question_patterns)
|
| 99 |
-
|
| 100 |
-
if is_question and len(found_types) > 0:
|
| 101 |
-
if "my" in text_lower or "i have" in text_lower or "i'm wearing" in text_lower:
|
| 102 |
-
for i, word in enumerate(text_lower.split()):
|
| 103 |
-
if word == "my" and i + 1 < len(text_lower.split()):
|
| 104 |
-
next_words = " ".join(text_lower.split()[i+1:i+4])
|
| 105 |
-
for ct in found_types:
|
| 106 |
-
if ct in next_words:
|
| 107 |
-
existing_item = ct
|
| 108 |
-
break
|
| 109 |
-
if existing_item:
|
| 110 |
-
break
|
| 111 |
-
|
| 112 |
-
for ct in found_types:
|
| 113 |
-
if "match" in text_lower or "go with" in text_lower or "pair" in text_lower:
|
| 114 |
-
ct_pos = text_lower.find(ct)
|
| 115 |
-
match_pos = text_lower.find("match")
|
| 116 |
-
if ct_pos > match_pos or "what kind of" in text_lower[:ct_pos]:
|
| 117 |
-
requested_item = ct
|
| 118 |
-
break
|
| 119 |
-
|
| 120 |
-
if not existing_item and not requested_item and found_types:
|
| 121 |
-
if is_question:
|
| 122 |
-
requested_item = found_types[0]
|
| 123 |
-
else:
|
| 124 |
-
existing_item = found_types[0]
|
| 125 |
-
|
| 126 |
-
return {
|
| 127 |
-
"color": found_color,
|
| 128 |
-
"type": existing_item or requested_item,
|
| 129 |
-
"existing_item": existing_item,
|
| 130 |
-
"requested_item": requested_item,
|
| 131 |
-
"is_question": is_question,
|
| 132 |
-
"raw_text": text
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
def get_color_matches(color: str) -> List[str]:
|
| 136 |
-
color_lower = color.lower()
|
| 137 |
-
return COLOR_HARMONY.get(color_lower, ["white", "black", "grey", "beige"])
|
| 138 |
-
|
| 139 |
-
def ensure_model_loaded():
|
| 140 |
-
"""Load model if not already loaded"""
|
| 141 |
-
global style_model, style_processor
|
| 142 |
-
|
| 143 |
-
if style_model is not None:
|
| 144 |
-
return
|
| 145 |
-
|
| 146 |
-
print("Loading model (lazy load on first request)...")
|
| 147 |
-
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 148 |
-
hf_token = os.getenv("HF_TOKEN")
|
| 149 |
-
|
| 150 |
-
try:
|
| 151 |
-
if hf_token:
|
| 152 |
-
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 153 |
-
model_id,
|
| 154 |
-
token=hf_token,
|
| 155 |
-
torch_dtype=torch.float32,
|
| 156 |
-
device_map="auto",
|
| 157 |
-
low_cpu_mem_usage=True
|
| 158 |
-
)
|
| 159 |
-
style_processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
|
| 160 |
-
else:
|
| 161 |
-
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 162 |
-
model_id,
|
| 163 |
-
torch_dtype=torch.float32,
|
| 164 |
-
device_map="auto",
|
| 165 |
-
low_cpu_mem_usage=True
|
| 166 |
-
)
|
| 167 |
-
style_processor = AutoProcessor.from_pretrained(model_id)
|
| 168 |
-
|
| 169 |
-
print(f"Loaded {model_id}")
|
| 170 |
-
except Exception as e:
|
| 171 |
-
print(f"Error loading model: {e}")
|
| 172 |
-
raise
|
| 173 |
-
|
| 174 |
-
def generate_chat_response(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> str:
|
| 175 |
-
"""Generate response using the model"""
|
| 176 |
-
ensure_model_loaded()
|
| 177 |
-
|
| 178 |
-
system_message = system_override if system_override else "You are StyleGPT, a friendly and helpful fashion stylist assistant. You give natural, conversational advice about clothing, colors, and outfit combinations. Always be warm, friendly, and advisory in your responses. When asked your name, say you're StyleGPT. When greeted, respond warmly and offer to help with fashion advice."
|
| 179 |
-
|
| 180 |
-
if rag_context:
|
| 181 |
-
system_message += f"\n\n{rag_context}\n\nUse this fashion knowledge to provide accurate and helpful advice. Reference this knowledge naturally in your responses."
|
| 182 |
-
|
| 183 |
-
# Build user content list for Qwen2.5-VL format
|
| 184 |
-
user_content = []
|
| 185 |
-
if images and len(images) > 0:
|
| 186 |
-
for img in images:
|
| 187 |
-
user_content.append({"type": "image", "image": img})
|
| 188 |
-
user_content.append({"type": "text", "text": prompt})
|
| 189 |
-
|
| 190 |
-
messages = [
|
| 191 |
-
{"role": "system", "content": system_message},
|
| 192 |
-
{"role": "user", "content": user_content}
|
| 193 |
-
]
|
| 194 |
-
|
| 195 |
-
try:
|
| 196 |
-
# Prepare text for processing
|
| 197 |
-
text = style_processor.apply_chat_template(
|
| 198 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
# Process vision info (images/videos)
|
| 202 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 203 |
-
|
| 204 |
-
# Process inputs
|
| 205 |
-
inputs = style_processor(
|
| 206 |
-
text=[text],
|
| 207 |
-
images=image_inputs,
|
| 208 |
-
videos=video_inputs,
|
| 209 |
-
padding=True,
|
| 210 |
-
return_tensors="pt",
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
inputs = {k: v.to(style_model.device) for k, v in inputs.items()}
|
| 214 |
-
|
| 215 |
-
temperature = max(0.1, min(1.5, temperature))
|
| 216 |
-
|
| 217 |
-
with model_lock:
|
| 218 |
-
with torch.no_grad():
|
| 219 |
-
try:
|
| 220 |
-
outputs = style_model.generate(
|
| 221 |
-
**inputs,
|
| 222 |
-
max_new_tokens=max_length,
|
| 223 |
-
temperature=temperature,
|
| 224 |
-
top_p=0.95,
|
| 225 |
-
top_k=50,
|
| 226 |
-
do_sample=True,
|
| 227 |
-
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 228 |
-
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 229 |
-
repetition_penalty=1.1,
|
| 230 |
-
)
|
| 231 |
-
except RuntimeError as e:
|
| 232 |
-
if "probability tensor" in str(e) or "inf" in str(e) or "nan" in str(e):
|
| 233 |
-
print(f"[GENERATE] Probability error, retrying with greedy decoding")
|
| 234 |
-
outputs = style_model.generate(
|
| 235 |
-
**inputs,
|
| 236 |
-
max_new_tokens=max_length,
|
| 237 |
-
do_sample=False,
|
| 238 |
-
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 239 |
-
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 240 |
-
repetition_penalty=1.1,
|
| 241 |
-
)
|
| 242 |
-
else:
|
| 243 |
-
raise
|
| 244 |
-
|
| 245 |
-
# Decode only the generated part
|
| 246 |
-
generated_ids_trimmed = [
|
| 247 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
|
| 248 |
-
]
|
| 249 |
-
generated_text = style_processor.batch_decode(
|
| 250 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 251 |
-
)[0]
|
| 252 |
-
|
| 253 |
-
generated_text = generated_text.strip()
|
| 254 |
-
return generated_text
|
| 255 |
-
|
| 256 |
-
except Exception as e:
|
| 257 |
-
print(f"[GENERATE] Error: {e}")
|
| 258 |
-
import traceback
|
| 259 |
-
traceback.print_exc()
|
| 260 |
-
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
|
| 261 |
-
|
| 262 |
-
async def generate_chat_response_streaming(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
|
| 263 |
-
"""Generate streaming response using the model"""
|
| 264 |
-
ensure_model_loaded()
|
| 265 |
-
|
| 266 |
-
system_message = system_override if system_override else "You are StyleGPT, a friendly and helpful fashion stylist assistant. You give natural, conversational advice about clothing, colors, and outfit combinations. Always be warm, friendly, and advisory in your responses. When asked your name, say you're StyleGPT. When greeted, respond warmly and offer to help with fashion advice."
|
| 267 |
-
|
| 268 |
-
if rag_context:
|
| 269 |
-
system_message += f"\n\n{rag_context}\n\nUse this fashion knowledge to provide accurate and helpful advice. Reference this knowledge naturally in your responses."
|
| 270 |
-
|
| 271 |
-
# Build user content list for Qwen2.5-VL format
|
| 272 |
-
user_content = []
|
| 273 |
-
if images and len(images) > 0:
|
| 274 |
-
for img in images:
|
| 275 |
-
user_content.append({"type": "image", "image": img})
|
| 276 |
-
user_content.append({"type": "text", "text": prompt})
|
| 277 |
-
|
| 278 |
-
messages = [
|
| 279 |
-
{"role": "system", "content": system_message},
|
| 280 |
-
{"role": "user", "content": user_content}
|
| 281 |
-
]
|
| 282 |
-
|
| 283 |
-
try:
|
| 284 |
-
# Prepare text for processing
|
| 285 |
-
text = style_processor.apply_chat_template(
|
| 286 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 287 |
-
)
|
| 288 |
-
|
| 289 |
-
# Process vision info (images/videos)
|
| 290 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 291 |
-
|
| 292 |
-
# Process inputs
|
| 293 |
-
inputs = style_processor(
|
| 294 |
-
text=[text],
|
| 295 |
-
images=image_inputs,
|
| 296 |
-
videos=video_inputs,
|
| 297 |
-
padding=True,
|
| 298 |
-
return_tensors="pt",
|
| 299 |
-
)
|
| 300 |
-
|
| 301 |
-
inputs = {k: v.to(style_model.device) for k, v in inputs.items()}
|
| 302 |
-
|
| 303 |
-
temperature = max(0.1, min(1.5, temperature))
|
| 304 |
-
|
| 305 |
-
with model_lock:
|
| 306 |
-
with torch.no_grad():
|
| 307 |
-
try:
|
| 308 |
-
outputs = style_model.generate(
|
| 309 |
-
**inputs,
|
| 310 |
-
max_new_tokens=max_length,
|
| 311 |
-
temperature=temperature,
|
| 312 |
-
top_p=0.95,
|
| 313 |
-
top_k=50,
|
| 314 |
-
do_sample=True,
|
| 315 |
-
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 316 |
-
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 317 |
-
repetition_penalty=1.1,
|
| 318 |
-
)
|
| 319 |
-
except RuntimeError as e:
|
| 320 |
-
if "probability tensor" in str(e) or "inf" in str(e) or "nan" in str(e):
|
| 321 |
-
print(f"[GENERATE STREAM] Probability error, retrying with greedy decoding")
|
| 322 |
-
outputs = style_model.generate(
|
| 323 |
-
**inputs,
|
| 324 |
-
max_new_tokens=max_length,
|
| 325 |
-
do_sample=False,
|
| 326 |
-
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 327 |
-
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 328 |
-
repetition_penalty=1.1,
|
| 329 |
-
)
|
| 330 |
-
else:
|
| 331 |
-
raise
|
| 332 |
-
|
| 333 |
-
# Decode only the generated part
|
| 334 |
-
generated_ids_trimmed = [
|
| 335 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
|
| 336 |
-
]
|
| 337 |
-
generated_text = style_processor.batch_decode(
|
| 338 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 339 |
-
)[0]
|
| 340 |
-
|
| 341 |
-
generated_text = generated_text.strip()
|
| 342 |
-
|
| 343 |
-
for char in generated_text:
|
| 344 |
-
yield char
|
| 345 |
-
await asyncio.sleep(0.01)
|
| 346 |
-
except Exception as e:
|
| 347 |
-
print(f"[GENERATE STREAM] Error: {e}")
|
| 348 |
-
import traceback
|
| 349 |
-
traceback.print_exc()
|
| 350 |
-
error_msg = f"I apologize, but I encountered an error generating a response. Please try again."
|
| 351 |
-
for char in error_msg:
|
| 352 |
-
yield char
|
| 353 |
-
await asyncio.sleep(0.01)
|
| 354 |
-
|
| 355 |
-
def generate_outfit_suggestion(description: str, color: Optional[str], clothing_type: Optional[str], existing_item: Optional[str] = None, requested_item: Optional[str] = None, is_question: bool = False) -> dict:
|
| 356 |
-
if is_question and requested_item:
|
| 357 |
-
item_focus = requested_item
|
| 358 |
-
base_item = existing_item or clothing_type or 'outfit'
|
| 359 |
-
prompt = f"I have a {base_item}"
|
| 360 |
-
if color:
|
| 361 |
-
prompt += f" in {color} color"
|
| 362 |
-
prompt += f". What {requested_item} would match well? Suggest 2-3 specific items with brief explanations."
|
| 363 |
-
elif existing_item:
|
| 364 |
-
item_focus = None
|
| 365 |
-
prompt = f"I have a {existing_item}"
|
| 366 |
-
if color:
|
| 367 |
-
prompt += f" in {color} color"
|
| 368 |
-
prompt += ". What other clothing items would match well? Suggest 2-3 specific items."
|
| 369 |
-
else:
|
| 370 |
-
item_focus = requested_item or clothing_type
|
| 371 |
-
prompt = f"{description}. Suggest 2-3 specific clothing items that would match well and give a brief explanation."
|
| 372 |
-
|
| 373 |
-
generated_text = generate_chat_response(prompt, max_length=200, temperature=0.8)
|
| 374 |
-
|
| 375 |
-
items = []
|
| 376 |
-
explanation = ""
|
| 377 |
-
|
| 378 |
-
template_markers = ["[specific item name]", "[brief, friendly explanation]", "specific item name", "[item"]
|
| 379 |
-
|
| 380 |
-
lines = generated_text.split("\n")
|
| 381 |
-
for line in lines:
|
| 382 |
-
line = line.strip()
|
| 383 |
-
if re.match(r"^Item\s+\d+:", line, re.IGNORECASE) or re.match(r"^-?\s*Item\s+\d+:", line, re.IGNORECASE):
|
| 384 |
-
line_clean = re.sub(r"^-?\s*Item\s+\d+:\s*", "", line, flags=re.IGNORECASE).strip()
|
| 385 |
-
|
| 386 |
-
parts = re.split(r"\s+Item\s+\d+:\s*", line_clean, flags=re.IGNORECASE)
|
| 387 |
-
for part in parts:
|
| 388 |
-
item = part.strip()
|
| 389 |
-
if item and len(item) > 2 and not item.lower().endswith(":") and not any(marker in item.lower() for marker in template_markers):
|
| 390 |
-
if not re.match(r"^\[.*\]$", item) and not item.lower().startswith("item"):
|
| 391 |
-
items.append(item)
|
| 392 |
-
elif line.startswith("Explanation:"):
|
| 393 |
-
explanation = re.sub(r"^Explanation:\s*", "", line, flags=re.IGNORECASE).strip()
|
| 394 |
-
if any(marker in explanation.lower() for marker in template_markers):
|
| 395 |
-
explanation = ""
|
| 396 |
-
|
| 397 |
-
if "Explanation:" in generated_text and not explanation:
|
| 398 |
-
parts = generated_text.split("Explanation:")
|
| 399 |
-
if len(parts) > 1:
|
| 400 |
-
explanation = parts[-1].strip().split("\n")[0].strip()
|
| 401 |
-
if any(marker in explanation.lower() for marker in template_markers):
|
| 402 |
-
explanation = ""
|
| 403 |
-
|
| 404 |
-
if not items:
|
| 405 |
-
item_lines = re.findall(r"Item\s+\d+:\s*([^\n]+?)(?:\s+Item\s+\d+:|Explanation:|\s*$)", generated_text, re.IGNORECASE | re.DOTALL)
|
| 406 |
-
for item_match in item_lines:
|
| 407 |
-
item = item_match.strip()
|
| 408 |
-
if "Item " in item:
|
| 409 |
-
item = re.split(r"\s+Item\s+\d+:", item, flags=re.IGNORECASE)[0].strip()
|
| 410 |
-
if item and len(item) > 2 and not any(marker in item.lower() for marker in template_markers):
|
| 411 |
-
if not re.match(r"^\[.*\]$", item):
|
| 412 |
-
items.append(item)
|
| 413 |
-
|
| 414 |
-
items = [item for item in items if not any(item2 in item for item2 in items if item2 != item and len(item2) > 5)]
|
| 415 |
-
|
| 416 |
-
items = [item for item in items if not any(marker in item.lower() for marker in template_markers) and len(item) > 3]
|
| 417 |
-
|
| 418 |
-
items = [item for item in items if not item.lower().startswith("item")]
|
| 419 |
-
|
| 420 |
-
if not items:
|
| 421 |
-
if item_focus:
|
| 422 |
-
if color:
|
| 423 |
-
matching_colors = get_color_matches(color)[:3]
|
| 424 |
-
items = [f"{c.title()} {item_focus}" for c in matching_colors if c != color]
|
| 425 |
-
else:
|
| 426 |
-
neutrals = ["beige", "white", "navy", "black"]
|
| 427 |
-
items = [f"{n.title()} {item_focus}" for n in neutrals[:3]]
|
| 428 |
-
elif color:
|
| 429 |
-
matching_colors = get_color_matches(color)[:3]
|
| 430 |
-
item_type = requested_item or clothing_type or "trousers" if "trousers" in description.lower() else "item"
|
| 431 |
-
items = [f"{c.title()} {item_type}" for c in matching_colors if c != color]
|
| 432 |
-
else:
|
| 433 |
-
if "trousers" in description.lower() or "pants" in description.lower():
|
| 434 |
-
items = ["Beige chinos", "Navy trousers", "Black jeans"]
|
| 435 |
-
elif "hoodie" in description.lower():
|
| 436 |
-
items = ["Beige chinos", "White sneakers", "Navy jacket"]
|
| 437 |
-
else:
|
| 438 |
-
items = ["White sneakers", "Beige chinos", "Black jacket"]
|
| 439 |
-
|
| 440 |
-
items = [item for item in items if len(item) > 3 and not item.lower().endswith(":") and not any(marker in item.lower() for marker in template_markers)]
|
| 441 |
-
|
| 442 |
-
items = list(dict.fromkeys(items))
|
| 443 |
-
|
| 444 |
-
while len(items) < 2:
|
| 445 |
-
if item_focus:
|
| 446 |
-
neutrals = ["beige", "white", "navy", "black", "grey"]
|
| 447 |
-
for n in neutrals:
|
| 448 |
-
new_item = f"{n.title()} {item_focus}"
|
| 449 |
-
if new_item not in items and (not color or n != color):
|
| 450 |
-
items.append(new_item)
|
| 451 |
-
break
|
| 452 |
-
elif color:
|
| 453 |
-
matching_colors = get_color_matches(color)
|
| 454 |
-
for mc in matching_colors:
|
| 455 |
-
if mc != color:
|
| 456 |
-
item_type = requested_item or clothing_type or "item"
|
| 457 |
-
new_item = f"{mc.title()} {item_type}"
|
| 458 |
-
if new_item not in items:
|
| 459 |
-
items.append(new_item)
|
| 460 |
-
break
|
| 461 |
-
else:
|
| 462 |
-
items.append("Classic white sneakers")
|
| 463 |
-
|
| 464 |
-
items = items[:3]
|
| 465 |
-
|
| 466 |
-
if not explanation or len(explanation) < 10:
|
| 467 |
-
if color and item_focus:
|
| 468 |
-
explanation = f"These {item_focus} options complement your {color} {existing_item or 'outfit'} beautifully, creating a balanced and stylish look."
|
| 469 |
-
elif color:
|
| 470 |
-
explanation = f"These items pair well with {color}, offering a versatile and fashionable combination."
|
| 471 |
-
elif item_focus:
|
| 472 |
-
explanation = f"These {item_focus} options will create a stylish and cohesive outfit."
|
| 473 |
-
else:
|
| 474 |
-
explanation = "These items pair well together for a stylish, balanced look."
|
| 475 |
-
|
| 476 |
-
return {
|
| 477 |
-
"items": items,
|
| 478 |
-
"explanation": explanation
|
| 479 |
-
}
|
| 480 |
-
|
| 481 |
-
def load_image_from_url(url_or_base64: str) -> Image.Image:
|
| 482 |
-
try:
|
| 483 |
-
if url_or_base64.startswith("data:image"):
|
| 484 |
-
header, encoded = url_or_base64.split(",", 1)
|
| 485 |
-
image_data = base64.b64decode(encoded)
|
| 486 |
-
return Image.open(BytesIO(image_data)).convert("RGB")
|
| 487 |
-
else:
|
| 488 |
-
response = requests.get(url_or_base64, timeout=10)
|
| 489 |
-
response.raise_for_status()
|
| 490 |
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
| 491 |
-
except Exception as e:
|
| 492 |
-
raise ValueError(f"Failed to load image: {str(e)}")
|
| 493 |
-
|
| 494 |
@app.on_event("startup")
|
| 495 |
async def startup_event():
|
| 496 |
print("Initializing RAG system...")
|
| 497 |
initialize_rag()
|
| 498 |
print("RAG ready for requests! Model will load on first request.")
|
| 499 |
|
| 500 |
-
|
| 501 |
-
async def root():
|
| 502 |
-
return {
|
| 503 |
-
"message": "Style GPT API - Milestone 1",
|
| 504 |
-
"version": "1.0.0",
|
| 505 |
-
"endpoint": "/chat - POST - Conversational fashion assistant"
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
@app.get("/health")
|
| 509 |
-
async def health_check():
|
| 510 |
-
return {
|
| 511 |
-
"status": "healthy" if style_model is not None else "loading",
|
| 512 |
-
"model_loaded": style_model is not None,
|
| 513 |
-
"model_name": "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 514 |
-
}
|
| 515 |
-
|
| 516 |
-
def extract_colors_from_query(query: str) -> tuple:
|
| 517 |
-
query_lower = query.lower()
|
| 518 |
-
colors = list(COLOR_HARMONY.keys())
|
| 519 |
-
|
| 520 |
-
extended_colors = ["navy blue", "wine", "burgundy", "maroon", "crimson", "scarlet", "mauve", "taupe", "olive", "teal", "turquoise", "indigo", "cobalt"]
|
| 521 |
-
all_colors = extended_colors + colors
|
| 522 |
-
|
| 523 |
-
color_mapping = {
|
| 524 |
-
"wine": "red", "burgundy": "red", "maroon": "red",
|
| 525 |
-
"mauve": "purple", "taupe": "beige", "olive": "green",
|
| 526 |
-
"teal": "blue", "turquoise": "blue", "indigo": "navy", "cobalt": "blue",
|
| 527 |
-
"navy blue": "navy"
|
| 528 |
-
}
|
| 529 |
-
|
| 530 |
-
found_colors = []
|
| 531 |
-
seen_mapped = set()
|
| 532 |
-
for color in all_colors:
|
| 533 |
-
if color in query_lower:
|
| 534 |
-
mapped = color_mapping.get(color, color)
|
| 535 |
-
if mapped not in seen_mapped:
|
| 536 |
-
found_colors.append((color, mapped))
|
| 537 |
-
seen_mapped.add(mapped)
|
| 538 |
-
|
| 539 |
-
return found_colors
|
| 540 |
-
|
| 541 |
-
def detect_query_type(message: str) -> str:
|
| 542 |
-
message_lower = message.lower()
|
| 543 |
-
|
| 544 |
-
color_comparison_patterns = ["does", "do", "will", "can"]
|
| 545 |
-
comparison_keywords = ["go with", "match", "work with", "pair with", "combine with"]
|
| 546 |
-
color_suggestion_patterns = ["what color", "which color", "colors go with", "colors match", "better color", "color to match"]
|
| 547 |
-
|
| 548 |
-
has_two_colors = len(extract_colors_from_query(message)) >= 2
|
| 549 |
-
|
| 550 |
-
if any(pattern in message_lower for pattern in color_comparison_patterns) and any(kw in message_lower for kw in comparison_keywords) and has_two_colors:
|
| 551 |
-
return "color_compatibility"
|
| 552 |
-
|
| 553 |
-
outfit_request_patterns = ["suggest", "recommend", "outfit", "wear", "dress", "thinking of", "what should i wear", "what can i wear", "what to wear"]
|
| 554 |
-
what_matches_patterns = ["what will go", "what goes with", "what matches", "what will match", "what can go"]
|
| 555 |
-
|
| 556 |
-
if any(pattern in message_lower for pattern in what_matches_patterns) and any(item in message_lower for item in CLOTHING_TYPES):
|
| 557 |
-
return "outfit_suggestion"
|
| 558 |
-
|
| 559 |
-
if any(pattern in message_lower for pattern in outfit_request_patterns):
|
| 560 |
-
return "outfit_suggestion"
|
| 561 |
-
|
| 562 |
-
if any(pattern in message_lower for pattern in color_suggestion_patterns):
|
| 563 |
-
return "color_suggestion"
|
| 564 |
-
|
| 565 |
-
if any(ct in message_lower for ct in CLOTHING_TYPES) and any(c in message_lower for c in ["with", "match", "go", "pair", "style", "stylish", "look"]):
|
| 566 |
-
return "outfit_suggestion"
|
| 567 |
-
|
| 568 |
-
if len(extract_colors_from_query(message)) > 0 and any(word in message_lower for word in ["match", "with", "go", "pair"]):
|
| 569 |
-
return "color_suggestion"
|
| 570 |
-
|
| 571 |
-
if any(ct in message_lower for ct in CLOTHING_TYPES):
|
| 572 |
-
return "outfit_suggestion"
|
| 573 |
-
|
| 574 |
-
return "outfit_suggestion"
|
| 575 |
-
|
| 576 |
-
def get_conversation_context(session_id: str) -> dict:
|
| 577 |
-
if session_id not in conversation_memory:
|
| 578 |
-
conversation_memory[session_id] = {
|
| 579 |
-
"messages": [],
|
| 580 |
-
"context": {}
|
| 581 |
-
}
|
| 582 |
-
return conversation_memory[session_id]
|
| 583 |
-
|
| 584 |
-
def update_context(session_id: str, message: str, response_data: dict):
|
| 585 |
-
conv = conversation_memory[session_id]
|
| 586 |
-
conv["messages"].append({"user": message, "assistant": response_data.get("response", "")})
|
| 587 |
-
|
| 588 |
-
if len(conv["messages"]) > 10:
|
| 589 |
-
conv["messages"] = conv["messages"][-10:]
|
| 590 |
-
|
| 591 |
-
if "color" in response_data:
|
| 592 |
-
conv["context"]["last_color"] = response_data["color"]
|
| 593 |
-
if "item" in response_data:
|
| 594 |
-
conv["context"]["last_item"] = response_data["item"]
|
| 595 |
-
if "colors" in response_data:
|
| 596 |
-
conv["context"]["last_colors"] = response_data["colors"]
|
| 597 |
-
if "items" in response_data:
|
| 598 |
-
conv["context"]["last_items"] = response_data["items"]
|
| 599 |
-
|
| 600 |
-
def enhance_message_with_context(message: str, context: dict) -> str:
|
| 601 |
-
message_lower = message.lower()
|
| 602 |
-
|
| 603 |
-
if "what about" in message_lower or "how about" in message_lower or "and" in message_lower:
|
| 604 |
-
if context.get("last_color"):
|
| 605 |
-
if not any(c in message_lower for c in ["color", "red", "blue", "green", "black", "white", "brown"]):
|
| 606 |
-
message = message + f" with {context['last_color']}"
|
| 607 |
-
if context.get("last_item"):
|
| 608 |
-
if not any(item in message_lower for item in CLOTHING_TYPES):
|
| 609 |
-
message = message + f" {context['last_item']}"
|
| 610 |
-
|
| 611 |
-
return message
|
| 612 |
-
|
| 613 |
-
def is_greeting(message: str) -> bool:
|
| 614 |
-
message_lower = message.lower().strip()
|
| 615 |
-
greetings = [
|
| 616 |
-
"hello", "hi", "hey", "good morning", "good afternoon", "good evening",
|
| 617 |
-
"greetings", "howdy", "what's up", "whats up", "sup", "yo"
|
| 618 |
-
]
|
| 619 |
-
return any(message_lower.startswith(g) or message_lower == g for g in greetings)
|
| 620 |
-
|
| 621 |
-
def is_name_question(message: str) -> bool:
|
| 622 |
-
message_lower = message.lower().strip()
|
| 623 |
-
name_patterns = [
|
| 624 |
-
"what is your name", "what's your name", "whats your name",
|
| 625 |
-
"who are you", "what are you", "tell me your name", "your name"
|
| 626 |
-
]
|
| 627 |
-
return any(pattern in message_lower for pattern in name_patterns)
|
| 628 |
-
|
| 629 |
-
def clean_wardrobe_response(text: str) -> str:
|
| 630 |
-
import re
|
| 631 |
-
|
| 632 |
-
text = re.sub(r'\b(\w+)\s+\1\b', r'\1', text)
|
| 633 |
-
|
| 634 |
-
text = re.sub(r'\b(navy|grey|gray|black|white|brown|blue|red|green|beige|tan|charcoal|rolex|fossil|hermes|zara|nike)\s+\1\b', r'\1', text, flags=re.IGNORECASE)
|
| 635 |
-
|
| 636 |
-
if "protein" in text.lower():
|
| 637 |
-
text = re.sub(r'[Pp]rotein\s*', '', text)
|
| 638 |
-
|
| 639 |
-
lines = text.split('\n')
|
| 640 |
-
cleaned_lines = []
|
| 641 |
-
|
| 642 |
-
for line in lines:
|
| 643 |
-
line = line.strip()
|
| 644 |
-
if not line or len(line) < 3:
|
| 645 |
-
continue
|
| 646 |
-
|
| 647 |
-
if "protein" in line.lower():
|
| 648 |
-
continue
|
| 649 |
-
|
| 650 |
-
cleaned_lines.append(line)
|
| 651 |
-
|
| 652 |
-
result = '\n'.join(cleaned_lines).strip()
|
| 653 |
-
|
| 654 |
-
if len(result) > 1000:
|
| 655 |
-
sentences = result.split('. ')
|
| 656 |
-
result = '. '.join(sentences[:8]) + '.'
|
| 657 |
-
|
| 658 |
-
return result
|
| 659 |
-
|
| 660 |
-
def format_wardrobe_for_prompt(wardrobe: List[WardrobeItem]) -> str:
|
| 661 |
-
wardrobe_by_category = {}
|
| 662 |
-
for item in wardrobe:
|
| 663 |
-
category = item.category.lower()
|
| 664 |
-
if category not in wardrobe_by_category:
|
| 665 |
-
wardrobe_by_category[category] = []
|
| 666 |
-
wardrobe_by_category[category].append(item)
|
| 667 |
-
|
| 668 |
-
wardrobe_details = []
|
| 669 |
-
for idx, item in enumerate(wardrobe, 1):
|
| 670 |
-
parts = []
|
| 671 |
-
if item.brand:
|
| 672 |
-
parts.append(item.brand)
|
| 673 |
-
if item.color:
|
| 674 |
-
parts.append(item.color)
|
| 675 |
-
if item.name:
|
| 676 |
-
parts.append(item.name)
|
| 677 |
-
elif item.category:
|
| 678 |
-
parts.append(item.category)
|
| 679 |
-
|
| 680 |
-
item_name = " ".join(parts) if parts else item.category
|
| 681 |
-
wardrobe_details.append(f'{idx}. {item_name} ({item.category}, {item.style})')
|
| 682 |
-
|
| 683 |
-
categories_list = ", ".join(wardrobe_by_category.keys())
|
| 684 |
-
|
| 685 |
-
return f"""Available items ({len(wardrobe)} total):
|
| 686 |
-
{chr(10).join(wardrobe_details)}
|
| 687 |
-
|
| 688 |
-
Categories: {categories_list}"""
|
| 689 |
-
|
| 690 |
-
async def handle_wardrobe_chat(message: str, wardrobe: List[WardrobeItem], session_id: str, images: Optional[List[str]] = None) -> ChatResponse:
|
| 691 |
-
conv_context = get_conversation_context(session_id)
|
| 692 |
-
enhanced_message = enhance_message_with_context(message, conv_context["context"])
|
| 693 |
-
|
| 694 |
-
wardrobe_context = format_wardrobe_for_prompt(wardrobe)
|
| 695 |
-
|
| 696 |
-
wardrobe_by_category = {}
|
| 697 |
-
for item in wardrobe:
|
| 698 |
-
category = item.category.lower()
|
| 699 |
-
if category not in wardrobe_by_category:
|
| 700 |
-
wardrobe_by_category[category] = []
|
| 701 |
-
wardrobe_by_category[category].append(item)
|
| 702 |
-
|
| 703 |
-
rag_chunks = retrieve_relevant_context(enhanced_message, top_k=3)
|
| 704 |
-
rag_context = format_rag_context(rag_chunks)
|
| 705 |
-
|
| 706 |
-
occasion_keywords = ["defense", "project", "presentation", "meeting", "interview", "formal", "casual", "party", "wedding", "dinner", "date", "work", "office"]
|
| 707 |
-
occasion = next((word for word in occasion_keywords if word in enhanced_message.lower()), None)
|
| 708 |
-
|
| 709 |
-
context_info = f"Available wardrobe categories: {', '.join(wardrobe_by_category.keys())}. "
|
| 710 |
-
if occasion:
|
| 711 |
-
context_info += f"Occasion: {occasion}. "
|
| 712 |
-
|
| 713 |
-
system_override = "You are a friendly and helpful fashion stylist. Suggest complete outfits conversationally and warmly. Include accessories when available. Be natural and friendly in your responses."
|
| 714 |
-
|
| 715 |
-
prompt = f"""{wardrobe_context}
|
| 716 |
-
|
| 717 |
-
User request: {enhanced_message}
|
| 718 |
-
|
| 719 |
-
Suggest a complete outfit using ONLY the items listed above. Reference items by their exact names as shown (e.g., if item is "zara black pants", say "zara black pants", not "black zara pants"). Include accessories (watches, bags, jewelry, belts, glasses) if available. Be friendly and conversational. Suggest: one top/shirt, one bottom (pants/shorts), shoes, and accessories. Explain briefly why it works."""
|
| 720 |
-
|
| 721 |
-
if context_info.strip():
|
| 722 |
-
prompt += f"\n\nContext: {context_info.strip()}"
|
| 723 |
-
|
| 724 |
-
response_text = generate_chat_response(prompt, max_length=512, temperature=0.8, rag_context=rag_context, system_override=system_override, images=images)
|
| 725 |
-
|
| 726 |
-
response_text = clean_wardrobe_response(response_text)
|
| 727 |
-
|
| 728 |
-
update_context(session_id, message, {
|
| 729 |
-
"response": response_text,
|
| 730 |
-
"wardrobe_count": len(wardrobe),
|
| 731 |
-
"categories": list(wardrobe_by_category.keys())
|
| 732 |
-
})
|
| 733 |
-
|
| 734 |
-
return ChatResponse(
|
| 735 |
-
response=response_text,
|
| 736 |
-
session_id=session_id
|
| 737 |
-
)
|
| 738 |
-
|
| 739 |
-
@app.post("/chat", response_model=ChatResponse)
|
| 740 |
-
async def chat(request: ChatRequest):
|
| 741 |
-
try:
|
| 742 |
-
message = request.message.strip()
|
| 743 |
-
session_id = request.session_id
|
| 744 |
-
|
| 745 |
-
if not message:
|
| 746 |
-
raise HTTPException(status_code=400, detail="Message cannot be empty")
|
| 747 |
-
|
| 748 |
-
if request.wardrobe and len(request.wardrobe) > 0:
|
| 749 |
-
print(f"[WARDROBE CHAT] ===== WARDROBE REQUEST DETECTED =====")
|
| 750 |
-
return await handle_wardrobe_chat(message, request.wardrobe, session_id, images=request.images)
|
| 751 |
-
|
| 752 |
-
conv_context = get_conversation_context(session_id)
|
| 753 |
-
|
| 754 |
-
if is_name_question(message):
|
| 755 |
-
prompt = "What is your name? Respond naturally and friendly."
|
| 756 |
-
rag_chunks = retrieve_relevant_context(message, top_k=2)
|
| 757 |
-
rag_context = format_rag_context(rag_chunks)
|
| 758 |
-
response_text = generate_chat_response(prompt, max_length=100, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 759 |
-
update_context(session_id, message, {"response": response_text})
|
| 760 |
-
return ChatResponse(response=response_text, session_id=session_id)
|
| 761 |
-
|
| 762 |
-
if is_greeting(message):
|
| 763 |
-
prompt = f"{message} Respond warmly and offer to help with fashion advice."
|
| 764 |
-
rag_chunks = retrieve_relevant_context(message, top_k=2)
|
| 765 |
-
rag_context = format_rag_context(rag_chunks)
|
| 766 |
-
response_text = generate_chat_response(prompt, max_length=150, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 767 |
-
update_context(session_id, message, {"response": response_text})
|
| 768 |
-
return ChatResponse(response=response_text, session_id=session_id)
|
| 769 |
-
|
| 770 |
-
enhanced_message = enhance_message_with_context(message, conv_context["context"])
|
| 771 |
-
query_type = detect_query_type(enhanced_message)
|
| 772 |
-
rag_chunks = retrieve_relevant_context(enhanced_message, top_k=3)
|
| 773 |
-
rag_context = format_rag_context(rag_chunks)
|
| 774 |
-
|
| 775 |
-
if query_type == "color_compatibility":
|
| 776 |
-
found_colors = extract_colors_from_query(enhanced_message)
|
| 777 |
-
|
| 778 |
-
if len(found_colors) >= 2:
|
| 779 |
-
color1_mapped = found_colors[0][1]
|
| 780 |
-
color2_mapped = found_colors[1][1]
|
| 781 |
-
color1_original = found_colors[0][0]
|
| 782 |
-
color2_original = found_colors[1][0]
|
| 783 |
-
|
| 784 |
-
compatible = False
|
| 785 |
-
if color1_mapped in COLOR_HARMONY:
|
| 786 |
-
compatible = color2_mapped in COLOR_HARMONY[color1_mapped]
|
| 787 |
-
elif color2_mapped in COLOR_HARMONY:
|
| 788 |
-
compatible = color1_mapped in COLOR_HARMONY[color2_mapped]
|
| 789 |
-
|
| 790 |
-
neutrals = ["white", "black", "grey", "gray", "beige", "navy"]
|
| 791 |
-
if color1_mapped in neutrals or color2_mapped in neutrals:
|
| 792 |
-
compatible = True
|
| 793 |
-
|
| 794 |
-
if compatible:
|
| 795 |
-
response_text = f"Yes, {color1_original.title()} will go well with {color2_original.title()}. They create a balanced and stylish combination that works great together!"
|
| 796 |
-
else:
|
| 797 |
-
response_text = f"{color1_original.title()} and {color2_original.title()} can work together, though you might want to add some neutral pieces to balance the look."
|
| 798 |
-
|
| 799 |
-
prompt = f"Does {color1_original} go well with {color2_original}? Answer naturally and conversationally."
|
| 800 |
-
ai_response = generate_chat_response(prompt, max_length=150, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 801 |
-
if len(ai_response) > 15:
|
| 802 |
-
response_text = ai_response
|
| 803 |
-
|
| 804 |
-
update_context(session_id, message, {
|
| 805 |
-
"response": response_text,
|
| 806 |
-
"color": color1_original,
|
| 807 |
-
"colors": [color1_original, color2_original]
|
| 808 |
-
})
|
| 809 |
-
|
| 810 |
-
return ChatResponse(
|
| 811 |
-
response=response_text,
|
| 812 |
-
session_id=session_id
|
| 813 |
-
)
|
| 814 |
-
|
| 815 |
-
elif query_type == "color_suggestion":
|
| 816 |
-
clothing_info = extract_clothing_info(enhanced_message)
|
| 817 |
-
base_color = clothing_info.get("color")
|
| 818 |
-
|
| 819 |
-
if not base_color:
|
| 820 |
-
found_colors = extract_colors_from_query(enhanced_message)
|
| 821 |
-
if found_colors:
|
| 822 |
-
base_color = found_colors[0][1]
|
| 823 |
-
elif conv_context["context"].get("last_color"):
|
| 824 |
-
base_color = conv_context["context"]["last_color"]
|
| 825 |
-
|
| 826 |
-
if not base_color:
|
| 827 |
-
return ChatResponse(
|
| 828 |
-
response="I'd love to help you with colors! Could you tell me which color you're working with? For example, 'what colors go with red?'",
|
| 829 |
-
session_id=session_id
|
| 830 |
-
)
|
| 831 |
-
|
| 832 |
-
matching_colors = get_color_matches(base_color)
|
| 833 |
-
clothing_item = clothing_info.get("existing_item") or clothing_info.get("type") or conv_context["context"].get("last_item", "outfit")
|
| 834 |
-
|
| 835 |
-
suggested_colors = [c.title() for c in matching_colors[:4]]
|
| 836 |
-
|
| 837 |
-
message_lower_for_style = message.lower()
|
| 838 |
-
style_keywords = []
|
| 839 |
-
if "stylish" in message_lower_for_style or "standout" in message_lower_for_style or "stand out" in message_lower_for_style:
|
| 840 |
-
style_keywords.append("stylish and eye-catching")
|
| 841 |
-
if "professional" in message_lower_for_style or "formal" in message_lower_for_style:
|
| 842 |
-
style_keywords.append("professional")
|
| 843 |
-
if "casual" in message_lower_for_style:
|
| 844 |
-
style_keywords.append("casual")
|
| 845 |
-
|
| 846 |
-
style_note = ""
|
| 847 |
-
if style_keywords:
|
| 848 |
-
style_note = f" The user wants something {', '.join(style_keywords)}."
|
| 849 |
-
|
| 850 |
-
prompt = f"What colors go well with {base_color} {clothing_item}?{style_note} Give me a natural, conversational answer with specific color suggestions."
|
| 851 |
-
ai_response = generate_chat_response(prompt, max_length=300, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 852 |
-
if len(ai_response) > 30:
|
| 853 |
-
response_text = ai_response
|
| 854 |
-
else:
|
| 855 |
-
response_text = f"For your {base_color} {clothing_item}, I'd suggest pairing it with {', '.join(suggested_colors[:3])}, or {suggested_colors[3] if len(suggested_colors) > 3 else 'other neutrals'}. These colors complement each other beautifully!"
|
| 856 |
-
|
| 857 |
-
update_context(session_id, message, {
|
| 858 |
-
"response": response_text,
|
| 859 |
-
"color": base_color,
|
| 860 |
-
"item": clothing_item,
|
| 861 |
-
"colors": suggested_colors
|
| 862 |
-
})
|
| 863 |
-
|
| 864 |
-
return ChatResponse(
|
| 865 |
-
response=response_text,
|
| 866 |
-
session_id=session_id
|
| 867 |
-
)
|
| 868 |
-
|
| 869 |
-
else:
|
| 870 |
-
clothing_info = extract_clothing_info(enhanced_message)
|
| 871 |
-
|
| 872 |
-
if not clothing_info.get("color") and conv_context["context"].get("last_color"):
|
| 873 |
-
enhanced_message = f"{enhanced_message} {conv_context['context']['last_color']}"
|
| 874 |
-
clothing_info = extract_clothing_info(enhanced_message)
|
| 875 |
-
|
| 876 |
-
context_info = ""
|
| 877 |
-
if clothing_info.get("color"):
|
| 878 |
-
context_info += f"Color preference: {clothing_info.get('color')}. "
|
| 879 |
-
if clothing_info.get("type"):
|
| 880 |
-
context_info += f"Item type: {clothing_info.get('type')}. "
|
| 881 |
-
if clothing_info.get("existing_item"):
|
| 882 |
-
context_info += f"User has: {clothing_info.get('existing_item')}. "
|
| 883 |
-
|
| 884 |
-
occasion_keywords = ["defense", "project", "presentation", "meeting", "interview", "formal", "casual", "party", "wedding"]
|
| 885 |
-
occasion = next((word for word in occasion_keywords if word in enhanced_message.lower()), None)
|
| 886 |
-
if occasion:
|
| 887 |
-
context_info += f"Occasion: {occasion}. "
|
| 888 |
-
|
| 889 |
-
prompt = f"{enhanced_message}"
|
| 890 |
-
if context_info:
|
| 891 |
-
prompt += f"\n\nContext: {context_info.strip()}"
|
| 892 |
-
prompt += "\n\nGive helpful, detailed outfit suggestions that are practical and stylish. Be specific about item combinations and explain why they work well."
|
| 893 |
-
|
| 894 |
-
response_text = generate_chat_response(prompt, max_length=1024, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 895 |
-
|
| 896 |
-
update_context(session_id, message, {
|
| 897 |
-
"response": response_text,
|
| 898 |
-
"color": clothing_info.get("color"),
|
| 899 |
-
"item": clothing_info.get("type") or clothing_info.get("requested_item"),
|
| 900 |
-
"items": clothing_info.get("items", [])
|
| 901 |
-
})
|
| 902 |
-
|
| 903 |
-
return ChatResponse(
|
| 904 |
-
response=response_text,
|
| 905 |
-
session_id=session_id
|
| 906 |
-
)
|
| 907 |
-
|
| 908 |
-
except Exception as e:
|
| 909 |
-
raise HTTPException(status_code=500, detail=f"Error processing chat message: {str(e)}")
|
| 910 |
-
|
| 911 |
-
@app.post("/chat/upload", response_model=ChatResponse)
|
| 912 |
-
async def chat_with_upload(
|
| 913 |
-
message: str = Form(...),
|
| 914 |
-
session_id: str = Form(default="default"),
|
| 915 |
-
wardrobe: Optional[str] = Form(default=None),
|
| 916 |
-
images: List[UploadFile] = File(default=[])
|
| 917 |
-
):
|
| 918 |
-
try:
|
| 919 |
-
wardrobe_items = []
|
| 920 |
-
if wardrobe and wardrobe.strip() and wardrobe.strip() not in ["[]", "", "string"]:
|
| 921 |
-
try:
|
| 922 |
-
wardrobe_data = json.loads(wardrobe)
|
| 923 |
-
if isinstance(wardrobe_data, list):
|
| 924 |
-
wardrobe_items = [WardrobeItem(**item) for item in wardrobe_data]
|
| 925 |
-
except json.JSONDecodeError:
|
| 926 |
-
print(f"[UPLOAD] Ignoring invalid wardrobe value: {wardrobe[:50]}")
|
| 927 |
-
|
| 928 |
-
image_data_urls = []
|
| 929 |
-
for img_file in images:
|
| 930 |
-
if img_file.filename:
|
| 931 |
-
content = await img_file.read()
|
| 932 |
-
content_type = img_file.content_type or "image/jpeg"
|
| 933 |
-
base64_data = base64.b64encode(content).decode("utf-8")
|
| 934 |
-
data_url = f"data:{content_type};base64,{base64_data}"
|
| 935 |
-
image_data_urls.append(data_url)
|
| 936 |
-
print(f"[UPLOAD] Processed image: {img_file.filename} ({len(content)} bytes)")
|
| 937 |
-
|
| 938 |
-
request = ChatRequest(
|
| 939 |
-
message=message,
|
| 940 |
-
session_id=session_id,
|
| 941 |
-
wardrobe=wardrobe_items if wardrobe_items else None,
|
| 942 |
-
images=image_data_urls if image_data_urls else None
|
| 943 |
-
)
|
| 944 |
-
|
| 945 |
-
print(f"[UPLOAD] Processing chat request: message='{message[:50]}...', images={len(image_data_urls)}, wardrobe={len(wardrobe_items)}")
|
| 946 |
-
result = await chat(request)
|
| 947 |
-
print(f"[UPLOAD] Response generated: {len(result.response)} chars")
|
| 948 |
-
return result
|
| 949 |
-
|
| 950 |
-
except Exception as e:
|
| 951 |
-
print(f"[UPLOAD] Error: {e}")
|
| 952 |
-
raise HTTPException(status_code=500, detail=f"Error processing upload: {str(e)}")
|
| 953 |
-
|
| 954 |
-
@app.post("/chat/upload/stream")
|
| 955 |
-
async def chat_with_upload_stream(
|
| 956 |
-
message: str = Form(...),
|
| 957 |
-
session_id: str = Form(default="default"),
|
| 958 |
-
wardrobe: Optional[str] = Form(default=None),
|
| 959 |
-
images: List[UploadFile] = File(default=[])
|
| 960 |
-
):
|
| 961 |
-
image_data_urls = []
|
| 962 |
-
for img_file in images:
|
| 963 |
-
if img_file.filename:
|
| 964 |
-
content = await img_file.read()
|
| 965 |
-
content_type = img_file.content_type or "image/jpeg"
|
| 966 |
-
base64_data = base64.b64encode(content).decode("utf-8")
|
| 967 |
-
data_url = f"data:{content_type};base64,{base64_data}"
|
| 968 |
-
image_data_urls.append(data_url)
|
| 969 |
-
print(f"[STREAM UPLOAD] Processed image: {img_file.filename} ({len(content)} bytes)")
|
| 970 |
-
|
| 971 |
-
rag_chunks = retrieve_relevant_context(message, top_k=3)
|
| 972 |
-
rag_context = format_rag_context(rag_chunks)
|
| 973 |
-
|
| 974 |
-
print(f"[STREAM UPLOAD] Starting streaming response for: {message[:50]}...")
|
| 975 |
-
|
| 976 |
-
async def generate():
|
| 977 |
-
yield f"data: {json.dumps({'type': 'start', 'session_id': session_id})}\n\n"
|
| 978 |
-
|
| 979 |
-
full_response = ""
|
| 980 |
-
async for chunk in generate_chat_response_streaming(
|
| 981 |
-
prompt=message,
|
| 982 |
-
max_length=512,
|
| 983 |
-
temperature=0.7,
|
| 984 |
-
rag_context=rag_context,
|
| 985 |
-
images=image_data_urls if image_data_urls else None
|
| 986 |
-
):
|
| 987 |
-
full_response += chunk
|
| 988 |
-
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk})}\n\n"
|
| 989 |
-
|
| 990 |
-
yield f"data: {json.dumps({'type': 'end', 'full_response': full_response, 'session_id': session_id})}\n\n"
|
| 991 |
-
print(f"[STREAM UPLOAD] Streaming complete: {len(full_response)} chars")
|
| 992 |
-
print(f"[STREAM RESPONSE] {full_response}")
|
| 993 |
-
|
| 994 |
-
return StreamingResponse(
|
| 995 |
-
generate(),
|
| 996 |
-
media_type="text/event-stream",
|
| 997 |
-
headers={
|
| 998 |
-
"Cache-Control": "no-cache",
|
| 999 |
-
"Connection": "keep-alive",
|
| 1000 |
-
"X-Accel-Buffering": "no",
|
| 1001 |
-
}
|
| 1002 |
-
)
|
| 1003 |
|
| 1004 |
if __name__ == "__main__":
|
| 1005 |
import uvicorn
|
| 1006 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 1007 |
-
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from rag import initialize_rag
|
| 4 |
+
from routes import setup_routes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
app = FastAPI(title="Style GPT API", version="1.0.0")
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
app.add_middleware(
|
| 9 |
CORSMiddleware,
|
| 10 |
allow_origins=["*"],
|
|
|
|
| 13 |
allow_headers=["*"],
|
| 14 |
)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
@app.on_event("startup")
|
| 17 |
async def startup_event():
|
| 18 |
print("Initializing RAG system...")
|
| 19 |
initialize_rag()
|
| 20 |
print("RAG ready for requests! Model will load on first request.")
|
| 21 |
|
| 22 |
+
setup_routes(app)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
if __name__ == "__main__":
|
| 25 |
import uvicorn
|
| 26 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
config.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
COLOR_HARMONY = {
|
| 2 |
+
"green": ["beige", "white", "grey", "gray", "navy", "black", "brown"],
|
| 3 |
+
"blue": ["white", "beige", "khaki", "grey", "gray", "navy", "red"],
|
| 4 |
+
"navy": ["white", "beige", "khaki", "grey", "gray", "red", "yellow"],
|
| 5 |
+
"black": ["white", "beige", "grey", "gray", "navy", "red", "blue"],
|
| 6 |
+
"white": ["black", "navy", "grey", "gray", "beige", "any color"],
|
| 7 |
+
"grey": ["black", "white", "navy", "beige", "blue"],
|
| 8 |
+
"gray": ["black", "white", "navy", "beige", "blue"],
|
| 9 |
+
"beige": ["green", "navy", "brown", "black", "white", "blue"],
|
| 10 |
+
"red": ["white", "black", "navy", "blue", "beige"],
|
| 11 |
+
"brown": ["beige", "white", "navy", "green", "blue"],
|
| 12 |
+
"yellow": ["navy", "black", "white", "grey", "gray"],
|
| 13 |
+
"pink": ["white", "navy", "grey", "gray", "black"],
|
| 14 |
+
"purple": ["white", "grey", "gray", "black", "navy"],
|
| 15 |
+
"orange": ["navy", "white", "black", "beige"],
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
CLOTHING_TYPES = [
|
| 19 |
+
"hoodie", "sweatshirt", "t-shirt", "shirt", "blazer", "jacket",
|
| 20 |
+
"jeans", "pants", "chinos", "trousers", "shorts",
|
| 21 |
+
"sneakers", "boots", "shoes", "sandal", "heels",
|
| 22 |
+
"dress", "skirt", "top", "sweater", "cardigan",
|
| 23 |
+
"vest", "coat", "parka", "blouse", "polo"
|
| 24 |
+
]
|
| 25 |
+
|
conversation.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import CLOTHING_TYPES
|
| 2 |
+
|
| 3 |
+
conversation_memory = {}
|
| 4 |
+
|
| 5 |
+
def get_conversation_context(session_id: str) -> dict:
|
| 6 |
+
if session_id not in conversation_memory:
|
| 7 |
+
conversation_memory[session_id] = {
|
| 8 |
+
"messages": [],
|
| 9 |
+
"context": {}
|
| 10 |
+
}
|
| 11 |
+
return conversation_memory[session_id]
|
| 12 |
+
|
| 13 |
+
def update_context(session_id: str, message: str, response_data: dict):
|
| 14 |
+
conv = conversation_memory[session_id]
|
| 15 |
+
conv["messages"].append({"user": message, "assistant": response_data.get("response", "")})
|
| 16 |
+
|
| 17 |
+
if len(conv["messages"]) > 10:
|
| 18 |
+
conv["messages"] = conv["messages"][-10:]
|
| 19 |
+
|
| 20 |
+
if "color" in response_data:
|
| 21 |
+
conv["context"]["last_color"] = response_data["color"]
|
| 22 |
+
if "item" in response_data:
|
| 23 |
+
conv["context"]["last_item"] = response_data["item"]
|
| 24 |
+
if "colors" in response_data:
|
| 25 |
+
conv["context"]["last_colors"] = response_data["colors"]
|
| 26 |
+
if "items" in response_data:
|
| 27 |
+
conv["context"]["last_items"] = response_data["items"]
|
| 28 |
+
|
| 29 |
+
def enhance_message_with_context(message: str, context: dict) -> str:
|
| 30 |
+
message_lower = message.lower()
|
| 31 |
+
|
| 32 |
+
if "what about" in message_lower or "how about" in message_lower or "and" in message_lower:
|
| 33 |
+
if context.get("last_color"):
|
| 34 |
+
if not any(c in message_lower for c in ["color", "red", "blue", "green", "black", "white", "brown"]):
|
| 35 |
+
message = message + f" with {context['last_color']}"
|
| 36 |
+
if context.get("last_item"):
|
| 37 |
+
if not any(item in message_lower for item in CLOTHING_TYPES):
|
| 38 |
+
message = message + f" {context['last_item']}"
|
| 39 |
+
|
| 40 |
+
return message
|
| 41 |
+
|
image_utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import requests
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
def is_placeholder_image(image: Image.Image) -> bool:
|
| 8 |
+
img_array = np.array(image)
|
| 9 |
+
|
| 10 |
+
if len(img_array.shape) != 3:
|
| 11 |
+
return True
|
| 12 |
+
|
| 13 |
+
height, width = img_array.shape[:2]
|
| 14 |
+
|
| 15 |
+
gray = np.mean(img_array, axis=2)
|
| 16 |
+
|
| 17 |
+
unique_colors = len(np.unique(gray))
|
| 18 |
+
|
| 19 |
+
if unique_colors < 10:
|
| 20 |
+
return True
|
| 21 |
+
|
| 22 |
+
black_white_ratio = np.sum((gray < 20) | (gray > 235)) / (height * width)
|
| 23 |
+
|
| 24 |
+
if black_white_ratio > 0.8:
|
| 25 |
+
return True
|
| 26 |
+
|
| 27 |
+
std_dev = np.std(gray)
|
| 28 |
+
if std_dev < 15:
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
sample_size = min(100, height // 10, width // 10)
|
| 32 |
+
if sample_size < 2:
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
step_h = height // sample_size
|
| 36 |
+
step_w = width // sample_size
|
| 37 |
+
|
| 38 |
+
grid_pattern = True
|
| 39 |
+
for i in range(0, height - step_h, step_h):
|
| 40 |
+
for j in range(0, width - step_w, step_w):
|
| 41 |
+
block = gray[i:i+step_h, j:j+step_w]
|
| 42 |
+
block_std = np.std(block)
|
| 43 |
+
if block_std > 30:
|
| 44 |
+
grid_pattern = False
|
| 45 |
+
break
|
| 46 |
+
if not grid_pattern:
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
if grid_pattern and black_white_ratio > 0.5:
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
def load_image_from_url(url_or_base64: str) -> Image.Image:
|
| 55 |
+
try:
|
| 56 |
+
if url_or_base64.startswith("data:image"):
|
| 57 |
+
header, encoded = url_or_base64.split(",", 1)
|
| 58 |
+
image_data = base64.b64decode(encoded)
|
| 59 |
+
return Image.open(BytesIO(image_data)).convert("RGB")
|
| 60 |
+
else:
|
| 61 |
+
response = requests.get(url_or_base64, timeout=10)
|
| 62 |
+
response.raise_for_status()
|
| 63 |
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
| 64 |
+
except Exception as e:
|
| 65 |
+
raise ValueError(f"Failed to load image: {str(e)}")
|
| 66 |
+
|
| 67 |
+
def filter_valid_images(images: list) -> list:
|
| 68 |
+
valid_images = []
|
| 69 |
+
for img in images:
|
| 70 |
+
if not img or not isinstance(img, str) or img.strip() in ["", "string", "null", "undefined"]:
|
| 71 |
+
continue
|
| 72 |
+
try:
|
| 73 |
+
pil_image = load_image_from_url(img)
|
| 74 |
+
if not is_placeholder_image(pil_image):
|
| 75 |
+
valid_images.append(pil_image)
|
| 76 |
+
else:
|
| 77 |
+
print(f"[IMAGE FILTER] Ignoring placeholder/empty image")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"[IMAGE FILTER] Warning: Failed to load image: {e}, skipping")
|
| 80 |
+
continue
|
| 81 |
+
return valid_images
|
| 82 |
+
|
model_manager.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import threading
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List, Optional, AsyncGenerator
|
| 5 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 6 |
+
from qwen_vl_utils import process_vision_info
|
| 7 |
+
from fastapi import HTTPException
|
| 8 |
+
from image_utils import filter_valid_images
|
| 9 |
+
|
| 10 |
+
style_model = None
|
| 11 |
+
style_processor = None
|
| 12 |
+
model_lock = threading.Lock()
|
| 13 |
+
|
| 14 |
+
def ensure_model_loaded():
|
| 15 |
+
global style_model, style_processor
|
| 16 |
+
|
| 17 |
+
if style_model is not None:
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
print("Loading model (lazy load on first request)...")
|
| 21 |
+
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 22 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
if hf_token:
|
| 26 |
+
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 27 |
+
model_id,
|
| 28 |
+
token=hf_token,
|
| 29 |
+
torch_dtype=torch.float32,
|
| 30 |
+
device_map="auto",
|
| 31 |
+
low_cpu_mem_usage=True
|
| 32 |
+
)
|
| 33 |
+
style_processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
|
| 34 |
+
else:
|
| 35 |
+
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 36 |
+
model_id,
|
| 37 |
+
torch_dtype=torch.float32,
|
| 38 |
+
device_map="auto",
|
| 39 |
+
low_cpu_mem_usage=True
|
| 40 |
+
)
|
| 41 |
+
style_processor = AutoProcessor.from_pretrained(model_id)
|
| 42 |
+
|
| 43 |
+
print(f"Loaded {model_id}")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Error loading model: {e}")
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
def generate_chat_response(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> str:
|
| 49 |
+
ensure_model_loaded()
|
| 50 |
+
|
| 51 |
+
system_message = system_override if system_override else "You are StyleGPT, a friendly and helpful fashion stylist assistant. You give natural, conversational advice about clothing, colors, and outfit combinations. Always be warm, friendly, and advisory in your responses. When asked your name, say you're StyleGPT. When greeted, respond warmly and offer to help with fashion advice."
|
| 52 |
+
|
| 53 |
+
if rag_context:
|
| 54 |
+
system_message += f"\n\n{rag_context}\n\nUse this fashion knowledge to provide accurate and helpful advice. Reference this knowledge naturally in your responses."
|
| 55 |
+
|
| 56 |
+
user_content = []
|
| 57 |
+
if images:
|
| 58 |
+
valid_images = filter_valid_images(images)
|
| 59 |
+
for pil_image in valid_images:
|
| 60 |
+
user_content.append({"type": "image", "image": pil_image})
|
| 61 |
+
user_content.append({"type": "text", "text": prompt})
|
| 62 |
+
|
| 63 |
+
messages = [
|
| 64 |
+
{"role": "system", "content": system_message},
|
| 65 |
+
{"role": "user", "content": user_content}
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
text = style_processor.apply_chat_template(
|
| 70 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 74 |
+
|
| 75 |
+
inputs = style_processor(
|
| 76 |
+
text=[text],
|
| 77 |
+
images=image_inputs,
|
| 78 |
+
videos=video_inputs,
|
| 79 |
+
padding=True,
|
| 80 |
+
return_tensors="pt",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
inputs = {k: v.to(style_model.device) for k, v in inputs.items()}
|
| 84 |
+
|
| 85 |
+
temperature = max(0.1, min(1.5, temperature))
|
| 86 |
+
|
| 87 |
+
with model_lock:
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
try:
|
| 90 |
+
outputs = style_model.generate(
|
| 91 |
+
**inputs,
|
| 92 |
+
max_new_tokens=max_length,
|
| 93 |
+
temperature=temperature,
|
| 94 |
+
top_p=0.95,
|
| 95 |
+
top_k=50,
|
| 96 |
+
do_sample=True,
|
| 97 |
+
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 98 |
+
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 99 |
+
repetition_penalty=1.1,
|
| 100 |
+
)
|
| 101 |
+
except RuntimeError as e:
|
| 102 |
+
if "probability tensor" in str(e) or "inf" in str(e) or "nan" in str(e):
|
| 103 |
+
print(f"[GENERATE] Probability error, retrying with greedy decoding")
|
| 104 |
+
outputs = style_model.generate(
|
| 105 |
+
**inputs,
|
| 106 |
+
max_new_tokens=max_length,
|
| 107 |
+
do_sample=False,
|
| 108 |
+
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 109 |
+
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 110 |
+
repetition_penalty=1.1,
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
raise
|
| 114 |
+
|
| 115 |
+
generated_ids_trimmed = [
|
| 116 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
|
| 117 |
+
]
|
| 118 |
+
generated_text = style_processor.batch_decode(
|
| 119 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 120 |
+
)[0]
|
| 121 |
+
|
| 122 |
+
generated_text = generated_text.strip()
|
| 123 |
+
return generated_text
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"[GENERATE] Error: {e}")
|
| 127 |
+
import traceback
|
| 128 |
+
traceback.print_exc()
|
| 129 |
+
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
|
| 130 |
+
|
| 131 |
+
async def generate_chat_response_streaming(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
|
| 132 |
+
ensure_model_loaded()
|
| 133 |
+
|
| 134 |
+
system_message = system_override if system_override else "You are StyleGPT, a friendly and helpful fashion stylist assistant. You give natural, conversational advice about clothing, colors, and outfit combinations. Always be warm, friendly, and advisory in your responses. When asked your name, say you're StyleGPT. When greeted, respond warmly and offer to help with fashion advice."
|
| 135 |
+
|
| 136 |
+
if rag_context:
|
| 137 |
+
system_message += f"\n\n{rag_context}\n\nUse this fashion knowledge to provide accurate and helpful advice. Reference this knowledge naturally in your responses."
|
| 138 |
+
|
| 139 |
+
user_content = []
|
| 140 |
+
if images:
|
| 141 |
+
valid_images = filter_valid_images(images)
|
| 142 |
+
for pil_image in valid_images:
|
| 143 |
+
user_content.append({"type": "image", "image": pil_image})
|
| 144 |
+
user_content.append({"type": "text", "text": prompt})
|
| 145 |
+
|
| 146 |
+
messages = [
|
| 147 |
+
{"role": "system", "content": system_message},
|
| 148 |
+
{"role": "user", "content": user_content}
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
text = style_processor.apply_chat_template(
|
| 153 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 157 |
+
|
| 158 |
+
inputs = style_processor(
|
| 159 |
+
text=[text],
|
| 160 |
+
images=image_inputs,
|
| 161 |
+
videos=video_inputs,
|
| 162 |
+
padding=True,
|
| 163 |
+
return_tensors="pt",
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
inputs = {k: v.to(style_model.device) for k, v in inputs.items()}
|
| 167 |
+
|
| 168 |
+
temperature = max(0.1, min(1.5, temperature))
|
| 169 |
+
|
| 170 |
+
with model_lock:
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
try:
|
| 173 |
+
outputs = style_model.generate(
|
| 174 |
+
**inputs,
|
| 175 |
+
max_new_tokens=max_length,
|
| 176 |
+
temperature=temperature,
|
| 177 |
+
top_p=0.95,
|
| 178 |
+
top_k=50,
|
| 179 |
+
do_sample=True,
|
| 180 |
+
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 181 |
+
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 182 |
+
repetition_penalty=1.1,
|
| 183 |
+
)
|
| 184 |
+
except RuntimeError as e:
|
| 185 |
+
if "probability tensor" in str(e) or "inf" in str(e) or "nan" in str(e):
|
| 186 |
+
print(f"[GENERATE STREAM] Probability error, retrying with greedy decoding")
|
| 187 |
+
outputs = style_model.generate(
|
| 188 |
+
**inputs,
|
| 189 |
+
max_new_tokens=max_length,
|
| 190 |
+
do_sample=False,
|
| 191 |
+
eos_token_id=style_processor.tokenizer.eos_token_id,
|
| 192 |
+
pad_token_id=style_processor.tokenizer.pad_token_id,
|
| 193 |
+
repetition_penalty=1.1,
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
raise
|
| 197 |
+
|
| 198 |
+
generated_ids_trimmed = [
|
| 199 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
|
| 200 |
+
]
|
| 201 |
+
generated_text = style_processor.batch_decode(
|
| 202 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 203 |
+
)[0]
|
| 204 |
+
|
| 205 |
+
generated_text = generated_text.strip()
|
| 206 |
+
|
| 207 |
+
import asyncio
|
| 208 |
+
for char in generated_text:
|
| 209 |
+
yield char
|
| 210 |
+
await asyncio.sleep(0.01)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
print(f"[GENERATE STREAM] Error: {e}")
|
| 213 |
+
import traceback
|
| 214 |
+
traceback.print_exc()
|
| 215 |
+
error_msg = f"I apologize, but I encountered an error generating a response. Please try again."
|
| 216 |
+
import asyncio
|
| 217 |
+
for char in error_msg:
|
| 218 |
+
yield char
|
| 219 |
+
await asyncio.sleep(0.01)
|
| 220 |
+
|
models.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
class WardrobeItem(BaseModel):
|
| 5 |
+
id: Optional[int] = None
|
| 6 |
+
category: str
|
| 7 |
+
style: str
|
| 8 |
+
color: Optional[str] = None
|
| 9 |
+
brand: Optional[str] = None
|
| 10 |
+
name: Optional[str] = None
|
| 11 |
+
|
| 12 |
+
class ChatRequest(BaseModel):
|
| 13 |
+
message: str
|
| 14 |
+
session_id: Optional[str] = "default"
|
| 15 |
+
wardrobe: Optional[List[WardrobeItem]] = None
|
| 16 |
+
images: Optional[List[str]] = None
|
| 17 |
+
|
| 18 |
+
class ChatResponse(BaseModel):
|
| 19 |
+
response: str
|
| 20 |
+
session_id: str
|
| 21 |
+
|
| 22 |
+
class TextRequest(BaseModel):
|
| 23 |
+
message: str
|
| 24 |
+
session_id: Optional[str] = "default"
|
| 25 |
+
|
query_processing.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List
|
| 3 |
+
from config import COLOR_HARMONY, CLOTHING_TYPES
|
| 4 |
+
|
| 5 |
+
def extract_clothing_info(text: str) -> dict:
|
| 6 |
+
text_lower = text.lower()
|
| 7 |
+
|
| 8 |
+
colors = list(COLOR_HARMONY.keys())
|
| 9 |
+
found_color = None
|
| 10 |
+
for color in colors:
|
| 11 |
+
if color in text_lower:
|
| 12 |
+
found_color = color
|
| 13 |
+
break
|
| 14 |
+
|
| 15 |
+
found_types = []
|
| 16 |
+
for clothing_type in CLOTHING_TYPES:
|
| 17 |
+
if clothing_type in text_lower:
|
| 18 |
+
found_types.append(clothing_type)
|
| 19 |
+
|
| 20 |
+
existing_item = None
|
| 21 |
+
requested_item = None
|
| 22 |
+
|
| 23 |
+
question_patterns = ["what kind of", "what", "which", "suggest", "recommend"]
|
| 24 |
+
is_question = any(pattern in text_lower for pattern in question_patterns)
|
| 25 |
+
|
| 26 |
+
if is_question and len(found_types) > 0:
|
| 27 |
+
if "my" in text_lower or "i have" in text_lower or "i'm wearing" in text_lower:
|
| 28 |
+
for i, word in enumerate(text_lower.split()):
|
| 29 |
+
if word == "my" and i + 1 < len(text_lower.split()):
|
| 30 |
+
next_words = " ".join(text_lower.split()[i+1:i+4])
|
| 31 |
+
for ct in found_types:
|
| 32 |
+
if ct in next_words:
|
| 33 |
+
existing_item = ct
|
| 34 |
+
break
|
| 35 |
+
if existing_item:
|
| 36 |
+
break
|
| 37 |
+
|
| 38 |
+
for ct in found_types:
|
| 39 |
+
if "match" in text_lower or "go with" in text_lower or "pair" in text_lower:
|
| 40 |
+
ct_pos = text_lower.find(ct)
|
| 41 |
+
match_pos = text_lower.find("match")
|
| 42 |
+
if ct_pos > match_pos or "what kind of" in text_lower[:ct_pos]:
|
| 43 |
+
requested_item = ct
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
if not existing_item and not requested_item and found_types:
|
| 47 |
+
if is_question:
|
| 48 |
+
requested_item = found_types[0]
|
| 49 |
+
else:
|
| 50 |
+
existing_item = found_types[0]
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
"color": found_color,
|
| 54 |
+
"type": existing_item or requested_item,
|
| 55 |
+
"existing_item": existing_item,
|
| 56 |
+
"requested_item": requested_item,
|
| 57 |
+
"is_question": is_question,
|
| 58 |
+
"raw_text": text
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def get_color_matches(color: str) -> List[str]:
|
| 62 |
+
color_lower = color.lower()
|
| 63 |
+
return COLOR_HARMONY.get(color_lower, ["white", "black", "grey", "beige"])
|
| 64 |
+
|
| 65 |
+
def extract_colors_from_query(query: str) -> tuple:
|
| 66 |
+
query_lower = query.lower()
|
| 67 |
+
colors = list(COLOR_HARMONY.keys())
|
| 68 |
+
|
| 69 |
+
extended_colors = ["navy blue", "wine", "burgundy", "maroon", "crimson", "scarlet", "mauve", "taupe", "olive", "teal", "turquoise", "indigo", "cobalt"]
|
| 70 |
+
all_colors = extended_colors + colors
|
| 71 |
+
|
| 72 |
+
color_mapping = {
|
| 73 |
+
"wine": "red", "burgundy": "red", "maroon": "red",
|
| 74 |
+
"mauve": "purple", "taupe": "beige", "olive": "green",
|
| 75 |
+
"teal": "blue", "turquoise": "blue", "indigo": "navy", "cobalt": "blue",
|
| 76 |
+
"navy blue": "navy"
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
found_colors = []
|
| 80 |
+
seen_mapped = set()
|
| 81 |
+
for color in all_colors:
|
| 82 |
+
if color in query_lower:
|
| 83 |
+
mapped = color_mapping.get(color, color)
|
| 84 |
+
if mapped not in seen_mapped:
|
| 85 |
+
found_colors.append((color, mapped))
|
| 86 |
+
seen_mapped.add(mapped)
|
| 87 |
+
|
| 88 |
+
return found_colors
|
| 89 |
+
|
| 90 |
+
def detect_query_type(message: str) -> str:
|
| 91 |
+
message_lower = message.lower()
|
| 92 |
+
|
| 93 |
+
color_comparison_patterns = ["does", "do", "will", "can"]
|
| 94 |
+
comparison_keywords = ["go with", "match", "work with", "pair with", "combine with"]
|
| 95 |
+
color_suggestion_patterns = ["what color", "which color", "colors go with", "colors match", "better color", "color to match"]
|
| 96 |
+
|
| 97 |
+
has_two_colors = len(extract_colors_from_query(message)) >= 2
|
| 98 |
+
|
| 99 |
+
if any(pattern in message_lower for pattern in color_comparison_patterns) and any(kw in message_lower for kw in comparison_keywords) and has_two_colors:
|
| 100 |
+
return "color_compatibility"
|
| 101 |
+
|
| 102 |
+
outfit_request_patterns = ["suggest", "recommend", "outfit", "wear", "dress", "thinking of", "what should i wear", "what can i wear", "what to wear"]
|
| 103 |
+
what_matches_patterns = ["what will go", "what goes with", "what matches", "what will match", "what can go"]
|
| 104 |
+
|
| 105 |
+
if any(pattern in message_lower for pattern in what_matches_patterns) and any(item in message_lower for item in CLOTHING_TYPES):
|
| 106 |
+
return "outfit_suggestion"
|
| 107 |
+
|
| 108 |
+
if any(pattern in message_lower for pattern in outfit_request_patterns):
|
| 109 |
+
return "outfit_suggestion"
|
| 110 |
+
|
| 111 |
+
if any(pattern in message_lower for pattern in color_suggestion_patterns):
|
| 112 |
+
return "color_suggestion"
|
| 113 |
+
|
| 114 |
+
if any(ct in message_lower for ct in CLOTHING_TYPES) and any(c in message_lower for c in ["with", "match", "go", "pair", "style", "stylish", "look"]):
|
| 115 |
+
return "outfit_suggestion"
|
| 116 |
+
|
| 117 |
+
if len(extract_colors_from_query(message)) > 0 and any(word in message_lower for word in ["match", "with", "go", "pair"]):
|
| 118 |
+
return "color_suggestion"
|
| 119 |
+
|
| 120 |
+
if any(ct in message_lower for ct in CLOTHING_TYPES):
|
| 121 |
+
return "outfit_suggestion"
|
| 122 |
+
|
| 123 |
+
return "outfit_suggestion"
|
| 124 |
+
|
| 125 |
+
def is_greeting(message: str) -> bool:
|
| 126 |
+
message_lower = message.lower().strip()
|
| 127 |
+
greetings = [
|
| 128 |
+
"hello", "hi", "hey", "good morning", "good afternoon", "good evening",
|
| 129 |
+
"greetings", "howdy", "what's up", "whats up", "sup", "yo"
|
| 130 |
+
]
|
| 131 |
+
return any(message_lower.startswith(g) or message_lower == g for g in greetings)
|
| 132 |
+
|
| 133 |
+
def is_name_question(message: str) -> bool:
|
| 134 |
+
message_lower = message.lower().strip()
|
| 135 |
+
name_patterns = [
|
| 136 |
+
"what is your name", "what's your name", "whats your name",
|
| 137 |
+
"who are you", "what are you", "tell me your name", "your name"
|
| 138 |
+
]
|
| 139 |
+
return any(pattern in message_lower for pattern in name_patterns)
|
| 140 |
+
|
routes.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import base64
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from fastapi import HTTPException, UploadFile, File, Form
|
| 5 |
+
from fastapi.responses import StreamingResponse
|
| 6 |
+
from models import ChatRequest, ChatResponse, WardrobeItem, TextRequest
|
| 7 |
+
from query_processing import (
|
| 8 |
+
extract_clothing_info, extract_colors_from_query, detect_query_type,
|
| 9 |
+
get_color_matches, is_greeting, is_name_question
|
| 10 |
+
)
|
| 11 |
+
from conversation import (
|
| 12 |
+
get_conversation_context, enhance_message_with_context, update_context
|
| 13 |
+
)
|
| 14 |
+
from model_manager import generate_chat_response, generate_chat_response_streaming
|
| 15 |
+
from wardrobe import handle_wardrobe_chat
|
| 16 |
+
from rag import retrieve_relevant_context, format_rag_context
|
| 17 |
+
from config import COLOR_HARMONY
|
| 18 |
+
from model_manager import style_model
|
| 19 |
+
|
| 20 |
+
def setup_routes(app):
|
| 21 |
+
@app.get("/")
|
| 22 |
+
async def root():
|
| 23 |
+
return {
|
| 24 |
+
"message": "Style GPT API - Milestone 1",
|
| 25 |
+
"version": "1.0.0",
|
| 26 |
+
"endpoints": {
|
| 27 |
+
"/text": "POST - Text-only chat",
|
| 28 |
+
"/chat": "POST - Chat with optional images",
|
| 29 |
+
"/chat/upload": "POST - Chat with file upload",
|
| 30 |
+
"/chat/upload/stream": "POST - Streaming chat with file upload",
|
| 31 |
+
"/health": "GET - Health check"
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
@app.get("/health")
|
| 36 |
+
async def health_check():
|
| 37 |
+
return {
|
| 38 |
+
"status": "healthy" if style_model is not None else "loading",
|
| 39 |
+
"model_loaded": style_model is not None,
|
| 40 |
+
"model_name": "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
@app.post("/text", response_model=ChatResponse)
|
| 44 |
+
async def text_only(request: TextRequest):
|
| 45 |
+
try:
|
| 46 |
+
message = request.message.strip()
|
| 47 |
+
session_id = request.session_id
|
| 48 |
+
|
| 49 |
+
if not message:
|
| 50 |
+
raise HTTPException(status_code=400, detail="Message cannot be empty")
|
| 51 |
+
|
| 52 |
+
conv_context = get_conversation_context(session_id)
|
| 53 |
+
|
| 54 |
+
if is_name_question(message):
|
| 55 |
+
prompt = "What is your name? Respond naturally and friendly."
|
| 56 |
+
rag_chunks = retrieve_relevant_context(message, top_k=2)
|
| 57 |
+
rag_context = format_rag_context(rag_chunks)
|
| 58 |
+
response_text = generate_chat_response(prompt, max_length=100, temperature=0.8, rag_context=rag_context, images=None)
|
| 59 |
+
update_context(session_id, message, {"response": response_text})
|
| 60 |
+
return ChatResponse(response=response_text, session_id=session_id)
|
| 61 |
+
|
| 62 |
+
if is_greeting(message):
|
| 63 |
+
prompt = f"{message} Respond warmly and offer to help with fashion advice."
|
| 64 |
+
rag_chunks = retrieve_relevant_context(message, top_k=2)
|
| 65 |
+
rag_context = format_rag_context(rag_chunks)
|
| 66 |
+
response_text = generate_chat_response(prompt, max_length=150, temperature=0.8, rag_context=rag_context, images=None)
|
| 67 |
+
update_context(session_id, message, {"response": response_text})
|
| 68 |
+
return ChatResponse(response=response_text, session_id=session_id)
|
| 69 |
+
|
| 70 |
+
enhanced_message = enhance_message_with_context(message, conv_context["context"])
|
| 71 |
+
query_type = detect_query_type(enhanced_message)
|
| 72 |
+
rag_chunks = retrieve_relevant_context(enhanced_message, top_k=3)
|
| 73 |
+
rag_context = format_rag_context(rag_chunks)
|
| 74 |
+
|
| 75 |
+
if query_type == "color_compatibility":
|
| 76 |
+
found_colors = extract_colors_from_query(enhanced_message)
|
| 77 |
+
|
| 78 |
+
if len(found_colors) >= 2:
|
| 79 |
+
color1_mapped = found_colors[0][1]
|
| 80 |
+
color2_mapped = found_colors[1][1]
|
| 81 |
+
color1_original = found_colors[0][0]
|
| 82 |
+
color2_original = found_colors[1][0]
|
| 83 |
+
|
| 84 |
+
compatible = False
|
| 85 |
+
if color1_mapped in COLOR_HARMONY:
|
| 86 |
+
compatible = color2_mapped in COLOR_HARMONY[color1_mapped]
|
| 87 |
+
elif color2_mapped in COLOR_HARMONY:
|
| 88 |
+
compatible = color1_mapped in COLOR_HARMONY[color2_mapped]
|
| 89 |
+
|
| 90 |
+
neutrals = ["white", "black", "grey", "gray", "beige", "navy"]
|
| 91 |
+
if color1_mapped in neutrals or color2_mapped in neutrals:
|
| 92 |
+
compatible = True
|
| 93 |
+
|
| 94 |
+
if compatible:
|
| 95 |
+
response_text = f"Yes, {color1_original.title()} will go well with {color2_original.title()}. They create a balanced and stylish combination that works great together!"
|
| 96 |
+
else:
|
| 97 |
+
response_text = f"{color1_original.title()} and {color2_original.title()} can work together, though you might want to add some neutral pieces to balance the look."
|
| 98 |
+
|
| 99 |
+
prompt = f"Does {color1_original} go well with {color2_original}? Answer naturally and conversationally."
|
| 100 |
+
ai_response = generate_chat_response(prompt, max_length=150, temperature=0.8, rag_context=rag_context, images=None)
|
| 101 |
+
if len(ai_response) > 15:
|
| 102 |
+
response_text = ai_response
|
| 103 |
+
|
| 104 |
+
update_context(session_id, message, {
|
| 105 |
+
"response": response_text,
|
| 106 |
+
"color": color1_original,
|
| 107 |
+
"colors": [color1_original, color2_original]
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
return ChatResponse(
|
| 111 |
+
response=response_text,
|
| 112 |
+
session_id=session_id
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
elif query_type == "color_suggestion":
|
| 116 |
+
clothing_info = extract_clothing_info(enhanced_message)
|
| 117 |
+
base_color = clothing_info.get("color")
|
| 118 |
+
|
| 119 |
+
if not base_color:
|
| 120 |
+
found_colors = extract_colors_from_query(enhanced_message)
|
| 121 |
+
if found_colors:
|
| 122 |
+
base_color = found_colors[0][1]
|
| 123 |
+
elif conv_context["context"].get("last_color"):
|
| 124 |
+
base_color = conv_context["context"]["last_color"]
|
| 125 |
+
|
| 126 |
+
if not base_color:
|
| 127 |
+
return ChatResponse(
|
| 128 |
+
response="I'd love to help you with colors! Could you tell me which color you're working with? For example, 'what colors go with red?'",
|
| 129 |
+
session_id=session_id
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
matching_colors = get_color_matches(base_color)
|
| 133 |
+
clothing_item = clothing_info.get("existing_item") or clothing_info.get("type") or conv_context["context"].get("last_item", "outfit")
|
| 134 |
+
|
| 135 |
+
suggested_colors = [c.title() for c in matching_colors[:4]]
|
| 136 |
+
|
| 137 |
+
message_lower_for_style = message.lower()
|
| 138 |
+
style_keywords = []
|
| 139 |
+
if "stylish" in message_lower_for_style or "standout" in message_lower_for_style or "stand out" in message_lower_for_style:
|
| 140 |
+
style_keywords.append("stylish and eye-catching")
|
| 141 |
+
if "professional" in message_lower_for_style or "formal" in message_lower_for_style:
|
| 142 |
+
style_keywords.append("professional")
|
| 143 |
+
if "casual" in message_lower_for_style:
|
| 144 |
+
style_keywords.append("casual")
|
| 145 |
+
|
| 146 |
+
style_note = ""
|
| 147 |
+
if style_keywords:
|
| 148 |
+
style_note = f" The user wants something {', '.join(style_keywords)}."
|
| 149 |
+
|
| 150 |
+
prompt = f"What colors go well with {base_color} {clothing_item}?{style_note} Give me a natural, conversational answer with specific color suggestions."
|
| 151 |
+
ai_response = generate_chat_response(prompt, max_length=300, temperature=0.8, rag_context=rag_context, images=None)
|
| 152 |
+
if len(ai_response) > 30:
|
| 153 |
+
response_text = ai_response
|
| 154 |
+
else:
|
| 155 |
+
response_text = f"For your {base_color} {clothing_item}, I'd suggest pairing it with {', '.join(suggested_colors[:3])}, or {suggested_colors[3] if len(suggested_colors) > 3 else 'other neutrals'}. These colors complement each other beautifully!"
|
| 156 |
+
|
| 157 |
+
update_context(session_id, message, {
|
| 158 |
+
"response": response_text,
|
| 159 |
+
"color": base_color,
|
| 160 |
+
"item": clothing_item,
|
| 161 |
+
"colors": suggested_colors
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
return ChatResponse(
|
| 165 |
+
response=response_text,
|
| 166 |
+
session_id=session_id
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
else:
|
| 170 |
+
clothing_info = extract_clothing_info(enhanced_message)
|
| 171 |
+
|
| 172 |
+
if not clothing_info.get("color") and conv_context["context"].get("last_color"):
|
| 173 |
+
enhanced_message = f"{enhanced_message} {conv_context['context']['last_color']}"
|
| 174 |
+
clothing_info = extract_clothing_info(enhanced_message)
|
| 175 |
+
|
| 176 |
+
context_info = ""
|
| 177 |
+
if clothing_info.get("color"):
|
| 178 |
+
context_info += f"Color preference: {clothing_info.get('color')}. "
|
| 179 |
+
if clothing_info.get("type"):
|
| 180 |
+
context_info += f"Item type: {clothing_info.get('type')}. "
|
| 181 |
+
if clothing_info.get("existing_item"):
|
| 182 |
+
context_info += f"User has: {clothing_info.get('existing_item')}. "
|
| 183 |
+
|
| 184 |
+
occasion_keywords = ["defense", "project", "presentation", "meeting", "interview", "formal", "casual", "party", "wedding"]
|
| 185 |
+
occasion = next((word for word in occasion_keywords if word in enhanced_message.lower()), None)
|
| 186 |
+
if occasion:
|
| 187 |
+
context_info += f"Occasion: {occasion}. "
|
| 188 |
+
|
| 189 |
+
prompt = f"{enhanced_message}"
|
| 190 |
+
if context_info:
|
| 191 |
+
prompt += f"\n\nContext: {context_info.strip()}"
|
| 192 |
+
prompt += "\n\nGive helpful, detailed outfit suggestions that are practical and stylish. Be specific about item combinations and explain why they work well."
|
| 193 |
+
|
| 194 |
+
response_text = generate_chat_response(prompt, max_length=1024, temperature=0.8, rag_context=rag_context, images=None)
|
| 195 |
+
|
| 196 |
+
update_context(session_id, message, {
|
| 197 |
+
"response": response_text,
|
| 198 |
+
"color": clothing_info.get("color"),
|
| 199 |
+
"item": clothing_info.get("type") or clothing_info.get("requested_item"),
|
| 200 |
+
"items": clothing_info.get("items", [])
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
return ChatResponse(
|
| 204 |
+
response=response_text,
|
| 205 |
+
session_id=session_id
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
raise HTTPException(status_code=500, detail=f"Error processing text message: {str(e)}")
|
| 210 |
+
|
| 211 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 212 |
+
async def chat(request: ChatRequest):
|
| 213 |
+
try:
|
| 214 |
+
message = request.message.strip()
|
| 215 |
+
session_id = request.session_id
|
| 216 |
+
|
| 217 |
+
if not message:
|
| 218 |
+
raise HTTPException(status_code=400, detail="Message cannot be empty")
|
| 219 |
+
|
| 220 |
+
if request.wardrobe and len(request.wardrobe) > 0:
|
| 221 |
+
print(f"[WARDROBE CHAT] ===== WARDROBE REQUEST DETECTED =====")
|
| 222 |
+
return await handle_wardrobe_chat(message, request.wardrobe, session_id, images=request.images)
|
| 223 |
+
|
| 224 |
+
conv_context = get_conversation_context(session_id)
|
| 225 |
+
|
| 226 |
+
if is_name_question(message):
|
| 227 |
+
prompt = "What is your name? Respond naturally and friendly."
|
| 228 |
+
rag_chunks = retrieve_relevant_context(message, top_k=2)
|
| 229 |
+
rag_context = format_rag_context(rag_chunks)
|
| 230 |
+
response_text = generate_chat_response(prompt, max_length=100, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 231 |
+
update_context(session_id, message, {"response": response_text})
|
| 232 |
+
return ChatResponse(response=response_text, session_id=session_id)
|
| 233 |
+
|
| 234 |
+
if is_greeting(message):
|
| 235 |
+
prompt = f"{message} Respond warmly and offer to help with fashion advice."
|
| 236 |
+
rag_chunks = retrieve_relevant_context(message, top_k=2)
|
| 237 |
+
rag_context = format_rag_context(rag_chunks)
|
| 238 |
+
response_text = generate_chat_response(prompt, max_length=150, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 239 |
+
update_context(session_id, message, {"response": response_text})
|
| 240 |
+
return ChatResponse(response=response_text, session_id=session_id)
|
| 241 |
+
|
| 242 |
+
enhanced_message = enhance_message_with_context(message, conv_context["context"])
|
| 243 |
+
query_type = detect_query_type(enhanced_message)
|
| 244 |
+
rag_chunks = retrieve_relevant_context(enhanced_message, top_k=3)
|
| 245 |
+
rag_context = format_rag_context(rag_chunks)
|
| 246 |
+
|
| 247 |
+
if query_type == "color_compatibility":
|
| 248 |
+
found_colors = extract_colors_from_query(enhanced_message)
|
| 249 |
+
|
| 250 |
+
if len(found_colors) >= 2:
|
| 251 |
+
color1_mapped = found_colors[0][1]
|
| 252 |
+
color2_mapped = found_colors[1][1]
|
| 253 |
+
color1_original = found_colors[0][0]
|
| 254 |
+
color2_original = found_colors[1][0]
|
| 255 |
+
|
| 256 |
+
compatible = False
|
| 257 |
+
if color1_mapped in COLOR_HARMONY:
|
| 258 |
+
compatible = color2_mapped in COLOR_HARMONY[color1_mapped]
|
| 259 |
+
elif color2_mapped in COLOR_HARMONY:
|
| 260 |
+
compatible = color1_mapped in COLOR_HARMONY[color2_mapped]
|
| 261 |
+
|
| 262 |
+
neutrals = ["white", "black", "grey", "gray", "beige", "navy"]
|
| 263 |
+
if color1_mapped in neutrals or color2_mapped in neutrals:
|
| 264 |
+
compatible = True
|
| 265 |
+
|
| 266 |
+
if compatible:
|
| 267 |
+
response_text = f"Yes, {color1_original.title()} will go well with {color2_original.title()}. They create a balanced and stylish combination that works great together!"
|
| 268 |
+
else:
|
| 269 |
+
response_text = f"{color1_original.title()} and {color2_original.title()} can work together, though you might want to add some neutral pieces to balance the look."
|
| 270 |
+
|
| 271 |
+
prompt = f"Does {color1_original} go well with {color2_original}? Answer naturally and conversationally."
|
| 272 |
+
ai_response = generate_chat_response(prompt, max_length=150, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 273 |
+
if len(ai_response) > 15:
|
| 274 |
+
response_text = ai_response
|
| 275 |
+
|
| 276 |
+
update_context(session_id, message, {
|
| 277 |
+
"response": response_text,
|
| 278 |
+
"color": color1_original,
|
| 279 |
+
"colors": [color1_original, color2_original]
|
| 280 |
+
})
|
| 281 |
+
|
| 282 |
+
return ChatResponse(
|
| 283 |
+
response=response_text,
|
| 284 |
+
session_id=session_id
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
elif query_type == "color_suggestion":
|
| 288 |
+
clothing_info = extract_clothing_info(enhanced_message)
|
| 289 |
+
base_color = clothing_info.get("color")
|
| 290 |
+
|
| 291 |
+
if not base_color:
|
| 292 |
+
found_colors = extract_colors_from_query(enhanced_message)
|
| 293 |
+
if found_colors:
|
| 294 |
+
base_color = found_colors[0][1]
|
| 295 |
+
elif conv_context["context"].get("last_color"):
|
| 296 |
+
base_color = conv_context["context"]["last_color"]
|
| 297 |
+
|
| 298 |
+
if not base_color:
|
| 299 |
+
return ChatResponse(
|
| 300 |
+
response="I'd love to help you with colors! Could you tell me which color you're working with? For example, 'what colors go with red?'",
|
| 301 |
+
session_id=session_id
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
matching_colors = get_color_matches(base_color)
|
| 305 |
+
clothing_item = clothing_info.get("existing_item") or clothing_info.get("type") or conv_context["context"].get("last_item", "outfit")
|
| 306 |
+
|
| 307 |
+
suggested_colors = [c.title() for c in matching_colors[:4]]
|
| 308 |
+
|
| 309 |
+
message_lower_for_style = message.lower()
|
| 310 |
+
style_keywords = []
|
| 311 |
+
if "stylish" in message_lower_for_style or "standout" in message_lower_for_style or "stand out" in message_lower_for_style:
|
| 312 |
+
style_keywords.append("stylish and eye-catching")
|
| 313 |
+
if "professional" in message_lower_for_style or "formal" in message_lower_for_style:
|
| 314 |
+
style_keywords.append("professional")
|
| 315 |
+
if "casual" in message_lower_for_style:
|
| 316 |
+
style_keywords.append("casual")
|
| 317 |
+
|
| 318 |
+
style_note = ""
|
| 319 |
+
if style_keywords:
|
| 320 |
+
style_note = f" The user wants something {', '.join(style_keywords)}."
|
| 321 |
+
|
| 322 |
+
prompt = f"What colors go well with {base_color} {clothing_item}?{style_note} Give me a natural, conversational answer with specific color suggestions."
|
| 323 |
+
ai_response = generate_chat_response(prompt, max_length=300, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 324 |
+
if len(ai_response) > 30:
|
| 325 |
+
response_text = ai_response
|
| 326 |
+
else:
|
| 327 |
+
response_text = f"For your {base_color} {clothing_item}, I'd suggest pairing it with {', '.join(suggested_colors[:3])}, or {suggested_colors[3] if len(suggested_colors) > 3 else 'other neutrals'}. These colors complement each other beautifully!"
|
| 328 |
+
|
| 329 |
+
update_context(session_id, message, {
|
| 330 |
+
"response": response_text,
|
| 331 |
+
"color": base_color,
|
| 332 |
+
"item": clothing_item,
|
| 333 |
+
"colors": suggested_colors
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
return ChatResponse(
|
| 337 |
+
response=response_text,
|
| 338 |
+
session_id=session_id
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
else:
|
| 342 |
+
clothing_info = extract_clothing_info(enhanced_message)
|
| 343 |
+
|
| 344 |
+
if not clothing_info.get("color") and conv_context["context"].get("last_color"):
|
| 345 |
+
enhanced_message = f"{enhanced_message} {conv_context['context']['last_color']}"
|
| 346 |
+
clothing_info = extract_clothing_info(enhanced_message)
|
| 347 |
+
|
| 348 |
+
context_info = ""
|
| 349 |
+
if clothing_info.get("color"):
|
| 350 |
+
context_info += f"Color preference: {clothing_info.get('color')}. "
|
| 351 |
+
if clothing_info.get("type"):
|
| 352 |
+
context_info += f"Item type: {clothing_info.get('type')}. "
|
| 353 |
+
if clothing_info.get("existing_item"):
|
| 354 |
+
context_info += f"User has: {clothing_info.get('existing_item')}. "
|
| 355 |
+
|
| 356 |
+
occasion_keywords = ["defense", "project", "presentation", "meeting", "interview", "formal", "casual", "party", "wedding"]
|
| 357 |
+
occasion = next((word for word in occasion_keywords if word in enhanced_message.lower()), None)
|
| 358 |
+
if occasion:
|
| 359 |
+
context_info += f"Occasion: {occasion}. "
|
| 360 |
+
|
| 361 |
+
prompt = f"{enhanced_message}"
|
| 362 |
+
if context_info:
|
| 363 |
+
prompt += f"\n\nContext: {context_info.strip()}"
|
| 364 |
+
prompt += "\n\nGive helpful, detailed outfit suggestions that are practical and stylish. Be specific about item combinations and explain why they work well."
|
| 365 |
+
|
| 366 |
+
response_text = generate_chat_response(prompt, max_length=1024, temperature=0.8, rag_context=rag_context, images=request.images)
|
| 367 |
+
|
| 368 |
+
update_context(session_id, message, {
|
| 369 |
+
"response": response_text,
|
| 370 |
+
"color": clothing_info.get("color"),
|
| 371 |
+
"item": clothing_info.get("type") or clothing_info.get("requested_item"),
|
| 372 |
+
"items": clothing_info.get("items", [])
|
| 373 |
+
})
|
| 374 |
+
|
| 375 |
+
return ChatResponse(
|
| 376 |
+
response=response_text,
|
| 377 |
+
session_id=session_id
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
raise HTTPException(status_code=500, detail=f"Error processing chat message: {str(e)}")
|
| 382 |
+
|
| 383 |
+
@app.post("/chat/upload", response_model=ChatResponse)
|
| 384 |
+
async def chat_with_upload(
|
| 385 |
+
message: str = Form(...),
|
| 386 |
+
session_id: str = Form(default="default"),
|
| 387 |
+
wardrobe: Optional[str] = Form(default=None),
|
| 388 |
+
images: List[UploadFile] = File(default=[])
|
| 389 |
+
):
|
| 390 |
+
try:
|
| 391 |
+
wardrobe_items = []
|
| 392 |
+
if wardrobe and wardrobe.strip() and wardrobe.strip() not in ["[]", "", "string"]:
|
| 393 |
+
try:
|
| 394 |
+
wardrobe_data = json.loads(wardrobe)
|
| 395 |
+
if isinstance(wardrobe_data, list):
|
| 396 |
+
wardrobe_items = [WardrobeItem(**item) for item in wardrobe_data]
|
| 397 |
+
except json.JSONDecodeError:
|
| 398 |
+
print(f"[UPLOAD] Ignoring invalid wardrobe value: {wardrobe[:50]}")
|
| 399 |
+
|
| 400 |
+
image_data_urls = []
|
| 401 |
+
for img_file in images:
|
| 402 |
+
if img_file.filename:
|
| 403 |
+
content = await img_file.read()
|
| 404 |
+
content_type = img_file.content_type or "image/jpeg"
|
| 405 |
+
base64_data = base64.b64encode(content).decode("utf-8")
|
| 406 |
+
data_url = f"data:{content_type};base64,{base64_data}"
|
| 407 |
+
image_data_urls.append(data_url)
|
| 408 |
+
print(f"[UPLOAD] Processed image: {img_file.filename} ({len(content)} bytes)")
|
| 409 |
+
|
| 410 |
+
request = ChatRequest(
|
| 411 |
+
message=message,
|
| 412 |
+
session_id=session_id,
|
| 413 |
+
wardrobe=wardrobe_items if wardrobe_items else None,
|
| 414 |
+
images=image_data_urls if image_data_urls else None
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
print(f"[UPLOAD] Processing chat request: message='{message[:50]}...', images={len(image_data_urls)}, wardrobe={len(wardrobe_items)}")
|
| 418 |
+
result = await chat(request)
|
| 419 |
+
print(f"[UPLOAD] Response generated: {len(result.response)} chars")
|
| 420 |
+
return result
|
| 421 |
+
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f"[UPLOAD] Error: {e}")
|
| 424 |
+
raise HTTPException(status_code=500, detail=f"Error processing upload: {str(e)}")
|
| 425 |
+
|
| 426 |
+
@app.post("/chat/upload/stream")
|
| 427 |
+
async def chat_with_upload_stream(
|
| 428 |
+
message: str = Form(...),
|
| 429 |
+
session_id: str = Form(default="default"),
|
| 430 |
+
wardrobe: Optional[str] = Form(default=None),
|
| 431 |
+
images: List[UploadFile] = File(default=[])
|
| 432 |
+
):
|
| 433 |
+
image_data_urls = []
|
| 434 |
+
for img_file in images:
|
| 435 |
+
if img_file.filename:
|
| 436 |
+
content = await img_file.read()
|
| 437 |
+
content_type = img_file.content_type or "image/jpeg"
|
| 438 |
+
base64_data = base64.b64encode(content).decode("utf-8")
|
| 439 |
+
data_url = f"data:{content_type};base64,{base64_data}"
|
| 440 |
+
image_data_urls.append(data_url)
|
| 441 |
+
print(f"[STREAM UPLOAD] Processed image: {img_file.filename} ({len(content)} bytes)")
|
| 442 |
+
|
| 443 |
+
rag_chunks = retrieve_relevant_context(message, top_k=3)
|
| 444 |
+
rag_context = format_rag_context(rag_chunks)
|
| 445 |
+
|
| 446 |
+
print(f"[STREAM UPLOAD] Starting streaming response for: {message[:50]}...")
|
| 447 |
+
|
| 448 |
+
async def generate():
|
| 449 |
+
yield f"data: {json.dumps({'type': 'start', 'session_id': session_id})}\n\n"
|
| 450 |
+
|
| 451 |
+
full_response = ""
|
| 452 |
+
async for chunk in generate_chat_response_streaming(
|
| 453 |
+
prompt=message,
|
| 454 |
+
max_length=512,
|
| 455 |
+
temperature=0.7,
|
| 456 |
+
rag_context=rag_context,
|
| 457 |
+
images=image_data_urls if image_data_urls else None
|
| 458 |
+
):
|
| 459 |
+
full_response += chunk
|
| 460 |
+
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk})}\n\n"
|
| 461 |
+
|
| 462 |
+
yield f"data: {json.dumps({'type': 'end', 'full_response': full_response, 'session_id': session_id})}\n\n"
|
| 463 |
+
print(f"[STREAM UPLOAD] Streaming complete: {len(full_response)} chars")
|
| 464 |
+
print(f"[STREAM RESPONSE] {full_response}")
|
| 465 |
+
|
| 466 |
+
return StreamingResponse(
|
| 467 |
+
generate(),
|
| 468 |
+
media_type="text/event-stream",
|
| 469 |
+
headers={
|
| 470 |
+
"Cache-Control": "no-cache",
|
| 471 |
+
"Connection": "keep-alive",
|
| 472 |
+
"X-Accel-Buffering": "no",
|
| 473 |
+
}
|
| 474 |
+
)
|
| 475 |
+
|
wardrobe.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from models import WardrobeItem, ChatResponse
|
| 4 |
+
from query_processing import get_color_matches
|
| 5 |
+
from model_manager import generate_chat_response
|
| 6 |
+
from conversation import get_conversation_context, enhance_message_with_context, update_context
|
| 7 |
+
from rag import retrieve_relevant_context, format_rag_context
|
| 8 |
+
|
| 9 |
+
def clean_wardrobe_response(text: str) -> str:
|
| 10 |
+
text = re.sub(r'\b(\w+)\s+\1\b', r'\1', text)
|
| 11 |
+
text = re.sub(r'\b(navy|grey|gray|black|white|brown|blue|red|green|beige|tan|charcoal|rolex|fossil|hermes|zara|nike)\s+\1\b', r'\1', text, flags=re.IGNORECASE)
|
| 12 |
+
|
| 13 |
+
if "protein" in text.lower():
|
| 14 |
+
text = re.sub(r'[Pp]rotein\s*', '', text)
|
| 15 |
+
|
| 16 |
+
lines = text.split('\n')
|
| 17 |
+
cleaned_lines = []
|
| 18 |
+
|
| 19 |
+
for line in lines:
|
| 20 |
+
line = line.strip()
|
| 21 |
+
if not line or len(line) < 3:
|
| 22 |
+
continue
|
| 23 |
+
if "protein" in line.lower():
|
| 24 |
+
continue
|
| 25 |
+
cleaned_lines.append(line)
|
| 26 |
+
|
| 27 |
+
result = '\n'.join(cleaned_lines).strip()
|
| 28 |
+
|
| 29 |
+
if len(result) > 1000:
|
| 30 |
+
sentences = result.split('. ')
|
| 31 |
+
result = '. '.join(sentences[:8]) + '.'
|
| 32 |
+
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
def format_wardrobe_for_prompt(wardrobe: List[WardrobeItem]) -> str:
|
| 36 |
+
wardrobe_by_category = {}
|
| 37 |
+
for item in wardrobe:
|
| 38 |
+
category = item.category.lower()
|
| 39 |
+
if category not in wardrobe_by_category:
|
| 40 |
+
wardrobe_by_category[category] = []
|
| 41 |
+
wardrobe_by_category[category].append(item)
|
| 42 |
+
|
| 43 |
+
wardrobe_details = []
|
| 44 |
+
for idx, item in enumerate(wardrobe, 1):
|
| 45 |
+
parts = []
|
| 46 |
+
if item.brand:
|
| 47 |
+
parts.append(item.brand)
|
| 48 |
+
if item.color:
|
| 49 |
+
parts.append(item.color)
|
| 50 |
+
if item.name:
|
| 51 |
+
parts.append(item.name)
|
| 52 |
+
elif item.category:
|
| 53 |
+
parts.append(item.category)
|
| 54 |
+
|
| 55 |
+
item_name = " ".join(parts) if parts else item.category
|
| 56 |
+
wardrobe_details.append(f'{idx}. {item_name} ({item.category}, {item.style})')
|
| 57 |
+
|
| 58 |
+
categories_list = ", ".join(wardrobe_by_category.keys())
|
| 59 |
+
|
| 60 |
+
return f"""Available items ({len(wardrobe)} total):
|
| 61 |
+
{chr(10).join(wardrobe_details)}
|
| 62 |
+
|
| 63 |
+
Categories: {categories_list}"""
|
| 64 |
+
|
| 65 |
+
async def handle_wardrobe_chat(message: str, wardrobe: List[WardrobeItem], session_id: str, images: Optional[List[str]] = None) -> ChatResponse:
|
| 66 |
+
conv_context = get_conversation_context(session_id)
|
| 67 |
+
enhanced_message = enhance_message_with_context(message, conv_context["context"])
|
| 68 |
+
|
| 69 |
+
wardrobe_context = format_wardrobe_for_prompt(wardrobe)
|
| 70 |
+
|
| 71 |
+
wardrobe_by_category = {}
|
| 72 |
+
for item in wardrobe:
|
| 73 |
+
category = item.category.lower()
|
| 74 |
+
if category not in wardrobe_by_category:
|
| 75 |
+
wardrobe_by_category[category] = []
|
| 76 |
+
wardrobe_by_category[category].append(item)
|
| 77 |
+
|
| 78 |
+
rag_chunks = retrieve_relevant_context(enhanced_message, top_k=3)
|
| 79 |
+
rag_context = format_rag_context(rag_chunks)
|
| 80 |
+
|
| 81 |
+
occasion_keywords = ["defense", "project", "presentation", "meeting", "interview", "formal", "casual", "party", "wedding", "dinner", "date", "work", "office"]
|
| 82 |
+
occasion = next((word for word in occasion_keywords if word in enhanced_message.lower()), None)
|
| 83 |
+
|
| 84 |
+
context_info = f"Available wardrobe categories: {', '.join(wardrobe_by_category.keys())}. "
|
| 85 |
+
if occasion:
|
| 86 |
+
context_info += f"Occasion: {occasion}. "
|
| 87 |
+
|
| 88 |
+
system_override = "You are a friendly and helpful fashion stylist. Suggest complete outfits conversationally and warmly. Include accessories when available. Be natural and friendly in your responses."
|
| 89 |
+
|
| 90 |
+
prompt = f"""{wardrobe_context}
|
| 91 |
+
|
| 92 |
+
User request: {enhanced_message}
|
| 93 |
+
|
| 94 |
+
Suggest a complete outfit using ONLY the items listed above. Reference items by their exact names as shown (e.g., if item is "zara black pants", say "zara black pants", not "black zara pants"). Include accessories (watches, bags, jewelry, belts, glasses) if available. Be friendly and conversational. Suggest: one top/shirt, one bottom (pants/shorts), shoes, and accessories. Explain briefly why it works."""
|
| 95 |
+
|
| 96 |
+
if context_info.strip():
|
| 97 |
+
prompt += f"\n\nContext: {context_info.strip()}"
|
| 98 |
+
|
| 99 |
+
response_text = generate_chat_response(prompt, max_length=512, temperature=0.8, rag_context=rag_context, system_override=system_override, images=images)
|
| 100 |
+
|
| 101 |
+
response_text = clean_wardrobe_response(response_text)
|
| 102 |
+
|
| 103 |
+
update_context(session_id, message, {
|
| 104 |
+
"response": response_text,
|
| 105 |
+
"wardrobe_count": len(wardrobe),
|
| 106 |
+
"categories": list(wardrobe_by_category.keys())
|
| 107 |
+
})
|
| 108 |
+
|
| 109 |
+
return ChatResponse(
|
| 110 |
+
response=response_text,
|
| 111 |
+
session_id=session_id
|
| 112 |
+
)
|
| 113 |
+
|