Rajiv Shah commited on
Commit
ad9fcac
1 Parent(s): 1f33ec6
Files changed (2) hide show
  1. app.py +57 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
4
+ from transformers import pipeline
5
+ import torch
6
+ import pickle
7
+ import pandas as pd
8
+ import gradio as gr
9
+
10
+
11
+ ##Speech Recognition
12
+ asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h")
13
+ def speech_to_text(speech):
14
+ text = asr(speech)["text"]
15
+ return text
16
+
17
+ bi_encoder = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
18
+ cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
19
+ corpus_embeddings=pd.read_pickle("corpus_embeddings_cpu.pkl")
20
+ corpus=pd.read_pickle("corpus.pkl")
21
+
22
+ def search(query,top_k=100):
23
+ print("Top 3 Answer by the NSE:")
24
+ print()
25
+ ans=[]
26
+ ##### Sematic Search #####
27
+ # Encode the query using the bi-encoder and find potentially relevant passages
28
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
29
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
30
+ hits = hits[0] # Get the hits for the first query
31
+ ##### Re-Ranking #####
32
+ # Now, score all retrieved passages with the cross_encoder
33
+ cross_inp = [[query, corpus[hit['corpus_id']]] for hit in hits]
34
+ cross_scores = cross_encoder.predict(cross_inp)
35
+ # Sort results by the cross-encoder scores
36
+ for idx in range(len(cross_scores)):
37
+ hits[idx]['cross-score'] = cross_scores[idx]
38
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
39
+
40
+ for idx, hit in enumerate(hits[0:3]):
41
+ ans.append(corpus[hit['corpus_id']])
42
+ return ans[0],ans[1],ans[2]
43
+
44
+
45
+ demo = gr.Blocks()
46
+ with demo:
47
+ audio_file = gr.inputs.Audio(source="microphone", type="filepath")
48
+ b1 = gr.Button("Recognize Speech")
49
+ text = gr.Textbox()
50
+ b1.click(speech_to_text, inputs=audio_file, outputs=text)
51
+ b2 = gr.Button("Ask Wiki")
52
+ print(text)
53
+ out1 = gr.Textbox()
54
+ out2 = gr.Textbox()
55
+ out3 = gr.Textbox()
56
+ b2.click(search, inputs=text, outputs=[out1,out2,out3])
57
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ spacy
4
+ gradio
5
+ sentence-transformers
6
+ pickle
7
+ pandas