Nuno Machado commited on
Commit
9e0bc77
1 Parent(s): 3f9c44b

Add gradio application

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from utils.dataset_loader import DatasetLoader
4
+ from embeddings.huggingface import HuggingFaceEncoder
5
+ from search.faiss import FaissSearchEngine
6
+
7
+ # Preload dataset with embeddings for chunk_size = 25
8
+ ds_test_embeddings = DatasetLoader.load_from_file_with_embeddings("./data/df_chunked_25_with_embeddings.csv")
9
+ hf_encoder = HuggingFaceEncoder("sentence-transformers/multi-qa-mpnet-base-dot-v1")
10
+
11
+
12
+ def retrieve_chunks(query, chunk_size, embeddings_generator, retriever_method, num_chunks_to_retrieve):
13
+ # Ignore chunk_size, embeddings_generator, and retriever_method,
14
+ # as we currently support only a single configuration
15
+ faiss_search = FaissSearchEngine(ds_test_embeddings, hf_encoder)
16
+
17
+ return faiss_search.search(query, num_chunks_to_retrieve)
18
+
19
+
20
+ # Create the Gradio application
21
+ with gr.Blocks() as demo:
22
+ query = gr.inputs.Textbox(label='Query', placeholder="Enter your query here. Example: 'What is a transformer?'")
23
+ chunk_size = gr.inputs.Slider(
24
+ minimum=25,
25
+ maximum=25,
26
+ step=25,
27
+ default=25,
28
+ label='Chunk Size'
29
+ )
30
+ embeddings_generator = gr.Radio(
31
+ ['sentence-transformers/multi-qa-mpnet-base-dot-v1'],
32
+ label='Embeddings Generator',
33
+ value='sentence-transformers/multi-qa-mpnet-base-dot-v1'
34
+ )
35
+ retriever_method = gr.Radio(
36
+ ['FAISS'],
37
+ value="FAISS",
38
+ label="Retriever Method"
39
+ )
40
+ num_chunks_to_retrieve = gr.inputs.Slider(
41
+ minimum=3,
42
+ maximum=5,
43
+ step=1,
44
+ default=3,
45
+ label='Number of Chunks to Retrieve'
46
+ )
47
+ inputs = [query, chunk_size, embeddings_generator, retriever_method, num_chunks_to_retrieve]
48
+
49
+ submit_btn = gr.Button("Submit")
50
+
51
+ outputs = gr.Dataframe(
52
+ headers=['id', 'guest', 'title', 'text', 'start', 'end', 'scores'],
53
+ type="pandas",
54
+ wrap=True
55
+ )
56
+
57
+ submit_btn.click(retrieve_chunks, inputs=inputs, outputs=outputs)
58
+
59
+ # Run the Gradio application
60
+ demo.launch()