Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						8ed71e7
	
1
								Parent(s):
							
							ac6bc7f
								
fix
Browse files- PIL_images.pkl +3 -0
 - app.py +117 -0
 - image_encodings_sample.pkl +3 -0
 - requirements.txt +2 -0
 - valid_images_sample.pkl +3 -0
 
    	
        PIL_images.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:c2a1399bf61a38aa1921f7d49cf94a6668a7df4c42f85d5d120e437406455540
         
     | 
| 3 | 
         
            +
            size 62989209
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,117 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            import pickle, os
         
     | 
| 3 | 
         
            +
            import pandas as pd
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import os
         
     | 
| 6 | 
         
            +
            from transformers import CLIPProcessor, CLIPModel
         
     | 
| 7 | 
         
            +
            from datasets import load_dataset
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
            import requests
         
     | 
| 11 | 
         
            +
            from io import BytesIO
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
         
     | 
| 14 | 
         
            +
            processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # hf_token = os.environ.get("HF_API_TOKEN")
         
     | 
| 17 | 
         
            +
            # dataset = load_dataset('pratyush19/cyborg', use_auth_token=hf_token, split='train')
         
     | 
| 18 | 
         
            +
             
         
     | 
| 19 | 
         
            +
            # dir_path = "train/"
         
     | 
| 20 | 
         
            +
            # print (dataset)
         
     | 
| 21 | 
         
            +
            # print (dataset[0].keys())
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            with open('valid_images_sample.pkl','rb') as f:
         
     | 
| 24 | 
         
            +
                valid_images = pickle.load(f)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            with open('image_encodings_sample.pkl','rb') as f:
         
     | 
| 27 | 
         
            +
                image_encodings = pickle.load(f)
         
     | 
| 28 | 
         
            +
            valid_images = np.array(valid_images)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            with open('PIL_images.pkl','rb') as f:
         
     | 
| 31 | 
         
            +
                PIL_images = pickle.load(f)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def softmax(x):
         
     | 
| 35 | 
         
            +
                e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
         
     | 
| 36 | 
         
            +
                return e_x / e_x.sum(axis=1, keepdims=True)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def find_similar_images(caption, image_encodings):
         
     | 
| 39 | 
         
            +
                inputs = processor(text=[caption], return_tensors="pt")
         
     | 
| 40 | 
         
            +
                text_features = model.get_text_features(**inputs)
         
     | 
| 41 | 
         
            +
                text_features = text_features.detach().numpy()
         
     | 
| 42 | 
         
            +
                logits_per_image = softmax(np.dot(text_features, image_encodings.T))
         
     | 
| 43 | 
         
            +
                return logits_per_image
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            def find_relevant_images(caption):
         
     | 
| 46 | 
         
            +
                similarity_scores = find_similar_images(caption, image_encodings)[0]
         
     | 
| 47 | 
         
            +
                top_indices = np.argsort(similarity_scores)[::-1][:16]
         
     | 
| 48 | 
         
            +
            #     top_path = valid_images[top_indices]
         
     | 
| 49 | 
         
            +
                images = []
         
     | 
| 50 | 
         
            +
                for idx in top_indices:
         
     | 
| 51 | 
         
            +
                    images.append(PIL_images[idx])
         
     | 
| 52 | 
         
            +
                return images
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def gradio_interface(input_text):
         
     | 
| 55 | 
         
            +
                # with open("user_inputs.txt", "a") as file:
         
     | 
| 56 | 
         
            +
                #     file.write(input_text + "\n")
         
     | 
| 57 | 
         
            +
                images = find_relevant_images(input_text)
         
     | 
| 58 | 
         
            +
                return images
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def clear_inputs():
         
     | 
| 62 | 
         
            +
                return [None, None, None, None, None, None, None, None, None,
         
     | 
| 63 | 
         
            +
                        None, None, None, None, None, None, None, None, None]
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            outputs = [None]*16
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            with gr.Blocks(title="MirrAI") as demo:
         
     | 
| 68 | 
         
            +
                gr.Markdown("<h1 style='text-align: center;'>MirrAI: GenAI-based Fashion Search</h1>")
         
     | 
| 69 | 
         
            +
                gr.Markdown("Enter a text to find the most relevant images from our dataset.")
         
     | 
