shibing624 commited on
Commit
7084f70
1 Parent(s): c948f23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -4
app.py CHANGED
@@ -1,7 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing(xuming624@qq.com)
4
+ @description:
5
+ """
6
+ import base64
7
+ import glob
8
+ import json
9
+ import os
10
+ import pprint
11
+ import sys
12
+ import zipfile
13
+ from io import BytesIO
14
+ from pathlib import Path
15
+
16
+ import faiss
17
  import gradio as gr
18
+ import numpy as np
19
+ import pandas as pd
20
+ import requests
21
+ from PIL import Image
22
+ from loguru import logger
23
+ from tqdm import tqdm
24
+
25
+ sys.path.append('..')
26
+ from similarities.utils.get_file import http_get
27
+ from similarities.clip_module import ClipModule
28
+
29
+
30
+ def batch_search_index(
31
+ queries,
32
+ model,
33
+ faiss_index,
34
+ df,
35
+ num_results,
36
+ threshold,
37
+ debug=False,
38
+ ):
39
+ """
40
+ Search index with image inputs or image paths (batch search)
41
+ :param queries: list of image paths or list of image inputs or texts or embeddings
42
+ :param model: CLIP model
43
+ :param faiss_index: faiss index
44
+ :param df: corpus dataframe
45
+ :param num_results: int, number of results to return
46
+ :param threshold: float, threshold to return results
47
+ :param debug: bool, whether to print debug info, default True
48
+ :return: search results
49
+ """
50
+ assert queries is not None, "queries should not be None"
51
+ result = []
52
+ if isinstance(queries, np.ndarray):
53
+ query_features = queries
54
+ else:
55
+ query_features = model.encode(queries, normalize_embeddings=True)
56
+
57
+ for query, query_feature in zip(queries, query_features):
58
+ query_feature = query_feature.reshape(1, -1)
59
+ if threshold is not None:
60
+ _, d, i = faiss_index.range_search(query_feature, threshold)
61
+ if debug:
62
+ logger.debug(f"Found {i.shape} items with query '{query}' and threshold {threshold}")
63
+ else:
64
+ d, i = faiss_index.search(query_feature, num_results)
65
+ i = i[0]
66
+ d = d[0]
67
+ # Sorted faiss search result with distance
68
+ text_scores = []
69
+ for ed, ei in zip(d, i):
70
+ # Convert to json, avoid float values error
71
+ item = df.iloc[ei].to_json(force_ascii=False)
72
+ if debug:
73
+ logger.debug(f"Found: {item}, similarity: {ed}, id: {ei}")
74
+ text_scores.append((item, float(ed), int(ei)))
75
+ # Sort by score desc
76
+ query_result = sorted(text_scores, key=lambda x: x[1], reverse=True)
77
+ result.append(query_result)
78
+ return result
79
+
80
+
81
+ def preprocess_image(image_input) -> Image.Image:
82
+ """
83
+ Process image input to Image.Image object
84
+ """
85
+ if isinstance(image_input, str):
86
+ if image_input.startswith('http'):
87
+ return Image.open(requests.get(image_input, stream=True).raw)
88
+ elif image_input.endswith((".png", ".jpg", ".jpeg", ".bmp")) and os.path.isfile(image_input):
89
+ return Image.open(image_input)
90
+ else:
91
+ raise ValueError(f"Unsupported image input type, image path: {image_input}")
92
+ elif isinstance(image_input, np.ndarray):
93
+ return Image.fromarray(image_input)
94
+ elif isinstance(image_input, bytes):
95
+ img_data = base64.b64decode(image_input)
96
+ return Image.open(BytesIO(img_data))
97
+ else:
98
+ raise ValueError(f"Unsupported image input type, image input: {image_input}")
99
+
100
+
101
+ def main():
102
+ # we get about 25k images from Unsplash
103
+ img_folder = 'photos/'
104
+ clip_folder = 'photos/csv/'
105
+ if not os.path.exists(clip_folder) or len(os.listdir(clip_folder)) == 0:
106
+ os.makedirs(img_folder, exist_ok=True)
107
+
108
+ photo_filename = 'unsplash-25k-photos.zip'
109
+ if not os.path.exists(photo_filename): # Download dataset if not exist
110
+ http_get('http://sbert.net/datasets/' + photo_filename, photo_filename)
111
+
112
+ # Extract all images
113
+ with zipfile.ZipFile(photo_filename, 'r') as zf:
114
+ for member in tqdm(zf.infolist(), desc='Extracting'):
115
+ zf.extract(member, img_folder)
116
+ df = pd.DataFrame({'image_path': glob.glob(img_folder + '/*'),
117
+ 'image_name': [os.path.basename(x) for x in glob.glob(img_folder + '/*')]})
118
+ os.makedirs(clip_folder, exist_ok=True)
119
+ df.to_csv(f'{clip_folder}/unsplash-25k-photos.csv', index=False)
120
+
121
+ index_dir = 'clip_engine_25k/image_index/'
122
+ index_name = "faiss.index"
123
+ corpus_dir = 'clip_engine_25k/corpus/'
124
+ model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
125
+
126
+ logger.info("starting boot of clip server")
127
+ index_file = os.path.join(index_dir, index_name)
128
+ assert os.path.exists(index_file), f"index file {index_file} not exist"
129
+ faiss_index = faiss.read_index(index_file)
130
+ model = ClipModule(model_name_or_path=model_name)
131
+ df = pd.concat(pd.read_parquet(parquet_file) for parquet_file in sorted(Path(corpus_dir).glob("*.parquet")))
132
+ logger.info(f'Load model success. model: {model_name}, index: {faiss_index}, corpus size: {len(df)}')
133
+
134
+ def image_path_to_base64(image_path: str) -> str:
135
+ with open(image_path, "rb") as image_file:
136
+ img_str = base64.b64encode(image_file.read()).decode("utf-8")
137
+ return img_str
138
+
139
+ def search_image(text="", image=None):
140
+ html_output = ""
141
+
142
+ if not text and not image:
143
+ return "<p>Please provide either text or image input.</p>"
144
+
145
+ if text and image is not None:
146
+ return "<p>Please provide either text or image input, not both.</p>"
147
+
148
+ if image is not None:
149
+ q = [preprocess_image(image)]
150
+ results = batch_search_index(q, model, faiss_index, df, 5, None, debug=False)[0]
151
+ image_src = "data:image/jpeg;base64," + image_path_to_base64(image)
152
+ html_output += f'Query: <img src="{image_src}" width="200" height="200"><br>'
153
+ else:
154
+ q = [text]
155
+ results = batch_search_index(q, model, faiss_index, df, 5, None, debug=False)[0]
156
+ html_output += f'Query: {text}<br>'
157
+
158
+ html_output += f'Result Size: {len(results)}<br>'
159
+ for result in results:
160
+ item, similarity_score, _ = result
161
+ item_dict = json.loads(item)
162
+ image_path = item_dict.get("image_path", "")
163
+ tip = pprint.pformat(item_dict)
164
+ if not image_path:
165
+ continue
166
+ if image_path.startswith("http"):
167
+ image_src = image_path
168
+ else:
169
+ image_src = "data:image/jpeg;base64," + image_path_to_base64(image_path)
170
+ html_output += f'<div style="display: inline-block; position: relative; margin: 10px;">'
171
+ html_output += f'<img src="{image_src}" width="200" height="200" title="{tip}">'
172
+ html_output += f'<div style="position: absolute; bottom: 0; right: 0; background-color: rgba(255, 255, 255, 0.7); padding: 2px 5px;">'
173
+ html_output += f'Score: {similarity_score:.4f}'
174
+ html_output += f'</div></div>'
175
+
176
+ return html_output
177
+
178
+ def reset_user_input():
179
+ return '', None
180
+
181
+ with gr.Blocks() as demo:
182
+ gr.HTML("""<h1 align="center">CLIP Image Search</h1>""")
183
+ gr.Markdown(
184
+ "> Search for similar images using Faiss and Chinese-CLIP. Link to Github: [similarities](https://github.com/shibing624/similarities)")
185
+ with gr.Tab("Text"):
186
+ with gr.Row():
187
+ with gr.Column():
188
+ input_text = gr.Textbox(lines=2, placeholder="Enter text here...")
189
+
190
+ with gr.Tab("Image"):
191
+ with gr.Row():
192
+ with gr.Column():
193
+ input_image = gr.Image(type="filepath", label="Upload an image")
194
+
195
+ btn_submit = gr.Button(label="Submit")
196
+ output = gr.outputs.HTML(label="Search results")
197
+ btn_submit.click(search_image, inputs=[input_text, input_image],
198
+ outputs=output, show_progress=True)
199
+ btn_submit.click(reset_user_input, outputs=[input_text, input_image])
200
+
201
+ demo.queue().launch()
202
 
 
 
203
 
204
+ if __name__ == '__main__':
205
+ main()