pratyush19 commited on
Commit
8ed71e7
·
1 Parent(s): ac6bc7f
PIL_images.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2a1399bf61a38aa1921f7d49cf94a6668a7df4c42f85d5d120e437406455540
3
+ size 62989209
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle, os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import os
6
+ from transformers import CLIPProcessor, CLIPModel
7
+ from datasets import load_dataset
8
+
9
+ from PIL import Image
10
+ import requests
11
+ from io import BytesIO
12
+
13
+ model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
14
+ processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
15
+
16
+ # hf_token = os.environ.get("HF_API_TOKEN")
17
+ # dataset = load_dataset('pratyush19/cyborg', use_auth_token=hf_token, split='train')
18
+
19
+ # dir_path = "train/"
20
+ # print (dataset)
21
+ # print (dataset[0].keys())
22
+
23
+ with open('valid_images_sample.pkl','rb') as f:
24
+ valid_images = pickle.load(f)
25
+
26
+ with open('image_encodings_sample.pkl','rb') as f:
27
+ image_encodings = pickle.load(f)
28
+ valid_images = np.array(valid_images)
29
+
30
+ with open('PIL_images.pkl','rb') as f:
31
+ PIL_images = pickle.load(f)
32
+
33
+
34
+ def softmax(x):
35
+ e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
36
+ return e_x / e_x.sum(axis=1, keepdims=True)
37
+
38
+ def find_similar_images(caption, image_encodings):
39
+ inputs = processor(text=[caption], return_tensors="pt")
40
+ text_features = model.get_text_features(**inputs)
41
+ text_features = text_features.detach().numpy()
42
+ logits_per_image = softmax(np.dot(text_features, image_encodings.T))
43
+ return logits_per_image
44
+
45
+ def find_relevant_images(caption):
46
+ similarity_scores = find_similar_images(caption, image_encodings)[0]
47
+ top_indices = np.argsort(similarity_scores)[::-1][:16]
48
+ # top_path = valid_images[top_indices]
49
+ images = []
50
+ for idx in top_indices:
51
+ images.append(PIL_images[idx])
52
+ return images
53
+
54
+ def gradio_interface(input_text):
55
+ # with open("user_inputs.txt", "a") as file:
56
+ # file.write(input_text + "\n")
57
+ images = find_relevant_images(input_text)
58
+ return images
59
+
60
+
61
+ def clear_inputs():
62
+ return [None, None, None, None, None, None, None, None, None,
63
+ None, None, None, None, None, None, None, None, None]
64
+
65
+ outputs = [None]*16
66
+
67
+ with gr.Blocks(title="MirrAI") as demo:
68
+ gr.Markdown("<h1 style='text-align: center;'>MirrAI: GenAI-based Fashion Search</h1>")
69
+ gr.Markdown("Enter a text to find the most relevant images from our dataset.")
70
+
71
+ text_input = gr.Textbox(lines=1, label="Input Text", placeholder="Enter your text here...")
72
+ with gr.Row():
73
+ cancel_button = gr.Button("Cancel")
74
+ submit_button = gr.Button("Submit")
75
+ examples = gr.Examples(["high-rise flare jean",
76
+ "a-line dress with floral",
77
+ "men colorful blazers",
78
+ "jumpsuit with puffed sleeve",
79
+ "sleeveless sweater",
80
+ "floral shirt",
81
+ "blue asymmetrical wedding dress with one sleeve",
82
+ "women long coat",
83
+ "cardigan sweater"], inputs=[text_input])
84
+
85
+ with gr.Row():
86
+ outputs[0] = gr.Image()
87
+ outputs[1] = gr.Image()
88
+ outputs[2] = gr.Image()
89
+ outputs[3] = gr.Image()
90
+ with gr.Row():
91
+ outputs[4] = gr.Image()
92
+ outputs[5] = gr.Image()
93
+ outputs[6] = gr.Image()
94
+ outputs[7] = gr.Image()
95
+ with gr.Row():
96
+ outputs[8] = gr.Image()
97
+ outputs[9] = gr.Image()
98
+ outputs[10] = gr.Image()
99
+ outputs[11] = gr.Image()
100
+ with gr.Row():
101
+ outputs[12] = gr.Image()
102
+ outputs[13] = gr.Image()
103
+ outputs[14] = gr.Image()
104
+ outputs[15] = gr.Image()
105
+
106
+ submit_button.click(
107
+ fn=gradio_interface,
108
+ inputs=text_input,
109
+ outputs=outputs
110
+ )
111
+ cancel_button.click(
112
+ fn=clear_inputs,
113
+ inputs=None,
114
+ outputs=[text_input] + outputs
115
+ )
116
+
117
+ demo.launch(share=True)
image_encodings_sample.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c24e3a6233e299af2a3e6dd671919d4f6d35c1106893801dcf0d4eb29743869b
3
+ size 110754
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.33.3
2
+ torch==2.0.1
valid_images_sample.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:290351b793138815b5d85d6cc7c14a1a70caa97a3d84d641e99025a9dabcb87c
3
+ size 13223