ankur310794 commited on
Commit
f4128ca
1 Parent(s): dbe46bc
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import wikipedia
3
+ from wikipedia.exceptions import DisambiguationError
4
+ from transformers import TFAutoModel, AutoTokenizer
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ try:
9
+ nlp = spacy.load("en_core_web_sm")
10
+ except:
11
+ spacy.cli.download("en_core_web_sm")
12
+ nlp = spacy.load("en_core_web_sm")
13
+
14
+ wh_words = ['what', 'who', 'how', 'when', 'which']
15
+ def get_concepts(text):
16
+ text = text.lower()
17
+ doc = nlp(text)
18
+ concepts = []
19
+ for chunk in doc.noun_chunks:
20
+ if chunk.text not in wh_words:
21
+ concepts.append(chunk.text)
22
+ return concepts
23
+
24
+ def get_passages(text, k=100):
25
+ doc = nlp(text)
26
+ passages = []
27
+ passage_len = 0
28
+ passage = ""
29
+ sents = list(doc.sents)
30
+ for i in range(len(sents)):
31
+ sen = sents[i]
32
+ passage_len+=len(sen)
33
+ if passage_len >= k:
34
+ passages.append(passage)
35
+ passage = sen.text
36
+ passage_len = len(sen)
37
+ continue
38
+
39
+ elif i==(len(sents)-1):
40
+ passage+=" "+sen.text
41
+ passages.append(passage)
42
+ passage = ""
43
+ passage_len = 0
44
+ continue
45
+
46
+ passage+=" "+sen.text
47
+ return passages
48
+
49
+ def get_dicts_for_dpr(concepts, n_results=20, k=100):
50
+ dicts = []
51
+ for concept in concepts:
52
+ wikis = wikipedia.search(concept, results=n_results)
53
+ print(concept, "No of Wikis: ",len(wikis))
54
+ for wiki in wikis:
55
+ try:
56
+ html_page = wikipedia.page(title = wiki, auto_suggest = False)
57
+ except DisambiguationError:
58
+ continue
59
+
60
+ passages = get_passages(html_page.content, k=k)
61
+ for passage in passages:
62
+ i_dicts = {}
63
+ i_dicts['text'] = passage
64
+ i_dicts['title'] = wiki
65
+ dicts.append(i_dicts)
66
+ return dicts
67
+
68
+ passage_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2")
69
+ query_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-2_H-128_A-2")
70
+
71
+ p_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2")
72
+ q_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-2_H-128_A-2")
73
+
74
+ def get_title_text_combined(passage_dicts):
75
+ res = []
76
+ for p in passage_dicts:
77
+ res.append(tuple((p['title'], p['text'])))
78
+ return res
79
+
80
+ def extracted_passage_embeddings(processed_passages, max_length=156):
81
+ passage_inputs = p_tokenizer.batch_encode_plus(
82
+ processed_passages,
83
+ add_special_tokens=True,
84
+ truncation=True,
85
+ padding="max_length",
86
+ max_length=max_length,
87
+ return_token_type_ids=True
88
+ )
89
+ passage_embeddings = passage_encoder.predict([np.array(passage_inputs['input_ids']),
90
+ np.array(passage_inputs['attention_mask']),
91
+ np.array(passage_inputs['token_type_ids'])],
92
+ batch_size=64,
93
+ verbose=1)
94
+ return passage_embeddings
95
+
96
+ def extracted_query_embeddings(queries, max_length=64):
97
+ query_inputs = q_tokenizer.batch_encode_plus(
98
+ queries,
99
+ add_special_tokens=True,
100
+ truncation=True,
101
+ padding="max_length",
102
+ max_length=max_length,
103
+ return_token_type_ids=True
104
+ )
105
+ query_embeddings = query_encoder.predict([np.array(query_inputs['input_ids']),
106
+ np.array(query_inputs['attention_mask']),
107
+ np.array(query_inputs['token_type_ids'])],
108
+ batch_size=1,
109
+ verbose=1)
110
+ return query_embeddings
111
+
112
+ def search(question):
113
+ concepts = get_concepts(question)
114
+ print("concepts: ",concepts)
115
+ dicts = get_dicts_for_dpr(concepts, n_results=1)
116
+ print("dicts len: ", len(dicts))
117
+ processed_passages = get_title_text_combined(dicts)
118
+ passage_embeddings = extracted_passage_embeddings(processed_passages)
119
+ query_embeddings = extracted_query_embeddings([question])
120
+ faiss_index = faiss.IndexFlatL2(128)
121
+ faiss_index.add(passage_embeddings.pooler_output)
122
+ prob, index = faiss_index.search(query_embeddings.pooler_output, k=10)
123
+ return pd.DataFrame([dicts[i] for i in index[0]])
124
+
125
+ import gradio as gr
126
+ inp = gr.inputs.Textbox(lines=2, default=question, label="Question")
127
+ out = gr.outputs.Dataframe(label="Answers")#gr.outputs.Textbox(label="Answers")
128
+ gr.Interface(fn=search, inputs=inp, outputs=out).launch()