File size: 3,835 Bytes
1809285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cce4c65
1809285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a66d2e
1809285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import torch
from visual_bge.modeling import Visualized_BGE
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from PIL import Image
from io import BytesIO
import requests
import matplotlib.pyplot as plt
from matplotlib import font_manager
import textwrap
import os
import tempfile
from huggingface_hub import hf_hub_download

model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth")

# Load Thai font
thai_font = font_manager.FontProperties(fname='./Sarabun-Regular.ttf')

# Load model
model = Visualized_BGE(
    model_name_bge="BAAI/bge-m3",
    model_weight=model_weight
)

# Load Qdrant connection
qdrant_client = QdrantClient(
    url=os.environ.get("QDRANT_URL"),
    api_key=os.environ.get("QDRANT_API_KEY")
)

# Visual helper function
def visualize_results(results):
    cols = 4
    rows = (len(results) + cols - 1) // cols
    fig, axs = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axs = axs.flatten() if hasattr(axs, 'flatten') else [axs]

    for i, res in enumerate(results):
        try:
            image_url = res.payload['image_url']
            img = Image.open(image_url) if os.path.exists(image_url) else Image.open(BytesIO(requests.get(image_url).content))
            name = res.payload['name']
            if len(name) > 30:
                name = name[:27] + "..."
            wrapped_name = textwrap.fill(name, width=15)
            axs[i].imshow(img)
            axs[i].set_title(f"{wrapped_name}\nScore: {res.score:.2f}", fontproperties=thai_font, fontsize=10)
            axs[i].axis('off')
        except Exception as e:
            axs[i].text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center', fontsize=8)
            axs[i].axis('off')

    for j in range(len(results), len(axs)):
        axs[j].axis('off')

    plt.tight_layout(pad=3.0)
    plt.subplots_adjust(hspace=0.5)
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
        fig.savefig(tmpfile.name)
        return tmpfile.name

# Text Query Handler
def search_by_text(text_input):
    if not text_input.strip():
        return "Please provide a text input.", None
    query_vector = model.encode(text=text_input)[0].tolist()
    results = qdrant_client.query_points(
        collection_name="bge_visualized_m3_demo",
        query=query_vector,
        with_payload=True,
    ).points
    image_path = visualize_results(results)
    return f"Results for: {text_input}", image_path

# Image Query Handler
def search_by_image(image_input):
    if image_input is None:
        return "Please upload an image.", None
    query_vector = model.encode(image=image_input)[0].tolist()
    results = qdrant_client.query_points(
        collection_name="bge_visualized_m3",
        query=query_vector,
        with_payload=True,
    ).points
    image_path = visualize_results(results)
    return "Results for image query", image_path

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# πŸ” Visualized BGE: Multimodal Search with Qdrant")
    with gr.Tab("πŸ“ Text Query"):
        text_input = gr.Textbox(label="Enter text to search")
        text_output = gr.Textbox(label="Query Info")
        text_image = gr.Image(label="Results", type="filepath")
        text_btn = gr.Button("Search")
        text_btn.click(fn=search_by_text, inputs=text_input, outputs=[text_output, text_image])

    with gr.Tab("πŸ–ΌοΈ Image Query"):
        image_input = gr.Image(label="Upload image to search", type="pil")
        image_output = gr.Textbox(label="Query Info")
        image_result = gr.Image(label="Results", type="filepath")
        image_btn = gr.Button("Search")
        image_btn.click(fn=search_by_image, inputs=image_input, outputs=[image_output, image_result])

demo.launch()