secilozksen commited on
Commit
98b83d0
1 Parent(s): 557bdd3

Delete demov2.py

Browse files
Files changed (1) hide show
  1. demov2.py +0 -304
demov2.py DELETED
@@ -1,304 +0,0 @@
1
- import copy
2
- import streamlit as st
3
- import json
4
- import pandas as pd
5
- import tokenizers
6
- from sentence_transformers import SentenceTransformer, CrossEncoder, util
7
- from transformers import pipeline
8
- from st_aggrid import GridOptionsBuilder, AgGrid
9
- import pickle
10
- import torch
11
- from transformers import RobertaTokenizer, RobertaForSequenceClassification
12
- import spacy
13
- import regex
14
- from typing import List
15
- from torch.autograd import Variable
16
-
17
- st.set_page_config(layout="wide")
18
-
19
- DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv'
20
- DATAFRAME_FILE_BSBS = 'policyQA_bsbs_sentence.csv'
21
-
22
-
23
- @st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
24
- def cross_encoder_init():
25
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
26
- return cross_encoder
27
-
28
-
29
- @st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
30
- def bi_encoder_init():
31
- bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
32
- bi_encoder.max_seq_length = 500 # Truncate long passages to 256 tokens
33
- return bi_encoder
34
-
35
-
36
- @st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
37
- def nlp_init(auth_token, private_model_name):
38
- return pipeline('question-answering', model=private_model_name, tokenizer=private_model_name,
39
- use_auth_token=auth_token,
40
- revision="main")
41
-
42
-
43
- @st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
44
- def nlp_pipeline_hf():
45
- model_name = "deepset/roberta-base-squad2"
46
- return pipeline('question-answering', model=model_name, tokenizer=model_name)
47
-
48
-
49
- @st.experimental_singleton(suppress_st_warning=True, show_spinner=False)
50
- def nlp_pipeline_sentence_based(auth_token, private_model_name):
51
- tokenizer = RobertaTokenizer.from_pretrained(private_model_name, use_auth_token=auth_token)
52
- model = RobertaForSequenceClassification.from_pretrained(private_model_name, use_auth_token=auth_token)
53
- return tokenizer, model
54
-
55
-
56
- @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None,
57
- regex.Pattern: lambda _: None}, show_spinner=False)
58
- def load_models_sentence_based(auth_token, private_model_name, private_model_name_base):
59
- bi_encoder = bi_encoder_init()
60
- cross_encoder = cross_encoder_init()
61
- # OLD MODEL
62
- # nlp = nlp_init(auth_token, private_model_name)
63
- # nlp_hf = nlp_pipeline_hf()
64
- policy_qa_tokenizer, policy_qa_model = nlp_pipeline_sentence_based(auth_token, private_model_name)
65
- asnq_tokenizer, asnq_model = nlp_pipeline_sentence_based(auth_token, private_model_name_base)
66
-
67
- return bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model
68
-
69
-
70
- @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None}, show_spinner=False)
71
- def load_models(auth_token, private_model_name):
72
- bi_encoder = bi_encoder_init()
73
- cross_encoder = cross_encoder_init()
74
- nlp = nlp_init(auth_token, private_model_name)
75
- nlp_hf = nlp_pipeline_hf()
76
-
77
- return bi_encoder, cross_encoder, nlp, nlp_hf
78
-
79
-
80
- def context():
81
- bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', device='cpu')
82
- with open("/home/secilsen/PycharmProjects/SquadOperations/contexes.json", 'r', encoding='utf-8') as f:
83
- paragraphs = json.load(f)
84
- paragraphs = paragraphs['contexes']
85
- with open('context-embeddings.pkl', "wb") as fIn:
86
- context_embeddings = bi_encoder.encode(paragraphs, convert_to_tensor=True, show_progress_bar=True)
87
- pickle.dump({'contexes': paragraphs, 'embeddings': context_embeddings}, fIn)
88
-
89
-
90
- @st.cache(show_spinner=False)
91
- def load_paragraphs():
92
- with open('context-embeddings.pkl', "rb") as fIn:
93
- cache_data = pickle.load(fIn)
94
- corpus_sentences = cache_data['contexes']
95
- corpus_embeddings = cache_data['embeddings']
96
-
97
- return corpus_embeddings, corpus_sentences
98
-
99
-
100
- @st.cache(show_spinner=False)
101
- def load_dataframes():
102
- data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
103
- data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
104
- data_original = data_original.sample(frac=1).reset_index(drop=True)
105
- data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
106
- return data_original, data_bsbs
107
-
108
-
109
- def search(question, corpus_embeddings, contexes, bi_encoder, cross_encoder):
110
- # Semantic Search (Retrieve)
111
- question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
112
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100)
113
- if len(hits) == 0:
114
- return []
115
- hits = hits[0]
116
- # Rerank - score all retrieved passages with cross-encoder
117
- cross_inp = [[question, contexes[hit['corpus_id']]] for hit in hits]
118
- cross_scores = cross_encoder.predict(cross_inp)
119
-
120
- # Sort results by the cross-encoder scores
121
- for idx in range(len(cross_scores)):
122
- hits[idx]['cross-score'] = cross_scores[idx]
123
-
124
- # Output of top-5 hits from re-ranker
125
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
126
- top_5_contexes = []
127
- top_5_scores = []
128
- for hit in hits[0:20]:
129
- top_5_contexes.append(contexes[hit['corpus_id']])
130
- top_5_scores.append(hit['cross-score'])
131
- return top_5_contexes, top_5_scores
132
-
133
-
134
- def paragraph_embeddings():
135
- paragraphs = load_paragraphs()
136
- context_embeddings = bi_encoder.encode(paragraphs, convert_to_tensor=True, show_progress_bar=True)
137
- return context_embeddings, paragraphs
138
-
139
-
140
- def retrieve_rerank_pipeline(question, context_embeddings, paragraphs, bi_encoder, cross_encoder):
141
- top_5_contexes, top_5_scores = search(question, context_embeddings, paragraphs, bi_encoder, cross_encoder)
142
- return top_5_contexes, top_5_scores
143
-
144
-
145
- def qa_pipeline(question, context, nlp):
146
- return nlp({'question': question.strip(), 'context': context})
147
-
148
-
149
- def qa_pipeline_sentence(question, context, model, tokenizer):
150
- sentences_doc = spacy_nlp(context)
151
- candidate_sentences = []
152
- for sentence in sentences_doc.sents:
153
- tokenized = tokenizer(f"<s> {question} </s> {sentence.text} </s>", padding=True, truncation=True, return_tensors='pt')
154
- output = model(**tokenized)
155
- soft_outputs = torch.nn.functional.sigmoid(output[0])
156
- t = Variable(torch.Tensor([0.2])) # threshold
157
- out = (soft_outputs[0] > t) * 1
158
- out = out.flatten().cpu().detach().numpy()
159
- # res = torch.argmax(out, dim=-1)
160
- print(out[1])
161
- if out[1] == 1:
162
- prob = soft_outputs[:, 1].flatten().cpu().detach().numpy()
163
- candidate_sentences.append(dict(sentence=sentence,
164
- prob=prob[0]))
165
- print(candidate_sentences)
166
- candidate_sentences = sorted(candidate_sentences, key=lambda x: x['prob'], reverse=True)
167
- return candidate_sentences
168
-
169
-
170
- def candidate_sentence_controller(sentences):
171
- if sentences is None or len(sentences) == 0:
172
- return ""
173
- if len(sentences) == 1:
174
- return sentences[0]
175
- return sentences
176
-
177
-
178
- def interactive_table(dataframe):
179
- gb = GridOptionsBuilder.from_dataframe(dataframe)
180
- gb.configure_pagination(paginationAutoPageSize=True)
181
- gb.configure_side_bar()
182
- gb.configure_selection('single', rowMultiSelectWithClick=True,
183
- groupSelectsChildren="Group checkbox select children") # Enable multi-row selection
184
- gridOptions = gb.build()
185
- grid_response = AgGrid(
186
- dataframe,
187
- gridOptions=gridOptions,
188
- data_return_mode='AS_INPUT',
189
- update_mode='SELECTION_CHANGED',
190
- enable_enterprise_modules=False,
191
- fit_columns_on_grid_load=False,
192
- theme='streamlit', # Add theme color to the table
193
- height=350,
194
- width='100%',
195
- reload_data=False
196
- )
197
- return grid_response
198
-
199
-
200
- def qa_main_widgetsv2():
201
- st.title("Question Answering Demo")
202
- col1, col2, col3 = st.columns([2, 1, 1])
203
- with col1:
204
- form = st.form(key='first_form')
205
- question = form.text_area("What is your question?:", height=200)
206
- submit = form.form_submit_button('Submit')
207
- if "form_submit" not in st.session_state:
208
- st.session_state.form_submit = False
209
- if submit:
210
- st.session_state.form_submit = True
211
- if st.session_state.form_submit and question != '':
212
- with st.spinner(text='Related context search in progress..'):
213
- top_5_contexes, top_5_scores = retrieve_rerank_pipeline(question.strip(), context_embeddings,
214
- paragraphs, bi_encoder,
215
- cross_encoder)
216
- if len(top_5_contexes) == 0:
217
- st.error("Related context not found!")
218
- st.session_state.form_submit = False
219
- else:
220
- with st.spinner(text='Now answering your question..'):
221
- for i, context in enumerate(top_5_contexes):
222
- # answer_trained = qa_pipeline(question, context, nlp)
223
- # answer_base = qa_pipeline(question, context, nlp_hf)
224
- answer_trained = qa_pipeline_sentence(question, context, policy_qa_model, policy_qa_tokenizer)
225
- answer_base = qa_pipeline_sentence(question, context, asnq_model, asnq_tokenizer)
226
- st.markdown(f"## Related Context - {i + 1} (score: {top_5_scores[i]:.2f})")
227
- st.markdown(context)
228
- st.markdown("## Answer (trained):")
229
- if answer_trained is None:
230
- st.markdown("")
231
- elif isinstance(answer_trained, List):
232
- for i,answer in enumerate(answer_trained):
233
- st.markdown(f"### Answer Option {i+1} with prob. {answer['prob']:.4f}")
234
- st.markdown(answer['sentence'])
235
- else:
236
- st.markdown(answer_trained)
237
- # st.markdown(answer_trained['answer'])
238
- st.markdown("## Answer (roberta-base-asnq):")
239
- if answer_base is None:
240
- st.markdown("")
241
- elif isinstance(answer_base, List):
242
- for i,answer in enumerate(answer_base):
243
- st.markdown(f"### Answer Option {i + 1} with prob. {answer['prob']:.4f}")
244
- st.markdown(answer['sentence'])
245
- else:
246
- st.markdown(answer_base)
247
- st.markdown("""---""")
248
-
249
- with col2:
250
- st.markdown("## Original Questions")
251
- grid_response = interactive_table(dataframe_original)
252
- data1 = grid_response['selected_rows']
253
- if "grid_click_1" not in st.session_state:
254
- st.session_state.grid_click_1 = False
255
- if len(data1) > 0:
256
- st.session_state.grid_click_1 = True
257
- if st.session_state.grid_click_1:
258
- selection = data1[0]
259
- # st.markdown("## Context & Answer:")
260
- st.markdown("### Context:")
261
- st.write(selection['context'])
262
- st.markdown("### Question:")
263
- st.write(selection['question'])
264
- st.markdown("### Answer:")
265
- st.write(selection['answer'])
266
- st.session_state.grid_click_1 = False
267
- with col3:
268
- st.markdown("## Our Questions")
269
- grid_response = interactive_table(dataframe_bsbs)
270
- data2 = grid_response['selected_rows']
271
- if "grid_click_2" not in st.session_state:
272
- st.session_state.grid_click_2 = False
273
- if len(data2) > 0:
274
- st.session_state.grid_click_2 = True
275
- if st.session_state.grid_click_2:
276
- selection = data2[0]
277
- # st.markdown("## Context & Answer:")
278
- st.markdown("### Context:")
279
- st.write(selection['context'])
280
- st.markdown("### Question:")
281
- st.write(selection['question'])
282
- st.markdown("### Answer:")
283
- st.write(selection['answer'])
284
- st.session_state.grid_click_2 = False
285
-
286
-
287
- def load():
288
- context_embeddings, paragraphs = load_paragraphs()
289
- dataframe_original, dataframe_bsbs = load_dataframes()
290
- spacy_nlp = spacy.load('en_core_web_sm')
291
- # bi_encoder, cross_encoder, nlp, nlp_hf = copy.deepcopy(load(st.secrets["AUTH_TOKEN"], st.secrets["MODEL_NAME"]))
292
- bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model \
293
- = copy.deepcopy(
294
- load_models_sentence_based(st.secrets["AUTH_TOKEN"], st.secrets["MODEL_NAME"], st.secrets["MODEL_NAME_BASE"]))
295
- return context_embeddings, paragraphs, dataframe_original, dataframe_bsbs, bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model, spacy_nlp
296
-
297
-
298
- # save_dataframe()
299
- # context_embeddings, paragraphs, dataframe_original, dataframe_bsbs, bi_encoder, cross_encoder, nlp, nlp_hf = load()
300
- context_embeddings, paragraphs, dataframe_original, dataframe_bsbs, bi_encoder, cross_encoder, policy_qa_tokenizer, policy_qa_model, asnq_tokenizer, asnq_model, spacy_nlp = load()
301
- qa_main_widgetsv2()
302
-
303
- # if __name__ == '__main__':
304
- # context()