Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """FinalProject.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1_wYfP0IRdb9fpc2zvbg8IqdXGx1dTo7X | |
| """ | |
| from datasets import load_dataset | |
| from PIL import Image, ImageChops | |
| from transformers import CLIPProcessor, CLIPModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from diffusers import StableDiffusionImg2ImgPipeline | |
| # Device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load CLIP model | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Load dataset | |
| dataset = load_dataset("lirus18/deepfashion", split="train") | |
| # Embed a subset of dataset images | |
| image_vectors = [] | |
| image_indices = [] | |
| N = 500 | |
| for i in range(N): | |
| img = dataset[i]['image'].convert("RGB") | |
| inputs = processor(images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| emb = model.get_image_features(**inputs) | |
| image_vectors.append(emb.cpu().numpy().squeeze()) | |
| image_indices.append(i) | |
| image_vectors = np.array(image_vectors) | |
| # Similarity search | |
| def find_similar(user_image, top_k=3, exclude_index=None): | |
| inputs = processor(images=user_image.convert("RGB"), return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| query_vec = model.get_image_features(**inputs).cpu().numpy() | |
| sims = cosine_similarity(query_vec, image_vectors)[0] | |
| if exclude_index is not None: | |
| sims[exclude_index] = -1 | |
| top_idx = sims.argsort()[-top_k:][::-1] | |
| return [dataset[image_indices[i]]['image'] for i in top_idx], query_vec | |
| # Load Stable Diffusion | |
| pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| pipe.enable_attention_slicing() | |
| # Generate 10 images | |
| def generate_outfits(input_image, n=1): | |
| prompt = "fashion outfit design inspired by the clothing item" | |
| init_image = input_image.resize((512, 512)) | |
| generated_images = [] | |
| for _ in range(n): | |
| result = pipe(prompt=prompt, image=init_image, strength=0.7, guidance_scale=7.5) | |
| generated_images.append(result.images[0]) | |
| return generated_images | |
| # Main function | |
| def recommend_from_upload(uploaded_image): | |
| uploaded_image = uploaded_image.convert("RGB") | |
| # Check for duplicates | |
| closest_idx = None | |
| for i in range(len(image_indices)): | |
| dataset_image = dataset[image_indices[i]]['image'].convert("RGB") | |
| if ImageChops.difference(dataset_image, uploaded_image).getbbox() is None: | |
| closest_idx = i | |
| break | |
| # Find similar items | |
| similar_imgs, query_vec = find_similar(uploaded_image, top_k=3, exclude_index=closest_idx) | |
| # Generate 10 new outfits | |
| generated_imgs = generate_outfits(uploaded_image, n=1) | |
| # Select best match | |
| best_score = -1 | |
| best_img = None | |
| for img in generated_imgs: | |
| inputs = processor(images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| emb = model.get_image_features(**inputs).cpu().numpy() | |
| sim = cosine_similarity(query_vec, emb)[0][0] | |
| if sim > best_score: | |
| best_score = sim | |
| best_img = img | |
| return [uploaded_image] + similar_imgs + [best_img] | |
| # Example paths | |
| example_paths = [ | |
| ["example1.jpg"], | |
| ["example2.jpg"], | |
| ["example3.jpg"], | |
| ["example4.jpg"], | |
| ["example5.jpg"] | |
| ] | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 👗 Fashion Outfit Recommender") | |
| gr.Markdown("Upload a clothing image to get 3 similar items from the dataset and 1 AI-generated outfit design.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload a clothing item") | |
| generate_btn = gr.Button("Generate Recommendations") | |
| with gr.Row(): | |
| output1 = gr.Image(label="Your Input", height=512, width=384) | |
| output2 = gr.Image(label="Similar Item 1", height=512, width=384) | |
| output3 = gr.Image(label="Similar Item 2", height=512, width=384) | |
| output4 = gr.Image(label="Similar Item 3", height=512, width=384) | |
| output5 = gr.Image(label="AI-Generated Outfit", height=512, width=384) | |
| examples = gr.Examples( | |
| examples=example_paths, | |
| inputs=image_input, | |
| label="Try an Example" | |
| ) | |
| generate_btn.click(fn=recommend_from_upload, | |
| inputs=image_input, | |
| outputs=[output1, output2, output3, output4, output5]) | |
| if __name__ == "__main__": | |
| demo.launch() |