File size: 1,996 Bytes
9e0bc77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr

from utils.dataset_loader import DatasetLoader
from embeddings.huggingface import HuggingFaceEncoder
from search.faiss import FaissSearchEngine

# Preload dataset with embeddings for chunk_size = 25
ds_test_embeddings = DatasetLoader.load_from_file_with_embeddings("./data/df_chunked_25_with_embeddings.csv")
hf_encoder = HuggingFaceEncoder("sentence-transformers/multi-qa-mpnet-base-dot-v1")


def retrieve_chunks(query, chunk_size, embeddings_generator, retriever_method, num_chunks_to_retrieve):
    # Ignore chunk_size, embeddings_generator, and retriever_method,
    # as we currently support only a single configuration
    faiss_search = FaissSearchEngine(ds_test_embeddings, hf_encoder)

    return faiss_search.search(query, num_chunks_to_retrieve)


# Create the Gradio application
with gr.Blocks() as demo:
    query = gr.inputs.Textbox(label='Query', placeholder="Enter your query here. Example: 'What is a transformer?'")
    chunk_size = gr.inputs.Slider(
        minimum=25,
        maximum=25,
        step=25,
        default=25,
        label='Chunk Size'
    )
    embeddings_generator = gr.Radio(
        ['sentence-transformers/multi-qa-mpnet-base-dot-v1'],
        label='Embeddings Generator',
        value='sentence-transformers/multi-qa-mpnet-base-dot-v1'
    )
    retriever_method = gr.Radio(
        ['FAISS'],
        value="FAISS",
        label="Retriever Method"
    )
    num_chunks_to_retrieve = gr.inputs.Slider(
        minimum=3,
        maximum=5,
        step=1,
        default=3,
        label='Number of Chunks to Retrieve'
    )
    inputs = [query, chunk_size, embeddings_generator, retriever_method, num_chunks_to_retrieve]

    submit_btn = gr.Button("Submit")

    outputs = gr.Dataframe(
        headers=['id', 'guest', 'title', 'text', 'start', 'end', 'scores'],
        type="pandas",
        wrap=True
    )

    submit_btn.click(retrieve_chunks, inputs=inputs, outputs=outputs)

# Run the Gradio application
demo.launch()