LuojiaHOG-demo / app.py
aleo1's picture
Upload app.py
635344f verified
raw
history blame contribute delete
No virus
31.8 kB
import time
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import zipfile
from io import BytesIO
from PIL import Image
import numpy as np
import argparse
import faiss
import gradio as gr
import pandas as pd
import pickle
import cisen.utils.config as config
from cisen.utils.dataset import tokenize
from torchvision import transforms
from get_data_by_image_id import read_json
from cisen.model.segmenter import CISEN_rsvit_hug
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def get_parser():
parser = argparse.ArgumentParser(
description='Pytorch Referring Expression Segmentation')
parser.add_argument('--config',
default='./cisen.yaml',
type=str,
help='config file')
parser.add_argument('--opts',
default=None,
nargs=argparse.REMAINDER,
help='override some settings in the config.')
args = parser.parse_args()
assert args.config is not None
cfg = config.load_cfg_from_cfg_file(args.config)
if args.opts is not None:
cfg = config.merge_cfg_from_list(cfg, args.opts)
return cfg
args = get_parser()
data_dir = './LuojiaHOG(best)_.json'
imgs_folder = 'image/'
# image_id = 'sample44_1641.jpg'
# model_path = './rsvit.pth'
with open('image_features_best.pkl', 'rb') as f:
image_dict = pickle.load(f)
image_feat = np.array(list(image_dict.values()))
f.close()
with open('text_features_best.pkl', 'rb') as f:
text_dict = pickle.load(f)
text_feat = np.array(list(text_dict.values()))
f.close()
# with open('./LuojiaHOG(best)_.pkl', 'rb') as f:
# data_info = pickle.load(f)
# f.close()
sample_info = np.array(list(image_dict))
data_info = read_json(data_dir)
config = {"embed_dim":512, "image_resolution":224, "vision_layers":12, "vision_width":768,
"vision_patch_size":32, "context_length":328, "txt_length":328, "vocab_size":49408,
"transformer_width":512, "transformer_heads":8, "transformer_layers":12, "patch_size":32,
"output_dim":512, "ratio":0.9, "emb_dim":768, "fpn_in":[512, 768, 768], "fpn_out":[768, 768, 768, 512]}
model = CISEN_rsvit_hug(**config)
model = model.from_pretrained("aleo1/cisen")
# img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder)
# 准备数据
# data = np.random.rand(1000, 512).astype(np.float32) # 生成随机的 1000 个向量,每个向量维度为 128
# 创建索引
image_index = faiss.IndexFlatL2(512) # 创建一个平坦索引,使用 L2 距离度量
text_index = faiss.IndexFlatL2(512)
# 将数据添加到索引中
image_index.add(image_feat)
text_index.add(text_feat)
#example
text1 = "The image shows an airfield situated within a military base, offering a comprehensive bird's-eye view of the entire area. The airfield is enclosed by a secure perimeter fence. The main runway, made of concrete and marked with distinct white stripes, is clearly visible and is designed for aircraft takeoffs and landings. Connected to the runway are several asphalt taxiways, which are slightly darker in color and provide pathways for aircraft to move between the runway and other parts of the airfield."
text2 = "The image shows a residential area with houses, parking lots, and significant vegetation. The houses are arranged in rows with roofs in various shades of brown and gray. Between the houses, there are lawns and gardens with green grass. The parking lots are paved with asphalt and have white lines indicating parking spaces. There are clusters of trees and shrubs around the residential buildings, forming green areas that separate different parts of the neighborhood. The vegetation is dense in some sections, especially along the edges of the residential zone. The layout is straightforward, with clear divisions between the houses, parking lots, and green spaces, indicating a well-organized residential area."
text3 = "The image shows a set of railway tracks running alongside several human-made structures. The tracks are made of metal and run parallel to each other, with wooden or concrete ties supporting them. Adjacent to the tracks, there are buildings that appear to be factories or office buildings. These structures are large and rectangular, with flat roofs and multiple windows. The ground near the tracks is clear of vegetation, indicating regular maintenance. There are access roads connecting the buildings to the railway tracks, facilitating transportation and logistics. The overall layout is organized, with clear separations between the railway, buildings, and paved areas. This setup indicates an industrial or commercial zone where the railway tracks are integral for transportation and operational efficiency."
text4 = "The image shows a school complex with several buildings and an adjacent playground. The school buildings are rectangular, with flat roofs and multiple windows. They are arranged in a compact formation, connected by pathways. The roofs are mostly light-colored, likely made of materials like concrete or metal. Next to the buildings is a large playground. The playground includes a grassy field, a running track, and several paved areas for various sports and activities. There are also small structures like sheds or pavilions within the playground area."
text5 = "The image shows a cemetery adjacent to a residential area. The cemetery is organized into neat rows of graves with headstones. The headstones are mostly rectangular and vary in size, creating a grid-like pattern. The ground within the cemetery is primarily grass, with some pathways made of gravel or pavement for access. Next to the cemetery is a residential area with houses arranged in rows. The houses have varied roof colors, including shades of brown and gray. Each house has a small yard with green lawns and some trees or shrubs, providing a touch of greenery."
text6 = "The image displays a soccer field, featuring a rectangular grassy surface with clearly marked boundaries and white lines defining the playing area. The field is surrounded by a track, possibly for running or warm-up exercises, and is enclosed by a fence or barrier. Near the field, there are seating areas, likely for spectators, arranged in rows or tiers. Floodlights are positioned around the perimeter, indicating the capability for night games or events. Adjacent to the field, there may be additional facilities such as changing rooms, restrooms, or concession stands. "
text7 = "The image presents a sports complex featuring both a soccer field and a basketball court. The soccer field occupies a rectangular area with neatly trimmed grass and clearly marked boundaries. White lines delineate the playing area, while surrounding tracks suggest space for warm-up or jogging. Adjacent to the field are seating arrangements for spectators, likely indicating the field's use for organized matches or events. Floodlights are positioned around the perimeter, enabling night-time activities. Adjacent to the soccer field is a basketball court, identifiable by its rectangular shape and hoop installations at each end. The court surface may be made of asphalt or concrete and is marked with lines for gameplay. The area around the court is open, with limited vegetation or structures nearby."
text8 = "The image showcases a sports complex featuring both a tennis court and a basketball court. The tennis court is distinguished by its rectangular shape, with a smooth playing surface marked by white lines indicating the boundaries and service areas. Surrounding the court is a fence or barrier, likely to prevent stray balls from interfering with gameplay. Adjacent to the tennis court, seating arrangements for spectators are visible, suggesting its use for organized matches or tournaments. "
text9 = "The aerial image shows a residential area with a centrally located swimming pool. The pool is a large, rectangular structure with clear blue water. Surrounding the pool, residential buildings are arranged in a neat grid pattern. These buildings vary in size and style, featuring different roof colors and garden layouts. Each house has a yard, some with green lawns, while others have patios or small gardens. Trees and shrubs are interspersed throughout the neighborhood, adding greenery and shade."
text10 = "The image shows a residential area with several single-story houses next to a large open stretch of wasteland. The houses are rectangular, evenly spaced, and have pale blue walls with hints of white, suggesting they are plastered. The wasteland is a mix of brown and green, featuring patches of dry grass and scattered shrubs. Surrounding the houses are tall, slender trees with lush green canopies providing shade. The scene is captured from a high altitude, giving a bird's-eye view where the houses, trees, and the central wasteland are clearly visible."
text11 = "The color remote sensing image shows an urban city street from a high altitude. The street is flanked by tall, sleek buildings featuring a mix of modern and traditional architecture, mostly in white and beige, with some more colorful facades. The street is busy with cars in white, black, silver, and gold, and pedestrians of diverse ethnicities wearing both modern and traditional clothing. Tall, lush trees with various shades of green line the street. The sky is bright blue with a few fluffy clouds. The image is high quality, with clear, visible details."
text12 = "The image displays a residential area with houses arranged in a grid-like pattern, each having a small yard. The houses are mostly uniform in size and shape, featuring pitched roofs and rectangular windows. They come in a variety of colors, from bright ones like yellow and pink to more neutral tones like white and gray. Trees of different sizes and shapes are scattered throughout the area. A parking lot next to the houses is mostly filled with cars, though some spots are empty. Sidewalks and streets connect the houses to other parts of the island."
image_folder = './example_image/'
image_files = [os.path.join(image_folder, filename) for filename in os.listdir(image_folder) if
filename.endswith('.jpg')]
image_list = []
for image_file in image_files:
image_list.append([Image.open(image_file)])
#search fun
def search(text_query, image_query, top_k: int = 10):
# 1. Embed the query as float32
#将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。
start_time = time.time()
# query_embedding = model.encode(query)
if image_query is None:
text = tokenize(text_query, 328)
query_vector = model.text_encode(text)
index = text_index
else:
print(text_query)
print(image_query)
image_query = transform(Image.fromarray(image_query))
query_vector = model.image_encode(image_query.unsqueeze(0))
index = image_index
embed_time = time.time() - start_time
query_vector = np.array(query_vector.detach().numpy())
# 2. Quantize the query to ubinary
#将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。
# start_time = time.time()
# query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
# quantize_time = time.time() - start_time
# 3. Search the binary index (either exact or approximate)
#使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。
# index = binary_ivf if use_approx else binary_index
# index = binary_index
start_time = time.time()
# _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
_scores, binary_ids = index.search(query_vector, top_k)
binary_ids = binary_ids[0]
search_time = time.time() - start_time
# # 4. Load the corresponding int8 embeddings
# #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。
# start_time = time.time()
# int8_embeddings = int8_view[binary_ids].astype(int)
# load_time = time.time() - start_time
#
# # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
# #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。
# start_time = time.time()
# scores = data @ int8_embeddings.T
# rescore_time = time.time() - start_time
# 6. Sort the scores and return the top_k
#根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。
start_time = time.time()
indices = _scores.argsort()[::-1][:top_k]
top_k_indices = binary_ids[indices]
# 获得图像名
info = list(sample_info[top_k_indices])[0]
top_k_scores = list(_scores)[0]
top_k_score = [np.round(value, 2) for value in top_k_scores]
top_k_labels, top_k_texts, lat, lon = zip(
*[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"],
data_info[str(idx)]["lon"]) for idx in info]
)
# df = pd.DataFrame(
# {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
# )
# 获取图像
if text_query != None:
# image_output = [Image.open(imgs_folder + img.replace('_','/')) for img in info]
image_output = []
for img in info:
sample_name = img.split('_')[0]
image_path = imgs_folder + sample_name + '.zip'
with zipfile.ZipFile(image_path, 'r') as zip_ref:
# 读取图像文件
with zip_ref.open(img.replace('_', '/')) as image_file:
# 将读取的字节流转换为图像
image = Image.open(BytesIO(image_file.read()))
image_output.append(image)
else:
image_output = []
df = pd.DataFrame(
{"Distance": top_k_score, 'Latitude' : lat, 'Longitude' : lon, "Description": top_k_texts}
)
df.round({"Distance":2, 'Latitude':4, 'Longitude':4})
sort_time = time.time() - start_time
return df, image_output, {
"Embed Time": f"{embed_time:.4f} s",
# "Quantize Time": f"{quantize_time:.4f} s",
"Search Time": f"{search_time:.4f} s",
# "Load Time": f"{load_time:.4f} s",
# "Rescore Time": f"{rescore_time:.4f} s",
"Sort Time": f"{sort_time:.4f} s",
"Total Retrieval Time": f"{search_time + sort_time:.4f} s",
}
def img_search(image_query, top_k: int = 10):
# 1. Embed the query as float32
#将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。
start_time = time.time()
# query_embedding = model.encode(query)
image_query = transform(Image.fromarray(image_query))
query_vector = model.image_encode(image_query.unsqueeze(0))
index = image_index
embed_time = time.time() - start_time
query_vector = np.array(query_vector.detach().numpy())
# 2. Quantize the query to ubinary
#将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。
# start_time = time.time()
# query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
# quantize_time = time.time() - start_time
# 3. Search the binary index (either exact or approximate)
#使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。
# index = binary_ivf if use_approx else binary_index
# index = binary_index
start_time = time.time()
# _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
_scores, binary_ids = index.search(query_vector, top_k)
binary_ids = binary_ids[0]
search_time = time.time() - start_time
# # 4. Load the corresponding int8 embeddings
# #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。
# start_time = time.time()
# int8_embeddings = int8_view[binary_ids].astype(int)
# load_time = time.time() - start_time
#
# # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
# #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。
# start_time = time.time()
# scores = data @ int8_embeddings.T
# rescore_time = time.time() - start_time
# 6. Sort the scores and return the top_k
#根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。
start_time = time.time()
indices = _scores.argsort()[::-1][:top_k]
top_k_indices = binary_ids[indices]
# 获得图像名
info = list(sample_info[top_k_indices])[0]
top_k_scores = list(_scores)[0]
top_k_score = [np.round(value, 2) for value in top_k_scores]
top_k_labels, top_k_texts, lat, lon = zip(
*[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"],
data_info[str(idx)]["lon"]) for idx in info]
)
# df = pd.DataFrame(
# {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
# )
# 获取图像
if text_query != None:
# image_output = [Image.open(imgs_folder + img.replace('_','/')) for img in info]
image_output = []
for img in info:
sample_name = img.split('_')[0]
image_path = imgs_folder + sample_name + '.zip'
with zipfile.ZipFile(image_path, 'r') as zip_ref:
# 读取图像文件
with zip_ref.open(img.replace('_', '/')) as image_file:
# 将读取的字节流转换为图像
image = Image.open(BytesIO(image_file.read()))
image_output.append(image)
else:
image_output = []
df = pd.DataFrame(
{"Distance": top_k_score, 'Latitude' : lat, 'Longitude' : lon, "Description": top_k_texts}
)
df.round({"Distance":2, 'Latitude':4, 'Longitude':4})
sort_time = time.time() - start_time
return df, image_output, {
"Embed Time": f"{embed_time:.4f} s",
# "Quantize Time": f"{quantize_time:.4f} s",
"Search Time": f"{search_time:.4f} s",
# "Load Time": f"{load_time:.4f} s",
# "Rescore Time": f"{rescore_time:.4f} s",
"Sort Time": f"{sort_time:.4f} s",
"Total Retrieval Time": f"{search_time + sort_time:.4f} s",
}
def txt_search(text_query, top_k: int = 10):
# 1. Embed the query as float32
# 将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。
start_time = time.time()
# query_embedding = model.encode(query)
text = tokenize(text_query, 328)
query_vector = model.text_encode(text)
index = text_index
embed_time = time.time() - start_time
query_vector = np.array(query_vector.detach().numpy())
# 2. Quantize the query to ubinary
# 将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。
# start_time = time.time()
# query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
# quantize_time = time.time() - start_time
# 3. Search the binary index (either exact or approximate)
# 使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。
# index = binary_ivf if use_approx else binary_index
# index = binary_index
start_time = time.time()
# _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
_scores, binary_ids = index.search(query_vector, top_k)
binary_ids = binary_ids[0]
search_time = time.time() - start_time
# # 4. Load the corresponding int8 embeddings
# #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。
# start_time = time.time()
# int8_embeddings = int8_view[binary_ids].astype(int)
# load_time = time.time() - start_time
#
# # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
# #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。
# start_time = time.time()
# scores = data @ int8_embeddings.T
# rescore_time = time.time() - start_time
# 6. Sort the scores and return the top_k
# 根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。
start_time = time.time()
indices = _scores.argsort()[::-1][:top_k]
top_k_indices = binary_ids[indices]
# 获得图像名
info = list(sample_info[top_k_indices])[0]
top_k_scores = list(_scores)[0]
top_k_score = [np.round(value, 2) for value in top_k_scores]
top_k_labels, top_k_texts, lat, lon = zip(
*[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"],
data_info[str(idx)]["lon"]) for idx in info]
)
# df = pd.DataFrame(
# {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
# )
# 获取图像
if text_query != None:
# image_output = [Image.open(imgs_folder + img.replace('_','/')) for img in info]
image_output = []
for img in info:
sample_name = img.split('_')[0]
image_path = imgs_folder + sample_name + '.zip'
with zipfile.ZipFile(image_path, 'r') as zip_ref:
# 读取图像文件
with zip_ref.open(img.replace('_', '/')) as image_file:
# 将读取的字节流转换为图像
image = Image.open(BytesIO(image_file.read()))
image_output.append(image)
else:
image_output = []
df = pd.DataFrame(
{"Distance": top_k_score, 'Latitude': lat, 'Longitude': lon, "Description": top_k_texts}
)
df.round({"Distance": 2, 'Latitude': 4, 'Longitude': 4})
sort_time = time.time() - start_time
return df, image_output, {
"Embed Time": f"{embed_time:.4f} s",
# "Quantize Time": f"{quantize_time:.4f} s",
"Search Time": f"{search_time:.4f} s",
# "Load Time": f"{load_time:.4f} s",
# "Rescore Time": f"{rescore_time:.4f} s",
"Sort Time": f"{sort_time:.4f} s",
"Total Retrieval Time": f"{search_time + sort_time:.4f} s",
}
def update_visible(choice):
if choice == True:
return gr.Textbox(
label="Text query for remote sensing images",
placeholder="Enter a query to search for relevant images.",
visible=True,
interactive=True
), gr.Image(
label="Upload an image",
visible=False
)
elif choice == False:
return gr.Textbox(
label="Text query for remote sensing images",
placeholder="Enter a query to search for relevant images.",
visible=False
), gr.Image(
label="Upload an image",
visible=True,
interactive=True
)
else:
return gr.Textbox(
label="Text query for remote sensing images",
placeholder="Enter a query to search for relevant images.",
visible=True
), gr.Image(
label="Upload an image",
visible=True,
interactive=True
)
with gr.Blocks(title="Image-Text Retrieval") as demo:
# gr.Markdown(
# """
# ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
# This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization) on a CPU. The corpus consists of 41 million texts from Wikipedia articles.
#
# <details><summary>Click to learn about the retrieval process</summary>
#
# Details:
# 1. The query is embedded using the [`mixedbread-ai/mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) SentenceTransformer model.
# 2. The query is quantized to binary using the `quantize_embeddings` function from the SentenceTransformers library.
# 3. A binary index (41M binary embeddings; 5.2GB of memory/disk space) is searched using the quantized query for the top 40 documents.
# 4. The top 40 documents are loaded on the fly from an int8 index on disk (41M int8 embeddings; 0 bytes of memory, 47.5GB of disk space).
# 5. The top 40 documents are rescored using the float32 query and the int8 embeddings to get the top 10 documents.
# 6. The top 10 documents are sorted by score and displayed.
#
# This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory.
# In total, this process requires keeping 1) the model in memory, 2) the binary index in memory, and 3) the int8 index on disk. With a dimensionality of 1024,
# we need `1024 / 8 * num_docs` bytes for the binary index and `1024 * num_docs` bytes for the int8 index.
#
# This is notably cheaper than doing the same process with float32 embeddings, which would require `4 * 1024 * num_docs` bytes of memory/disk space for the float32 index, i.e. 32x as much memory and 4x as much disk space.
# Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. In conclusion, this process allows for fast, scalable, cheap, and memory-efficient retrieval.
#
# Feel free to check out the [code for this demo](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py) to learn more about how to apply this in practice.
#
# Notes:
# - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data. A better IVF index will be released soon.
#
# </details>
# """
# )
# 搜索索引选择:一个单选按钮组,允许用户选择是使用精确搜索还是近似搜索。
search_index = gr.Radio(
choices=[("Examples", None), ("Image-to-Text", False), ("Text-to-Image", True)],
value=None,
label="Search Index",
)
# 查询输入框:一个文本框,允许用户输入查询字符串。用户可以在这里输入想要检索的内容。
text_query = gr.Textbox(
label="Text query for remote sensing images",
placeholder="Enter a query to search for relevant images.",
visible=True,
interactive=True
)
#图像输入框:一个文本框,允许用户输入图像。用户可以在这里输入想要检索的图像。
image_query = gr.Image(
label="Upload an image",
visible=True,
interactive=True
)
search_index.change(update_visible, search_index, [text_query, image_query])
#检索参数设置:两个滑动条,用于设置检索参数。一个用于设置要检索的数量,另一个用于设置重新评分倍数。
with gr.Row():
with gr.Column(scale=2):
top_k = gr.Slider(
minimum=10,
maximum=100,
step=5,
value=10,
interactive=True,
label="Number of images/texts to retrieve",
info="Number of images/texts to retrieve",
)
with gr.Column(scale=2):
json = gr.JSON(label='retrieval time')
# rescore_multiplier = gr.Slider(
# minimum=1,
# maximum=10,
# step=1,
# value=1,
# interactive=True,
# label="Rescore multiplier",
# info="Search for `rescore_multiplier` as many documents to rescore",
# )
#搜索按钮:一个按钮,当用户点击时会触发检索操作。
with gr.Row():
search_button = gr.Button(value="Search", variant='primary')
clear_button = gr.ClearButton(value='Clear Before Next Search')
#输出结果:一个数据框,用于显示检索结果。结果包括得分、标题和文本内容。
with gr.Column():
output = gr.Dataframe(headers=["Distance", "Latitude", "Longitude", "Description"], label="Text outputs")
#输出图像
with gr.Row():
image_output = gr.Gallery(label="Image outputs")
# def update_layout():
# if search_index.value:
# return [search_index, text_query, top_k, rescore_multiplier]
# else:
# return [search_index, image_query, top_k, rescore_multiplier]
inputs = [search_index, text_query, image_query, top_k]
outputs = [output, json, image_output]
# exp_txt = gr.Examples(examples=[[text1, None], [text2, None], [text3, None], [text4, None], [text5, None], [text6, None], [text7, None], [text8, None], [text9, None], [text10, None], [text11, None], [text12, None]],
# inputs=[text_query, image_query, top_k],
# outputs=[output, image_output, json], fn=search, run_on_click=False, examples_per_page=4, label= "Text examples")
exp_txt = gr.Examples(examples=[[text1], [text2], [text3], [text4], [text5], [text6], [text7], [text8], [text9], [text10], [text11], [text12]],
inputs=[text_query, top_k],
outputs=[output, image_output, json], fn=txt_search, run_on_click=True, examples_per_page=4, label= "Text examples", cache_examples='lazy')
exp_img = gr.Examples(examples=image_list, inputs=[image_query, top_k],
outputs=[output, image_output, json], fn=img_search, run_on_click=True, examples_per_page=4, label="Image examples", cache_examples='lazy')
search_button.click(search, inputs=[text_query, image_query, top_k], outputs=[output, image_output, json])
clear_button.add(components=[text_query, image_query, output, image_output, json])
demo.queue()
demo.launch()