aleo1 commited on
Commit
6097744
1 Parent(s): e9cc8b9

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ LuojiaHOG(best)_.json filter=lfs diff=lfs merge=lfs -text
LuojiaHOG(best)_.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ded475ab0cc7bcd517a9a3845b09fcf8ae0ca6466c19faf9cfe146a1837fc09f
3
+ size 34819318
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
4
+ import zipfile
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ import numpy as np
8
+ import argparse
9
+ import faiss
10
+ import gradio as gr
11
+ import pandas as pd
12
+ import pickle
13
+ import cisen.utils.config as config
14
+ from cisen.utils.dataset import tokenize
15
+ from torchvision import transforms
16
+ from get_data_by_image_id import read_json
17
+ from cisen.model.segmenter import CISEN_rsvit_hug
18
+ transform = transforms.Compose([
19
+ transforms.Resize(224),
20
+ transforms.CenterCrop(224),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
23
+ ])
24
+
25
+ def get_parser():
26
+ parser = argparse.ArgumentParser(
27
+ description='Pytorch Referring Expression Segmentation')
28
+ parser.add_argument('--config',
29
+ default='./cisen.yaml',
30
+ type=str,
31
+ help='config file')
32
+ parser.add_argument('--opts',
33
+ default=None,
34
+ nargs=argparse.REMAINDER,
35
+ help='override some settings in the config.')
36
+ args = parser.parse_args()
37
+ assert args.config is not None
38
+ cfg = config.load_cfg_from_cfg_file(args.config)
39
+ if args.opts is not None:
40
+ cfg = config.merge_cfg_from_list(cfg, args.opts)
41
+ return cfg
42
+ args = get_parser()
43
+ data_dir = './LuojiaHOG(best)_.json'
44
+ imgs_folder = 'image/'
45
+
46
+ # image_id = 'sample44_1641.jpg'
47
+ # model_path = './rsvit.pth'
48
+
49
+ with open('image_features_best.pkl', 'rb') as f:
50
+ image_dict = pickle.load(f)
51
+ image_feat = np.array(list(image_dict.values()))
52
+ f.close()
53
+ with open('text_features_best.pkl', 'rb') as f:
54
+ text_dict = pickle.load(f)
55
+ text_feat = np.array(list(text_dict.values()))
56
+ f.close()
57
+ # with open('./LuojiaHOG(best)_.pkl', 'rb') as f:
58
+ # data_info = pickle.load(f)
59
+ # f.close()
60
+
61
+ sample_info = np.array(list(image_dict))
62
+ data_info = read_json(data_dir)
63
+ config = {"embed_dim":512, "image_resolution":224, "vision_layers":12, "vision_width":768,
64
+ "vision_patch_size":32, "context_length":328, "txt_length":328, "vocab_size":49408,
65
+ "transformer_width":512, "transformer_heads":8, "transformer_layers":12, "patch_size":32,
66
+ "output_dim":512, "ratio":0.9, "emb_dim":768, "fpn_in":[512, 768, 768], "fpn_out":[768, 768, 768, 512]}
67
+ model = CISEN_rsvit_hug(**config)
68
+ model = model.from_pretrained("aleo1/cisen")
69
+
70
+
71
+ # img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder)
72
+
73
+ # 准备数据
74
+ # data = np.random.rand(1000, 512).astype(np.float32) # 生成随机的 1000 个向量,每个向量维度为 128
75
+
76
+ # 创建索引
77
+ image_index = faiss.IndexFlatL2(512) # 创建一个平坦索引,使用 L2 距离度量
78
+ text_index = faiss.IndexFlatL2(512)
79
+ # 将数据添加到索引中
80
+ image_index.add(image_feat)
81
+ text_index.add(text_feat)
82
+ #example
83
+ text1 = "A rectangular sports field with green artificial turf is visible. The field has white boundary lines and a bright blue surrounding track. Adjacent buildings with flat, gray roofs are visible. Roads with marked lanes run alongside the buildings. A red-roofed structure stands near the sports field. Vegetation includes small, scattered trees with green foliage. Cars are parked along the roads. Shadows cast by the buildings indicate different heights. Pedestrian pathways are present alongside the roads. The image contains a mix of recreational and residential zones. The layout suggests a planned urban environment."
84
+ text2 = "The picture shows a wetland full of diverse plants. The area has a network of waterways and thick vegetation, mainly tall reeds and cattails with slender, bamboo-like stalks. The scene is mostly green, with touches of blue and brown, creating a peaceful vibe. The land is mostly flat, allowing a wide view of the wetland. The image is taken from high above, giving a bird's-eye view of the waterways and aquatic plants in the wetland ecosystem."
85
+ text3 = "The residential area is depicted in the color remote sensing image with a bird's-eye view. The scene shows a heterogeneous mix of houses with varying shapes and sizes, spread across the area. The houses are painted in different colors, with some having white walls and red-tiled roofs, while others have blue or green exteriors. The residential area also contains several green spaces, including small front yards, larger parks, and gardens with various types of trees and shrubs. A broad road runs through the center of the residential area, connecting different parts of the community. The road is lined with trees on both sides and has a designated sidewalk for pedestrians. The image also captures various other elements of the urban landscape, including utility poles, streetlights, and a few commercial buildings on the outskirts of the residential area."
86
+ text4 = "The image depicts a nature reserve on an island, with a landscape dominated by sparse shrubs and meadows in the interior. The color of the image is predominantly green, with hints of brown and yellow, representing the different types of vegetation and soil. The reserve is characterized by rolling hills and gentle valleys, with some areas of flat terrain interspersed throughout. The landscape is dotted with trees, which are scattered randomly and have a relatively low density."
87
+ text5 = "The image shows a nature reserve on an island, featuring a landscape mainly covered with sparse shrubs and meadows. The dominant color is green, with touches of brown and yellow indicating various vegetation types and soil. The reserve has rolling hills and gentle valleys, along with some flat areas. Trees are scattered randomly across the landscape, with a relatively low density."
88
+ text6 = "The image describes a scene of a residential area with several houses situated next to a large, open stretch of land, which serves as a waste land. The houses are single-story structures with rectangular shapes and are evenly distributed across the scene. They have a pale blue color with a hint of white, which suggests that they are constructed using plastered walls. The waste land is covered in a mixture of brown and green colors, with patches of dry grass and scattered shrubs. The trees around the houses are slender and tall, with a lush green canopy that provides shade to the area. The scene is captured from a high altitude, offering a bird's-eye view of the area. The houses and trees are clearly distinguishable, and the waste land appears as a large, empty space in the center of the image."
89
+ text7 = "The neighborhood in the picture has streets arranged in a grid pattern with consistent, low-rise buildings. The buildings are mainly brown, red, and yellow, suggesting a blend of modern and traditional styles. The houses are surrounded by different kinds of greenery, such as small trees, bushes, and tall grasses."
90
+ text8 = "The color remote sensing image shows a residential area from a bird's-eye view, revealing a mix of houses of different shapes and sizes spread throughout the area. The houses are painted in various colors, with some having white walls and red-tiled roofs, while others feature blue or green exteriors. The neighborhood includes several green spaces, such as small front yards, larger parks, and gardens with different types of trees and shrubs. A wide road runs through the center, linking different parts of the community. This road is lined with trees and has sidewalks for pedestrians. The image also shows other urban features like utility poles, streetlights, and a few commercial buildings on the edges of the residential area."
91
+ text9 = "The image shows a barren and desolate landscape with little to no vegetation, dominated by a uniform color palette. The primary color is white, with some patches of gray and black. The terrain is mostly flat, with minimal changes in elevation. A white road is the only noticeable feature in the scene."
92
+ text10 = "The image shows a residential area with several single-story houses next to a large open stretch of wasteland. The houses are rectangular, evenly spaced, and have pale blue walls with hints of white, suggesting they are plastered. The wasteland is a mix of brown and green, featuring patches of dry grass and scattered shrubs. Surrounding the houses are tall, slender trees with lush green canopies providing shade. The scene is captured from a high altitude, giving a bird's-eye view where the houses, trees, and the central wasteland are clearly visible."
93
+ text11 = "The color remote sensing image shows an urban city street from a high altitude. The street is flanked by tall, sleek buildings featuring a mix of modern and traditional architecture, mostly in white and beige, with some more colorful facades. The street is busy with cars in white, black, silver, and gold, and pedestrians of diverse ethnicities wearing both modern and traditional clothing. Tall, lush trees with various shades of green line the street. The sky is bright blue with a few fluffy clouds. The image is high quality, with clear, visible details."
94
+ text12 = "The image displays a residential area with houses arranged in a grid-like pattern, each having a small yard. The houses are mostly uniform in size and shape, featuring pitched roofs and rectangular windows. They come in a variety of colors, from bright ones like yellow and pink to more neutral tones like white and gray. Trees of different sizes and shapes are scattered throughout the area. A parking lot next to the houses is mostly filled with cars, though some spots are empty. Sidewalks and streets connect the houses to other parts of the island."
95
+
96
+ image_folder = './example_image/'
97
+ image_files = [os.path.join(image_folder, filename) for filename in os.listdir(image_folder) if
98
+ filename.endswith('.jpg')]
99
+ image_list = []
100
+ for image_file in image_files:
101
+ image_list.append([Image.open(image_file)])
102
+
103
+ #search fun
104
+ def search(text_query, image_query, top_k: int = 10):
105
+
106
+ # 1. Embed the query as float32
107
+ #将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。
108
+ start_time = time.time()
109
+
110
+ # query_embedding = model.encode(query)
111
+ if image_query is None:
112
+ text = tokenize(text_query, 328)
113
+ query_vector = model.text_encode(text)
114
+ index = text_index
115
+ else:
116
+ print(text_query)
117
+ print(image_query)
118
+
119
+ image_query = transform(Image.fromarray(image_query))
120
+ query_vector = model.image_encode(image_query.unsqueeze(0))
121
+ index = image_index
122
+
123
+ embed_time = time.time() - start_time
124
+ query_vector = np.array(query_vector.detach().numpy())
125
+ # 2. Quantize the query to ubinary
126
+ #将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。
127
+ # start_time = time.time()
128
+ # query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
129
+ # quantize_time = time.time() - start_time
130
+
131
+ # 3. Search the binary index (either exact or approximate)
132
+ #使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。
133
+ # index = binary_ivf if use_approx else binary_index
134
+ # index = binary_index
135
+ start_time = time.time()
136
+ # _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
137
+ _scores, binary_ids = index.search(query_vector, top_k)
138
+ binary_ids = binary_ids[0]
139
+ search_time = time.time() - start_time
140
+
141
+ # # 4. Load the corresponding int8 embeddings
142
+ # #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。
143
+ # start_time = time.time()
144
+ # int8_embeddings = int8_view[binary_ids].astype(int)
145
+ # load_time = time.time() - start_time
146
+ #
147
+ # # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
148
+ # #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。
149
+ # start_time = time.time()
150
+ # scores = data @ int8_embeddings.T
151
+ # rescore_time = time.time() - start_time
152
+
153
+ # 6. Sort the scores and return the top_k
154
+ #根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。
155
+ start_time = time.time()
156
+ indices = _scores.argsort()[::-1][:top_k]
157
+ top_k_indices = binary_ids[indices]
158
+ # 获得图像名
159
+ info = list(sample_info[top_k_indices])[0]
160
+
161
+ top_k_scores = list(_scores)[0]
162
+ top_k_score = [np.round(value, 2) for value in top_k_scores]
163
+
164
+ top_k_labels, top_k_texts, lat, lon = zip(
165
+ *[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"],
166
+ data_info[str(idx)]["lon"]) for idx in info]
167
+ )
168
+ # df = pd.DataFrame(
169
+ # {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
170
+ # )
171
+ # 获取图像
172
+ image_output = []
173
+ for img in info:
174
+ with zipfile.ZipFile('./image.zip', 'r') as zip_ref:
175
+ # 读取图像文件
176
+ with zip_ref.open(imgs_folder + img.replace('_','/')) as image_file:
177
+ # 将读取的字节流转换为图像
178
+ image = Image.open(BytesIO(image_file.read()))
179
+ image_output.append(image)
180
+
181
+ if text_query != None:
182
+ # image_output = [Image.open(imgs_folder + img.replace('_','/')) for img in info]
183
+ image_output = []
184
+ with zipfile.ZipFile('./image.zip', 'r') as zip_ref:
185
+ # 读取图像文件
186
+ for img in info:
187
+ with zip_ref.open(imgs_folder + img.replace('_', '/')) as image_file:
188
+ # 将读取的字节流转换为图像
189
+ image = Image.open(BytesIO(image_file.read()))
190
+ image_output.append(image)
191
+ else:
192
+ image_output = []
193
+
194
+ df = pd.DataFrame(
195
+ {"Distance": top_k_score, 'Latitude' : lat, 'Longitude' : lon, "Description": top_k_texts}
196
+ )
197
+ df.round({"Distance":2, 'Latitude':4, 'Longitude':4})
198
+ sort_time = time.time() - start_time
199
+
200
+ return df, image_output, {
201
+ "Embed Time": f"{embed_time:.4f} s",
202
+ # "Quantize Time": f"{quantize_time:.4f} s",
203
+ "Search Time": f"{search_time:.4f} s",
204
+ # "Load Time": f"{load_time:.4f} s",
205
+ # "Rescore Time": f"{rescore_time:.4f} s",
206
+ "Sort Time": f"{sort_time:.4f} s",
207
+ "Total Retrieval Time": f"{search_time + sort_time:.4f} s",
208
+ }
209
+
210
+ def img_search(image_query, top_k: int = 10):
211
+
212
+ # 1. Embed the query as float32
213
+ #将查询字符串编码为浮���数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。
214
+ start_time = time.time()
215
+
216
+ # query_embedding = model.encode(query)
217
+
218
+
219
+ image_query = transform(Image.fromarray(image_query))
220
+ query_vector = model.image_encode(image_query.unsqueeze(0))
221
+ index = image_index
222
+
223
+ embed_time = time.time() - start_time
224
+ query_vector = np.array(query_vector.detach().numpy())
225
+ # 2. Quantize the query to ubinary
226
+ #将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。
227
+ # start_time = time.time()
228
+ # query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
229
+ # quantize_time = time.time() - start_time
230
+
231
+ # 3. Search the binary index (either exact or approximate)
232
+ #使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。
233
+ # index = binary_ivf if use_approx else binary_index
234
+ # index = binary_index
235
+ start_time = time.time()
236
+ # _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
237
+ _scores, binary_ids = index.search(query_vector, top_k)
238
+ binary_ids = binary_ids[0]
239
+ search_time = time.time() - start_time
240
+
241
+ # # 4. Load the corresponding int8 embeddings
242
+ # #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。
243
+ # start_time = time.time()
244
+ # int8_embeddings = int8_view[binary_ids].astype(int)
245
+ # load_time = time.time() - start_time
246
+ #
247
+ # # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
248
+ # #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。
249
+ # start_time = time.time()
250
+ # scores = data @ int8_embeddings.T
251
+ # rescore_time = time.time() - start_time
252
+
253
+ # 6. Sort the scores and return the top_k
254
+ #根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。
255
+ start_time = time.time()
256
+ indices = _scores.argsort()[::-1][:top_k]
257
+ top_k_indices = binary_ids[indices]
258
+ # 获得图像名
259
+ info = list(sample_info[top_k_indices])[0]
260
+
261
+ top_k_scores = list(_scores)[0]
262
+ top_k_score = [np.round(value, 2) for value in top_k_scores]
263
+
264
+ top_k_labels, top_k_texts, lat, lon = zip(
265
+ *[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"],
266
+ data_info[str(idx)]["lon"]) for idx in info]
267
+ )
268
+ # df = pd.DataFrame(
269
+ # {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
270
+ # )
271
+ # 获取图像
272
+ if text_query != None:
273
+ image_output = [Image.open(imgs_folder + img.replace('_','/')) for img in info]
274
+ else:
275
+ image_output = []
276
+
277
+ df = pd.DataFrame(
278
+ {"Distance": top_k_score, 'Latitude' : lat, 'Longitude' : lon, "Description": top_k_texts}
279
+ )
280
+ df.round({"Distance":2, 'Latitude':4, 'Longitude':4})
281
+ sort_time = time.time() - start_time
282
+
283
+ return df, image_output, {
284
+ "Embed Time": f"{embed_time:.4f} s",
285
+ # "Quantize Time": f"{quantize_time:.4f} s",
286
+ "Search Time": f"{search_time:.4f} s",
287
+ # "Load Time": f"{load_time:.4f} s",
288
+ # "Rescore Time": f"{rescore_time:.4f} s",
289
+ "Sort Time": f"{sort_time:.4f} s",
290
+ "Total Retrieval Time": f"{search_time + sort_time:.4f} s",
291
+ }
292
+
293
+
294
+ def txt_search(text_query, top_k: int = 10):
295
+ # 1. Embed the query as float32
296
+ # 将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。
297
+ start_time = time.time()
298
+
299
+ # query_embedding = model.encode(query)
300
+
301
+ text = tokenize(text_query, 328)
302
+ query_vector = model.text_encode(text)
303
+ index = text_index
304
+
305
+ embed_time = time.time() - start_time
306
+ query_vector = np.array(query_vector.detach().numpy())
307
+ # 2. Quantize the query to ubinary
308
+ # 将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。
309
+ # start_time = time.time()
310
+ # query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
311
+ # quantize_time = time.time() - start_time
312
+
313
+ # 3. Search the binary index (either exact or approximate)
314
+ # 使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。
315
+ # index = binary_ivf if use_approx else binary_index
316
+ # index = binary_index
317
+ start_time = time.time()
318
+ # _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
319
+ _scores, binary_ids = index.search(query_vector, top_k)
320
+ binary_ids = binary_ids[0]
321
+ search_time = time.time() - start_time
322
+
323
+ # # 4. Load the corresponding int8 embeddings
324
+ # #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。
325
+ # start_time = time.time()
326
+ # int8_embeddings = int8_view[binary_ids].astype(int)
327
+ # load_time = time.time() - start_time
328
+ #
329
+ # # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
330
+ # #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。
331
+ # start_time = time.time()
332
+ # scores = data @ int8_embeddings.T
333
+ # rescore_time = time.time() - start_time
334
+
335
+ # 6. Sort the scores and return the top_k
336
+ # 根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。
337
+ start_time = time.time()
338
+ indices = _scores.argsort()[::-1][:top_k]
339
+ top_k_indices = binary_ids[indices]
340
+ # 获得图像名
341
+ info = list(sample_info[top_k_indices])[0]
342
+
343
+ top_k_scores = list(_scores)[0]
344
+ top_k_score = [np.round(value, 2) for value in top_k_scores]
345
+
346
+ top_k_labels, top_k_texts, lat, lon = zip(
347
+ *[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"],
348
+ data_info[str(idx)]["lon"]) for idx in info]
349
+ )
350
+ # df = pd.DataFrame(
351
+ # {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
352
+ # )
353
+ # 获取图像
354
+ if text_query != None:
355
+ image_output = [Image.open(imgs_folder + img.replace('_', '/')) for img in info]
356
+ else:
357
+ image_output = []
358
+
359
+ df = pd.DataFrame(
360
+ {"Distance": top_k_score, 'Latitude': lat, 'Longitude': lon, "Description": top_k_texts}
361
+ )
362
+ df.round({"Distance": 2, 'Latitude': 4, 'Longitude': 4})
363
+ sort_time = time.time() - start_time
364
+
365
+ return df, image_output, {
366
+ "Embed Time": f"{embed_time:.4f} s",
367
+ # "Quantize Time": f"{quantize_time:.4f} s",
368
+ "Search Time": f"{search_time:.4f} s",
369
+ # "Load Time": f"{load_time:.4f} s",
370
+ # "Rescore Time": f"{rescore_time:.4f} s",
371
+ "Sort Time": f"{sort_time:.4f} s",
372
+ "Total Retrieval Time": f"{search_time + sort_time:.4f} s",
373
+ }
374
+ def update_visible(choice):
375
+ if choice == True:
376
+ return gr.Textbox(
377
+ label="Text query for remote sensing images",
378
+ placeholder="Enter a query to search for relevant images.",
379
+ visible=True,
380
+ interactive=True
381
+ ), gr.Image(
382
+ label="Upload an image",
383
+ visible=False
384
+ )
385
+ elif choice == False:
386
+ return gr.Textbox(
387
+ label="Text query for remote sensing images",
388
+ placeholder="Enter a query to search for relevant images.",
389
+ visible=False
390
+ ), gr.Image(
391
+ label="Upload an image",
392
+ visible=True,
393
+ interactive=True
394
+ )
395
+ else:
396
+ return gr.Textbox(
397
+ label="Text query for remote sensing images",
398
+ placeholder="Enter a query to search for relevant images.",
399
+ visible=True
400
+ ), gr.Image(
401
+ label="Upload an image",
402
+ visible=True,
403
+ interactive=True
404
+ )
405
+
406
+ with gr.Blocks(title="Image-Text Retrieval") as demo:
407
+ # gr.Markdown(
408
+ # """
409
+ # ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
410
+ # This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization) on a CPU. The corpus consists of 41 million texts from Wikipedia articles.
411
+ #
412
+ # <details><summary>Click to learn about the retrieval process</summary>
413
+ #
414
+ # Details:
415
+ # 1. The query is embedded using the [`mixedbread-ai/mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) SentenceTransformer model.
416
+ # 2. The query is quantized to binary using the `quantize_embeddings` function from the SentenceTransformers library.
417
+ # 3. A binary index (41M binary embeddings; 5.2GB of memory/disk space) is searched using the quantized query for the top 40 documents.
418
+ # 4. The top 40 documents are loaded on the fly from an int8 index on disk (41M int8 embeddings; 0 bytes of memory, 47.5GB of disk space).
419
+ # 5. The top 40 documents are rescored using the float32 query and the int8 embeddings to get the top 10 documents.
420
+ # 6. The top 10 documents are sorted by score and displayed.
421
+ #
422
+ # This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory.
423
+ # In total, this process requires keeping 1) the model in memory, 2) the binary index in memory, and 3) the int8 index on disk. With a dimensionality of 1024,
424
+ # we need `1024 / 8 * num_docs` bytes for the binary index and `1024 * num_docs` bytes for the int8 index.
425
+ #
426
+ # This is notably cheaper than doing the same process with float32 embeddings, which would require `4 * 1024 * num_docs` bytes of memory/disk space for the float32 index, i.e. 32x as much memory and 4x as much disk space.
427
+ # Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. In conclusion, this process allows for fast, scalable, cheap, and memory-efficient retrieval.
428
+ #
429
+ # Feel free to check out the [code for this demo](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py) to learn more about how to apply this in practice.
430
+ #
431
+ # Notes:
432
+ # - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data. A better IVF index will be released soon.
433
+ #
434
+ # </details>
435
+ # """
436
+ # )
437
+
438
+
439
+ # 搜索索引选择:一个单选按钮组,允许用户选择是使用精确搜索还是近似搜索。
440
+
441
+ search_index = gr.Radio(
442
+ choices=[("Examples", None), ("Image-to-Text", False), ("Text-to-Image", True)],
443
+ value=None,
444
+ label="Search Index",
445
+ )
446
+
447
+
448
+ # 查询输入框:一个文本框,允许用户输入查询字符串。用户可以在这里输入想要检索的内容。
449
+ text_query = gr.Textbox(
450
+ label="Text query for remote sensing images",
451
+ placeholder="Enter a query to search for relevant images.",
452
+ visible=True,
453
+ interactive=True
454
+ )
455
+
456
+
457
+ #图像输入框:一个文本框,允许用户输入图像。用户可以在这里输入想要检索的图像。
458
+ image_query = gr.Image(
459
+ label="Upload an image",
460
+ visible=True,
461
+ interactive=True
462
+ )
463
+ search_index.change(update_visible, search_index, [text_query, image_query])
464
+
465
+ #检索参数设置:两个滑动条,用于设置检索参数。一个用于设置要检索的数量,另一个用于设置重新评分倍数。
466
+ with gr.Row():
467
+ with gr.Column(scale=2):
468
+ top_k = gr.Slider(
469
+ minimum=10,
470
+ maximum=100,
471
+ step=5,
472
+ value=10,
473
+ interactive=True,
474
+ label="Number of images/texts to retrieve",
475
+ info="Number of images/texts to retrieve",
476
+ )
477
+ with gr.Column(scale=2):
478
+ json = gr.JSON(label='retrieval time')
479
+ # rescore_multiplier = gr.Slider(
480
+ # minimum=1,
481
+ # maximum=10,
482
+ # step=1,
483
+ # value=1,
484
+ # interactive=True,
485
+ # label="Rescore multiplier",
486
+ # info="Search for `rescore_multiplier` as many documents to rescore",
487
+ # )
488
+ #搜索按钮:一个按钮,当用户点击时会触发检索操作。
489
+ with gr.Row():
490
+ search_button = gr.Button(value="Search", variant='primary')
491
+ clear_button = gr.ClearButton(value='Clear Before Next Search')
492
+
493
+ #输出结果:一个数据框,用于显示检索结果。结果包括得分、标题和文本内容。
494
+
495
+ with gr.Column():
496
+ output = gr.Dataframe(headers=["Distance", "Latitude", "Longitude", "Description"], label="Text outputs")
497
+
498
+
499
+
500
+
501
+ #输出图像
502
+ with gr.Row():
503
+ image_output = gr.Gallery(label="Image outputs")
504
+
505
+
506
+ # def update_layout():
507
+ # if search_index.value:
508
+ # return [search_index, text_query, top_k, rescore_multiplier]
509
+ # else:
510
+ # return [search_index, image_query, top_k, rescore_multiplier]
511
+
512
+ inputs = [search_index, text_query, image_query, top_k]
513
+ outputs = [output, json, image_output]
514
+
515
+
516
+
517
+
518
+ # exp_txt = gr.Examples(examples=[[text1, None], [text2, None], [text3, None], [text4, None], [text5, None], [text6, None], [text7, None], [text8, None], [text9, None], [text10, None], [text11, None], [text12, None]],
519
+ # inputs=[text_query, image_query, top_k],
520
+ # outputs=[output, image_output, json], fn=search, run_on_click=False, examples_per_page=4, label= "Text examples")
521
+ exp_txt = gr.Examples(examples=[[text1], [text2], [text3], [text4], [text5], [text6], [text7], [text8], [text9], [text10], [text11], [text12]],
522
+ inputs=[text_query, top_k],
523
+ outputs=[output, image_output, json], fn=txt_search, run_on_click=True, examples_per_page=4, label= "Text examples", cache_examples='lazy')
524
+
525
+ exp_img = gr.Examples(examples=image_list, inputs=[image_query, top_k],
526
+ outputs=[output, image_output, json], fn=img_search, run_on_click=True, examples_per_page=4, label="Image examples", cache_examples='lazy')
527
+
528
+ search_button.click(search, inputs=[text_query, image_query, top_k], outputs=[output, image_output, json])
529
+ clear_button.add(components=[text_query, image_query, output, image_output, json])
530
+
531
+ demo.queue()
532
+ demo.launch()
533
+
cisen.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: classification
3
+ dataset_json_file: /data02/xy/dataEngine/json_data/LuojiaHOG(test)_.json
4
+ # dataset_json_file: /data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json
5
+ # dataset_json_file: /data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json
6
+ exp_name: classifi
7
+ ratio: 0.9
8
+ dataset_train_split: 0.6
9
+ dataset_query_split: 0.2
10
+ imgs_folder: /data02/xy/Clip-hash/datasets/image/
11
+ label_path: /data02/xy/Clip-hash/labels.txt
12
+ num_classes: 10
13
+ # num_classes: 131
14
+ TRAIN:
15
+ # Base Arch
16
+ # clip_pretrain: /data02/xy/Clip-hash/pretrain/RS5M_ViT-B-32.pt
17
+ clip_pretrain: ./cisen/pretrain/RS5M_ViT-B-32.pt
18
+ model_name: ViT-B-32
19
+ ckpt_path: /data02/xy/GeoRSCLIP/codebase/inference/pretrain/RS5M_ViT-B-32.pt
20
+ input_size: 224
21
+ word_len: 328
22
+ word_dim: 1024
23
+ vis_dim: 512
24
+ fpn_in: [ 512, 768, 768 ]
25
+ fpn_out: [ 768, 768, 768, 512 ]
26
+ sync_bn: True
27
+ # Decoder
28
+ num_layers: 3
29
+ num_head: 8
30
+ dim_ffn: 2048
31
+ dropout: 0.1
32
+ intermediate: False
33
+ # Training Setting
34
+ workers: 32 # data loader workers
35
+ workers_val: 16
36
+ epochs: 50
37
+ milestones: [50]
38
+ start_epoch: 0
39
+ batch_size: 256 # batch size for training
40
+ batch_size_val: 256 # batch size for validation during training, memory and speed tradeoff 11111
41
+ base_lr: 0.0001
42
+ min_lr: 0.00000001
43
+ lr_decay: 0.5
44
+ lr_multi: 0.1
45
+ weight_decay: 0.
46
+ max_norm: 0.
47
+ manual_seed: 0
48
+ print_freq: 1
49
+ lamda1: 0.5
50
+ lamda2: 0.5
51
+ beta1: 0.5
52
+ beta2: 0.5
53
+ eta: 0.2
54
+ warmup_epochs: 0
55
+ contrastive: [0.4, 0.3, 0.3]
56
+ # Resume & Save
57
+
58
+ output_folder: /data02/xy/Clip-hash/exp/
59
+ save_freq: 1
60
+ weight: # path to initial weight (default: none)
61
+ resume: False # path to latest checkpoint (default: none)
62
+ evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend
63
+ Distributed:
64
+ dist_url: tcp://localhost:3693
65
+ dist_backend: 'nccl'
66
+ multiprocessing_distributed: True
67
+ world_size: 1
68
+ rank: 0
69
+ TEST:
70
+ test_split: val-test
71
+ gpu : [0]
72
+ test_lmdb: /data02/xy/Clip-hash/datasets/lmdb/refcoco/val.lmdb
73
+ visualize: False
74
+ topk: 5
75
+ test_batch_size: 256 #1111111
76
+ val_batch_size: 1
get_data_by_image_id.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+
7
+
8
+ def read_json(file_name, suppress_console_info=False):
9
+ with open(file_name, 'r') as f:
10
+ data = json.load(f)
11
+ if not suppress_console_info:
12
+ print("Read from:", file_name)
13
+ return data
14
+
15
+ def get_file_names(data, imgs_folder, feature_folder, suppress_console_info=False):
16
+
17
+ image_file_names = {}
18
+ feature_pathes = {}
19
+ captions = {}
20
+ labels = {}
21
+ lats = {}
22
+ lons = {}
23
+
24
+ for img in data['images']:
25
+ image_name = img["image_name"]
26
+ sample_id = img["sample_id"]
27
+ image_id = f'{sample_id}_{image_name}'
28
+ path_data = imgs_folder + f'{sample_id}/{image_name}'
29
+ feature_data = feature_folder + f'{sample_id}/{image_name}.npy'
30
+ # image_file_name.append(path_data)
31
+ # caption.append(img["description"])
32
+ # label.append(img["labels"])
33
+ # lat.append(img["lat"])
34
+ # lon.append(img["lon"])
35
+
36
+ image_file_names[image_id] = path_data
37
+ feature_pathes[image_id] = feature_data
38
+ captions[image_id] = img["description"]
39
+ labels[image_id] = img["labels"]
40
+ lats[image_id] = img["lat"]
41
+ lons[image_id] = img["lon"]
42
+
43
+ return image_file_names, feature_pathes, captions, labels, lats, lons
44
+
45
+
46
+ def get_data(image_file_names, captions, feature_pathes, labels, lats, lons, image_id):
47
+
48
+ image_file_name = image_file_names[image_id]
49
+ feature_path = feature_pathes[image_id]
50
+ caption = captions[image_id]
51
+ label = labels[image_id]
52
+ lat = lats[image_id]
53
+ lon = lons[image_id]
54
+
55
+ return image_file_name, feature_path, caption, label, lat, lon
56
+
57
+
58
+ def read_by_image_id(data_dir, imgs_folder, feature_folder, image_id=None):
59
+ '''
60
+ return:
61
+ img
62
+ img_ -> transform(img)
63
+ caption
64
+ image_feature -> tensor
65
+ label
66
+ label_en -> text of labels
67
+ lat
68
+ lon
69
+ '''
70
+
71
+ data_info = read_json(data_dir)
72
+ image_file_names, image_features_path, captions, labels, lats, lons = get_file_names(data_info, imgs_folder, feature_folder)
73
+
74
+ image_file_name, image_feature_path, caption, label, lat, lon = get_data(image_file_names, captions, image_features_path, labels, lats, lons, image_id)
75
+
76
+ label_en = []
77
+ label131 = data_info['labels']
78
+
79
+ for lable_name in label131.keys():
80
+ label_id = label131[lable_name]
81
+ for label_singel in label:
82
+ if label_singel == label_id:
83
+ label_en.append(lable_name)
84
+ image_feature = np.load(image_feature_path)
85
+
86
+ img = Image.open(image_file_name).convert('RGB')
87
+
88
+ transform = transforms.Compose([
89
+ transforms.Resize(224),
90
+ transforms.CenterCrop(224),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
93
+ ])
94
+
95
+ if transform is not None:
96
+ img_ = np.array(transform(img))
97
+ else:
98
+ img_ = np.array(img)
99
+ img_ = torch.from_numpy(img_.astype('float32'))
100
+
101
+ return img, img_, caption, image_feature, label, label_en, lat, lon
102
+
103
+
104
+ # test
105
+ data_dir = '/data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json'
106
+ imgs_folder = '/data02/xy/Clip-hash//datasets/image/'
107
+ feature_folder = '/data02/xy/Clip-hash/image_feature/georsclip_21_r0.9_fpn/'
108
+ image_id = 'sample44_889.jpg'
109
+
110
+ # img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder, image_id)
111
+ # print(img, img_, caption, image_feature, label, label_en, lat, lon)
image_features_best.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5984003a770cd6dace7f2cb01c3e41e0683fc54e3dabdce3d9c6e155995db21
3
+ size 86042598
text_features_best.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a13ca6ce41ce5aa386a8797c4d1b2416eaed1409e06f0972f799eb08567b709a
3
+ size 86042598