File size: 3,408 Bytes
be785b0
 
181ee96
 
 
 
be785b0
 
52b042a
 
 
dd10fbb
52b042a
dd10fbb
 
 
 
 
 
52b042a
dd10fbb
52b042a
dd10fbb
 
52b042a
 
 
dd10fbb
da8206b
 
 
 
dd10fbb
da8206b
dd10fbb
52b042a
dd10fbb
da8206b
 
 
 
 
 
 
 
 
 
 
 
 
52b042a
 
 
 
 
da8206b
 
dd10fbb
da8206b
 
 
 
 
 
dd10fbb
da8206b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b042a
 
 
 
 
 
 
dd10fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b042a
 
 
 
 
 
 
 
 
 
 
da8206b
52b042a
 
 
dd10fbb
 
52b042a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
# set JAVA_HOME by finding it, e.g. JAVA_HOME=$(readlink -f /usr/bin/javac | sed "s:bin/javac::")
# print the contents of /user/lib
print(os.listdir("/usr/lib"))
print(os.listdir("/usr/lib/jvm"))
os.environ["JAVA_HOME"] = "/usr/lib/jvm"
print(os.environ["JAVA_HOME"])

import gradio as gr
from pyserini.search.lucene import LuceneSearcher
import os
import json

def initialize_searcher(index_name):
    if not os.path.exists(index_name):
        os.system(f'python -c "from pyserini.search import LuceneSearcher; LuceneSearcher.from_prebuilt_index(\'{index_name}\')"')
    searcher = LuceneSearcher.from_prebuilt_index(index_name)
    searcher.set_bm25(k1=0.9, b=0.4)
    return searcher

def search_pyserini(query, top_k, index_name):
    try:
        searcher = initialize_searcher(index_name)
        hits = searcher.search(query, k=top_k)
        results = []
        for i, hit in enumerate(hits):
            doc = searcher.doc(hit.docid)
            doc_dict = json.loads(doc.raw())
            results.append({
                "rank": i + 1,
                "doc_id": hit.docid,
                "score": hit.score,
                "content": doc_dict['contents']
            })
        return format_results(results)
    except Exception as e:
        return f"<div class='error'>An error occurred: {str(e)}</div>"

def format_results(results):
    html = "<div class='results-container'>"
    for result in results:
        html += f"""
        <div class='result-item'>
            <h3>Rank {result['rank']} (Score: {result['score']:.4f})</h3>
            <p class='doc-id'>Doc ID: {result['doc_id']}</p>
            <p class='content'>{result['content']}</p>
        </div>
        """
    html += "</div>"
    return html

css = """
.gradio-container {
    font-family: 'Arial', sans-serif;
}
.results-container {
    display: flex;
    flex-direction: column;
    gap: 20px;
}
.result-item {
    border: 1px solid #ddd;
    border-radius: 8px;
    padding: 15px;
    width: 100%;
    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.result-item h3 {
    margin-top: 0;
    color: #333;
}
.doc-id {
    font-size: 0.9em;
    color: #666;
    margin-bottom: 10px;
}
.content {
    font-size: 0.95em;
    line-height: 1.4;
}
.error {
    color: red;
    font-weight: bold;
}
"""

with gr.Blocks(css=css) as iface:
    gr.Markdown("# Pyserini Search Interface")
    gr.Markdown("Enter a query to search using Pyserini with BM25 scoring (k1=0.9, b=0.4).")
    
    with gr.Row():
        index_input = gr.Textbox(
            value="msmarco-passage",
            lines=1,
            label="Prebuilt Index Name",
            placeholder="Enter the name of the prebuilt index"
        )
    
    with gr.Row():
        top_k_slider = gr.Slider(
            minimum=1,
            maximum=100,
            value=10,
            step=1,
            label="Number of top results to return"
        )
    
    with gr.Row():
        query_input = gr.Textbox(
            lines=1, 
            placeholder="Enter your search query here...",
            label="Search Query"
        )
    
    with gr.Row():
        search_button = gr.Button("Search", variant="primary")
    
    with gr.Row():
        output = gr.HTML(label="Search Results")
    
    search_button.click(
        fn=search_pyserini,
        inputs=[query_input, top_k_slider, index_input],
        outputs=output
    )

iface.launch()