RobotJelly commited on
Commit
087fe06
1 Parent(s): 05170c1
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import Libraries
2
+ from pathlib import Path
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ import clip
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import requests
10
+ import gradio as gr
11
+ # Load the openAI's CLIP model
12
+ model, preprocess = clip.load("ViT-B/32", jit=False)
13
+ #display output photo
14
+ def show_output_image(matched_images) :
15
+ image=[]
16
+ for photo_id in matched_images:
17
+ photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
18
+ #photo_image_url = f"https://unsplash.com/photos/{photo_id}?w=280"
19
+ response = requests.get(photo_image_url)
20
+ img = Image.open(BytesIO(response.content))
21
+ #return img
22
+ image.append(img)
23
+ return image
24
+ # Encode and normalize the search query using CLIP
25
+ def encode_search_query(search_query, model, device):
26
+ with torch.no_grad():
27
+ text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
28
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
29
+ # Retrieve the feature vector from the GPU and convert it to a numpy array
30
+ return text_encoded.cpu().numpy()
31
+ # Find all matched photos
32
+ def find_matches(text_features, photo_features, photo_ids, results_count=4):
33
+ # Compute the similarity between the search query and each photo using the Cosine similarity
34
+ similarities = (photo_features @ text_features.T).squeeze(1)
35
+ # Sort the photos by their similarity score
36
+ best_photo_idx = (-similarities).argsort()
37
+ # Return the photo IDs of the best matches
38
+ return [photo_ids[i] for i in best_photo_idx[:results_count]]
39
+ def image_search(search_text, search_image, option):
40
+ # taking photo IDs
41
+ photo_ids = pd.read_csv("./photo_ids.csv")
42
+ photo_ids = list(photo_ids['photo_id'])
43
+
44
+ # taking features vectors
45
+ photo_features = np.load("./features.npy")
46
+
47
+ # check if CUDA available
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+
50
+ # Load the openAI's CLIP model
51
+ #model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
52
+ model = model.to(device)
53
+
54
+ # Input Text Query
55
+ #search_query = "The feeling when your program finally works"
56
+
57
+ if option == "Text-To-Image" :
58
+ # Extracting text features
59
+ text_features = encode_search_query(search_text, model, device)
60
+
61
+ # Find the matched Images
62
+ matched_images = find_matches(text_features, photo_features, photo_ids, 4)
63
+ # ---- debug purpose ------#
64
+ print(matched_images[0])
65
+ id = matched_images[0]
66
+ photo_image_url = f"https://unsplash.com/photos/{id}/download?w=280"
67
+ print(photo_image_url)
68
+ #--------------------------#
69
+
70
+ return show_output_image(matched_images)
71
+ elif option == "Image-To-Image":
72
+ # Input Image for Search
73
+ with torch.no_grad():
74
+ image_feature = model.encode_image(preprocess(search_image).unsqueeze(0).to(device))
75
+ image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy()
76
+ # Find the matched Images
77
+ matched_images = find_matches(image_feature, photo_features, photo_ids, 4)
78
+ #is_input_image = True
79
+ images = show_output_image(matched_images)
80
+ return images
81
+
82
+ gr.Interface(fn=image_search,
83
+ inputs=[gr.inputs.Textbox(lines=7, label="Input Text"),
84
+ gr.inputs.Image(type="pil", optional=True),
85
+ gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"])
86
+ ],
87
+ outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]),
88
+ enable_queue=True
89
+ ).launch(debug=True)