| 70 | 
         
            +
                    
         
     | 
| 71 | 
         
            +
                text_input = gr.Textbox(lines=1, label="Input Text", placeholder="Enter your text here...")
         
     | 
| 72 | 
         
            +
                with gr.Row():
         
     | 
| 73 | 
         
            +
                    cancel_button = gr.Button("Cancel")
         
     | 
| 74 | 
         
            +
                    submit_button = gr.Button("Submit")
         
     | 
| 75 | 
         
            +
                examples = gr.Examples(["high-rise flare jean",
         
     | 
| 76 | 
         
            +
                                        "a-line dress with floral",
         
     | 
| 77 | 
         
            +
                                        "men colorful blazers",
         
     | 
| 78 | 
         
            +
                                        "jumpsuit with puffed sleeve",
         
     | 
| 79 | 
         
            +
                                        "sleeveless sweater",
         
     | 
| 80 | 
         
            +
                                        "floral shirt",
         
     | 
| 81 | 
         
            +
                                        "blue asymmetrical wedding dress with one sleeve",
         
     | 
| 82 | 
         
            +
                                        "women long coat",
         
     | 
| 83 | 
         
            +
                                        "cardigan sweater"], inputs=[text_input])
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                with gr.Row():
         
     | 
| 86 | 
         
            +
                    outputs[0] = gr.Image()
         
     | 
| 87 | 
         
            +
                    outputs[1] = gr.Image()
         
     | 
| 88 | 
         
            +
                    outputs[2] = gr.Image()
         
     | 
| 89 | 
         
            +
                    outputs[3] = gr.Image()
         
     | 
| 90 | 
         
            +
                with gr.Row():
         
     | 
| 91 | 
         
            +
                    outputs[4] = gr.Image()
         
     | 
| 92 | 
         
            +
                    outputs[5] = gr.Image()
         
     | 
| 93 | 
         
            +
                    outputs[6] = gr.Image()
         
     | 
| 94 | 
         
            +
                    outputs[7] = gr.Image()
         
     | 
| 95 | 
         
            +
                with gr.Row():
         
     | 
| 96 | 
         
            +
                    outputs[8] = gr.Image()
         
     | 
| 97 | 
         
            +
                    outputs[9] = gr.Image()
         
     | 
| 98 | 
         
            +
                    outputs[10] = gr.Image()
         
     | 
| 99 | 
         
            +
                    outputs[11] = gr.Image()
         
     | 
| 100 | 
         
            +
                with gr.Row():
         
     | 
| 101 | 
         
            +
                    outputs[12] = gr.Image()
         
     | 
| 102 | 
         
            +
                    outputs[13] = gr.Image()
         
     | 
| 103 | 
         
            +
                    outputs[14] = gr.Image()
         
     | 
| 104 | 
         
            +
                    outputs[15] = gr.Image()
         
     | 
| 105 | 
         
            +
                
         
     | 
| 106 | 
         
            +
                submit_button.click(
         
     | 
| 107 | 
         
            +
                    fn=gradio_interface,
         
     | 
| 108 | 
         
            +
                    inputs=text_input,
         
     | 
| 109 | 
         
            +
                    outputs=outputs
         
     | 
| 110 | 
         
            +
                )
         
     | 
| 111 | 
         
            +
                cancel_button.click(
         
     | 
| 112 | 
         
            +
                    fn=clear_inputs,
         
     | 
| 113 | 
         
            +
                    inputs=None,
         
     | 
| 114 | 
         
            +
                    outputs=[text_input] + outputs
         
     | 
| 115 | 
         
            +
                )
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            demo.launch(share=True)
         
     | 
    	
        image_encodings_sample.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:c24e3a6233e299af2a3e6dd671919d4f6d35c1106893801dcf0d4eb29743869b
         
     | 
| 3 | 
         
            +
            size 110754
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            transformers==4.33.3
         
     | 
| 2 | 
         
            +
            torch==2.0.1
         
     | 
    	
        valid_images_sample.pkl
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:290351b793138815b5d85d6cc7c14a1a70caa97a3d84d641e99025a9dabcb87c
         
     | 
| 3 | 
         
            +
            size 13223
         
     |