File size: 6,645 Bytes
bf43437
 
31ab17b
bf43437
 
 
 
31ab17b
 
bf43437
 
31ab17b
d77c2f1
31ab17b
bf43437
 
 
 
 
 
 
 
 
 
 
31ab17b
 
 
 
 
 
 
 
 
 
 
 
 
bf43437
 
 
 
 
 
31ab17b
e95b21b
31ab17b
bf43437
31ab17b
 
 
bf43437
 
 
31ab17b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf43437
d77c2f1
 
bf43437
31ab17b
d77c2f1
bf43437
 
 
 
 
 
 
 
31ab17b
bf43437
 
 
 
 
 
 
31ab17b
bf43437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1f9eb
bf43437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1f9eb
 
 
 
 
bf43437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1f9eb
bf43437
 
 
31ab17b
bf43437
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import asyncio
import html
import os
from io import BytesIO

import aiohttp
import dotenv
import gradio as gr
import requests
import torch
from colpali_engine.models import ColQwen2, ColQwen2Processor
from PIL import Image
from qdrant_client import QdrantClient

dotenv.load_dotenv()

if torch.cuda.is_available():
    device = "cuda:0"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"


os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Initialize ColPali model and processor
model_name = "vidore/colqwen2-v0.1"
colpali_model = ColQwen2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
)
colpali_processor = ColQwen2Processor.from_pretrained(
    model_name,
)

# Initialize Qdrant client
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
qdrant_client = QdrantClient(
    url="https://davanstrien-qdrant-test.hf.space",
    port=None,
    api_key=QDRANT_API_KEY,
    timeout=10,
)

collection_name = "song_sheets"  # Replace with your actual collection name


def search_images_by_text(query_text, top_k=5):
    # Process and encode the text query
    with torch.no_grad():
        batch_query = colpali_processor.process_queries([query_text]).to(
            colpali_model.device
        )
        query_embedding = colpali_model(**batch_query)

    # Convert the query embedding to a list of vectors
    multivector_query = query_embedding[0].cpu().float().numpy().tolist()

    # Search in Qdrant
    search_result = qdrant_client.query_points(
        collection_name=collection_name,
        query=multivector_query,
        limit=top_k,
        timeout=800,
    )

    return search_result


def modify_iiif_url(url, size_percent):
    # Modify the IIIF URL to use percentage scaling
    parts = url.split("/")
    size_index = -3
    parts[size_index] = f"pct:{size_percent}"
    return "/".join(parts)


async def fetch_image(session, url):
    async with session.get(url) as response:
        content = await response.read()
        return Image.open(BytesIO(content)).convert("RGB")


async def fetch_all_images(urls):
    async with aiohttp.ClientSession() as session:
        tasks = [fetch_image(session, url) for url in urls]
        return await asyncio.gather(*tasks)


async def search_and_display(query, top_k, size_percent):
    results = search_images_by_text(query, top_k)
    modified_urls = [
        modify_iiif_url(result.payload["image_url"], size_percent)
        for result in results.points
    ]

    images = await fetch_all_images(modified_urls)
    html_output = (
        "<div style='display: flex; flex-wrap: wrap; justify-content: space-around;'>"
    )
    for i, (image, result) in enumerate(zip(images, results.points)):
        image_url = modified_urls[i]
        item_url = result.payload["item_url"]
        score = result.score
        html_output += f"""
        <div style='margin: 10px; text-align: center; width: 300px;'>
            <img src='{image_url}' style='max-width: 100%; height: auto;'>
            <p>Score: {score:.2f}</p>
            <a href='{item_url}' target='_blank'>View Item</a>
        </div>
        """
    html_output += "</div>"
    return html_output


# Wrapper function for synchronous Gradio interface
def search_and_display_wrapper(query, top_k, size_percent):
    return asyncio.run(search_and_display(query, top_k, size_percent))


with gr.Blocks() as demo:
    gr.HTML(
        """
        <h1 style='text-align: center; color: #2a4b7c;'>America Singing: Nineteenth-Century Song Sheets ColPali Search</h1>
        <div style="display: flex; align-items: stretch; margin-bottom: 20px;">
            <div style="flex: 2; padding-right: 20px;">
                <p>This app allows you to search through the Library of Congress's <a href="https://www.loc.gov/collections/nineteenth-century-song-sheets/about-this-collection/" target="_blank">"America Singing: Nineteenth-Century Song Sheets"</a> collection using natural language queries. The collection contains 4,291 song sheets from the 19th century, offering a unique window into American history, culture, and music.</p>

                <p>This search functionality is powered by <a href="https://huggingface.co/blog/manu/colpali" target="_blank">ColPali</a>, an efficient document retrieval system that uses Vision Language Models. ColPali allows for searching through documents (including images and complex layouts) without the need for traditional text extraction or OCR. It works by directly embedding page images and using a <a href="https://jina.ai/news/what-is-colbert-and-late-interaction-and-why-they-matter-in-search/" target="_blank">late interaction mechanism</a> to match queries with relevant document patches.</p>

                <p>ColPali's approach:
                <ul>
                    <li>Uses a Vision Language Model to encode document page images directly</li>
                    <li>Splits images into patches and creates contextualized patch embeddings</li>
                    <li>Employs a late interaction mechanism to efficiently match query tokens to document patches</li>
                    <li>Eliminates the need for complex OCR and document parsing pipelines</li>
                    <li>Captures both textual and visual information from documents</li>
                </ul>
                </p>
            </div>
            <div style="flex: 1; display: flex; flex-direction: column;">
                <div style="flex-grow: 1; display: flex; flex-direction: column; justify-content: center;">
                    <img src="https://tile.loc.gov/image-services/iiif/service:rbc:amss:hc:00:00:3b:hc00003b:001a/full/pct:50/0/default.jpg" alt="Example Song Sheet" style="width: 100%; height: auto; max-height: 100%; object-fit: contain; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
                </div>
                <p style="text-align: center; margin-top: 10px;"><em>Example of a song sheet from the collection</em></p>
            </div>
        </div>
        """
    )
    with gr.Row():
        with gr.Column(scale=4):
            search_box = gr.Textbox(
                label="Search Query", placeholder="i.e. Irish migrant experience"
            )
            submit_button = gr.Button("Search", variant="primary")
    num_results = gr.Slider(
        minimum=1, maximum=20, step=1, label="Number of Results", value=5
    )
    results_html = gr.HTML(label="Search Results")

    submit_button.click(
        fn=lambda query, top_k: search_and_display_wrapper(query, top_k, 50),
        inputs=[search_box, num_results],
        outputs=results_html,
    )

demo.launch()