FinalProject / app.py
itayitay123's picture
Update app.py
a75b08d verified
# -*- coding: utf-8 -*-
"""FinalProject.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X
"""
from datasets import load_dataset
from PIL import Image, ImageChops
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import gradio as gr
from diffusers import StableDiffusionImg2ImgPipeline
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load dataset
dataset = load_dataset("lirus18/deepfashion", split="train")
# Embed a subset of dataset images
image_vectors = []
image_indices = []
N = 500
for i in range(N):
img = dataset[i]['image'].convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
emb = model.get_image_features(**inputs)
image_vectors.append(emb.cpu().numpy().squeeze())
image_indices.append(i)
image_vectors = np.array(image_vectors)
# Similarity search
def find_similar(user_image, top_k=3, exclude_index=None):
inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device)
with torch.no_grad():
query_vec = model.get_image_features(**inputs).cpu().numpy()
sims = cosine_similarity(query_vec, image_vectors)[0]
if exclude_index is not None:
sims[exclude_index] = -1
top_idx = sims.argsort()[-top_k:][::-1]
return [dataset[image_indices[i]]['image'] for i in top_idx], query_vec
# Load Stable Diffusion
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(device)
pipe.enable_attention_slicing()
# Generate 10 images
def generate_outfits(input_image, n=1):
prompt = "fashion outfit design inspired by the clothing item"
init_image = input_image.resize((512, 512))
generated_images = []
for _ in range(n):
result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5)
generated_images.append(result.images[0])
return generated_images
# Main function
def recommend_from_upload(uploaded_image):
uploaded_image = uploaded_image.convert("RGB")
# Check for duplicates
closest_idx = None
for i in range(len(image_indices)):
dataset_image = dataset[image_indices[i]]['image'].convert("RGB")
if ImageChops.difference(dataset_image, uploaded_image).getbbox() is None:
closest_idx = i
break
# Find similar items
similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx)
# Generate 10 new outfits
generated_imgs = generate_outfits(uploaded_image, n=1)
# Select best match
best_score = -1
best_img = None
for img in generated_imgs:
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
emb = model.get_image_features(**inputs).cpu().numpy()
sim = cosine_similarity(query_vec, emb)[0][0]
if sim > best_score:
best_score = sim
best_img = img
return [uploaded_image] + similar_imgs + [best_img]
# Example paths
example_paths = [
["example1.jpg"],
["example2.jpg"],
["example3.jpg"],
["example4.jpg"],
["example5.jpg"]
]
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 👗 Fashion Outfit Recommender")
gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 AI-generated outfit design.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload a clothing item")
generate_btn = gr.Button("Generate Recommendations")
with gr.Row():
output1 = gr.Image(label="Your Input", height=512, width=384)
output2 = gr.Image(label="Similar Item 1", height=512, width=384)
output3 = gr.Image(label="Similar Item 2", height=512, width=384)
output4 = gr.Image(label="Similar Item 3", height=512, width=384)
output5 = gr.Image(label="AI-Generated Outfit", height=512, width=384)
examples = gr.Examples(
examples=example_paths,
inputs=image_input,
label="Try an Example"
)
generate_btn.click(fn=recommend_from_upload,
inputs=image_input,
outputs=[output1, output2, output3, output4, output5])
if __name__ == "__main__":
demo.launch()