File size: 3,129 Bytes
52b042a
 
 
dd10fbb
52b042a
dd10fbb
 
 
 
 
 
52b042a
dd10fbb
52b042a
dd10fbb
 
52b042a
 
 
dd10fbb
da8206b
 
 
 
dd10fbb
da8206b
dd10fbb
52b042a
dd10fbb
da8206b
 
 
 
 
 
 
 
 
 
 
 
 
52b042a
 
 
 
 
da8206b
 
dd10fbb
da8206b
 
 
 
 
 
dd10fbb
da8206b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b042a
 
 
 
 
 
 
dd10fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b042a
 
 
 
 
 
 
 
 
 
 
da8206b
52b042a
 
 
dd10fbb
 
52b042a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from pyserini.search.lucene import LuceneSearcher
import os
import json

def initialize_searcher(index_name):
    if not os.path.exists(index_name):
        os.system(f'python -c "from pyserini.search import LuceneSearcher; LuceneSearcher.from_prebuilt_index(\'{index_name}\')"')
    searcher = LuceneSearcher.from_prebuilt_index(index_name)
    searcher.set_bm25(k1=0.9, b=0.4)
    return searcher

def search_pyserini(query, top_k, index_name):
    try:
        searcher = initialize_searcher(index_name)
        hits = searcher.search(query, k=top_k)
        results = []
        for i, hit in enumerate(hits):
            doc = searcher.doc(hit.docid)
            doc_dict = json.loads(doc.raw())
            results.append({
                "rank": i + 1,
                "doc_id": hit.docid,
                "score": hit.score,
                "content": doc_dict['contents']
            })
        return format_results(results)
    except Exception as e:
        return f"<div class='error'>An error occurred: {str(e)}</div>"

def format_results(results):
    html = "<div class='results-container'>"
    for result in results:
        html += f"""
        <div class='result-item'>
            <h3>Rank {result['rank']} (Score: {result['score']:.4f})</h3>
            <p class='doc-id'>Doc ID: {result['doc_id']}</p>
            <p class='content'>{result['content']}</p>
        </div>
        """
    html += "</div>"
    return html

css = """
.gradio-container {
    font-family: 'Arial', sans-serif;
}
.results-container {
    display: flex;
    flex-direction: column;
    gap: 20px;
}
.result-item {
    border: 1px solid #ddd;
    border-radius: 8px;
    padding: 15px;
    width: 100%;
    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.result-item h3 {
    margin-top: 0;
    color: #333;
}
.doc-id {
    font-size: 0.9em;
    color: #666;
    margin-bottom: 10px;
}
.content {
    font-size: 0.95em;
    line-height: 1.4;
}
.error {
    color: red;
    font-weight: bold;
}
"""

with gr.Blocks(css=css) as iface:
    gr.Markdown("# Pyserini Search Interface")
    gr.Markdown("Enter a query to search using Pyserini with BM25 scoring (k1=0.9, b=0.4).")
    
    with gr.Row():
        index_input = gr.Textbox(
            value="msmarco-passage",
            lines=1,
            label="Prebuilt Index Name",
            placeholder="Enter the name of the prebuilt index"
        )
    
    with gr.Row():
        top_k_slider = gr.Slider(
            minimum=1,
            maximum=100,
            value=10,
            step=1,
            label="Number of top results to return"
        )
    
    with gr.Row():
        query_input = gr.Textbox(
            lines=1, 
            placeholder="Enter your search query here...",
            label="Search Query"
        )
    
    with gr.Row():
        search_button = gr.Button("Search", variant="primary")
    
    with gr.Row():
        output = gr.HTML(label="Search Results")
    
    search_button.click(
        fn=search_pyserini,
        inputs=[query_input, top_k_slider, index_input],
        outputs=output
    )

iface.launch()