Achyut Tiwari commited on
Commit
83ede0c
1 Parent(s): 468d439

Add files via upload

Browse files
Files changed (3) hide show
  1. pages/ask.py +376 -0
  2. pages/info.py +92 -0
  3. pages/settings.py +95 -0
pages/ask.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import json
3
+ import re
4
+ import time
5
+
6
+ import nltk
7
+ import numpy as np
8
+ from nltk import tokenize
9
+
10
+ nltk.download('punkt')
11
+ from google.oauth2 import service_account
12
+ from google.cloud import texttospeech
13
+
14
+ from typing import Dict, Optional, List
15
+
16
+ import jwt
17
+ import requests
18
+ import streamlit as st
19
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
20
+
21
+ JWT_SECRET = st.secrets["api_secret"]
22
+ JWT_ALGORITHM = st.secrets["api_algorithm"]
23
+ INFERENCE_TOKEN = st.secrets["api_inference"]
24
+ CONTEXT_API_URL = st.secrets["api_context"]
25
+ LFQA_API_URL = st.secrets["api_lfqa"]
26
+
27
+ headers = {"Authorization": f"Bearer {INFERENCE_TOKEN}"}
28
+ API_URL = "https://api-inference.huggingface.co/models/vblagoje/bart_lfqa"
29
+ API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan"
30
+
31
+
32
+ def api_inference_lfqa(model_input: str):
33
+ payload = {
34
+ "inputs": model_input,
35
+ "parameters": {
36
+ "truncation": "longest_first",
37
+ "min_length": st.session_state["min_length"],
38
+ "max_length": st.session_state["max_length"],
39
+ "do_sample": st.session_state["do_sample"],
40
+ "early_stopping": st.session_state["early_stopping"],
41
+ "num_beams": st.session_state["num_beams"],
42
+ "temperature": st.session_state["temperature"],
43
+ "top_k": None,
44
+ "top_p": None,
45
+ "no_repeat_ngram_size": 3,
46
+ "num_return_sequences": 1
47
+ },
48
+ "options": {
49
+ "wait_for_model": True
50
+ }
51
+ }
52
+ data = json.dumps(payload)
53
+ response = requests.request("POST", API_URL, headers=headers, data=data)
54
+ return json.loads(response.content.decode("utf-8"))
55
+
56
+
57
+ def inference_lfqa(model_input: str, header: dict):
58
+ payload = {
59
+ "model_input": model_input,
60
+ "parameters": {
61
+ "min_length": st.session_state["min_length"],
62
+ "max_length": st.session_state["max_length"],
63
+ "do_sample": st.session_state["do_sample"],
64
+ "early_stopping": st.session_state["early_stopping"],
65
+ "num_beams": st.session_state["num_beams"],
66
+ "temperature": st.session_state["temperature"],
67
+ "top_k": None,
68
+ "top_p": None,
69
+ "no_repeat_ngram_size": 3,
70
+ "num_return_sequences": 1
71
+ }
72
+ }
73
+ data = json.dumps(payload)
74
+ try:
75
+ response = requests.request("POST", LFQA_API_URL, headers=header, data=data)
76
+ if response.status_code == 200:
77
+ json_response = response.content.decode("utf-8")
78
+ result = json.loads(json_response)
79
+ else:
80
+ result = {"error": f"LFQA service unavailable, status code={response.status_code}"}
81
+ except requests.exceptions.RequestException as e:
82
+ result = {"error": e}
83
+ return result
84
+
85
+
86
+ def invoke_lfqa(service_backend: str, model_input: str, header: Optional[dict]):
87
+ if "HuggingFace" == service_backend:
88
+ inference_response = api_inference_lfqa(model_input)
89
+ else:
90
+ inference_response = inference_lfqa(model_input, header)
91
+ return inference_response
92
+
93
+
94
+ @st.cache(allow_output_mutation=True, show_spinner=False)
95
+ def hf_tts(text: str):
96
+ payload = {
97
+ "inputs": text,
98
+ "parameters": {
99
+ "vocoder_tag": "str_or_none(none)",
100
+ "threshold": 0.5,
101
+ "minlenratio": 0.0,
102
+ "maxlenratio": 10.0,
103
+ "use_att_constraint": False,
104
+ "backward_window": 1,
105
+ "forward_window": 3,
106
+ "speed_control_alpha": 1.0,
107
+ "noise_scale": 0.333,
108
+ "noise_scale_dur": 0.333
109
+ },
110
+ "options": {
111
+ "wait_for_model": True
112
+ }
113
+ }
114
+ data = json.dumps(payload)
115
+ response = requests.request("POST", API_URL_TTS, headers=headers, data=data)
116
+ return response.content
117
+
118
+
119
+ @st.cache(allow_output_mutation=True, show_spinner=False)
120
+ def google_tts(text: str, private_key_id: str, private_key: str, client_email: str):
121
+ config = {
122
+ "private_key_id": private_key_id,
123
+ "private_key": f"-----BEGIN PRIVATE KEY-----\n{private_key}\n-----END PRIVATE KEY-----\n",
124
+ "client_email": client_email,
125
+ "token_uri": "https://oauth2.googleapis.com/token",
126
+ }
127
+ credentials = service_account.Credentials.from_service_account_info(config)
128
+ client = texttospeech.TextToSpeechClient(credentials=credentials)
129
+
130
+ synthesis_input = texttospeech.SynthesisInput(text=text)
131
+
132
+ # Build the voice request, select the language code ("en-US") and the ssml
133
+ # voice gender ("neutral")
134
+ voice = texttospeech.VoiceSelectionParams(language_code="en-US",
135
+ ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL)
136
+
137
+ # Select the type of audio file you want returned
138
+ audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3)
139
+
140
+ # Perform the text-to-speech request on the text input with the selected
141
+ # voice parameters and audio file type
142
+ response = client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config)
143
+ return response
144
+
145
+
146
+ def request_context_passages(question, header):
147
+ try:
148
+ response = requests.request("GET", CONTEXT_API_URL + question, headers=header)
149
+ if response.status_code == 200:
150
+ json_response = response.content.decode("utf-8")
151
+ result = json.loads(json_response)
152
+ else:
153
+ result = {"error": f"Context passage service unavailable, status code={response.status_code}"}
154
+ except requests.exceptions.RequestException as e:
155
+ result = {"error": e}
156
+
157
+ return result
158
+
159
+
160
+ @st.cache(allow_output_mutation=True, show_spinner=False)
161
+ def get_sentence_transformer():
162
+ return SentenceTransformer('all-MiniLM-L6-v2')
163
+
164
+
165
+ @st.cache(allow_output_mutation=True, show_spinner=False)
166
+ def get_sentence_transformer_encoding(sentences):
167
+ model = get_sentence_transformer()
168
+ return model.encode([sentence for sentence in sentences], convert_to_tensor=True)
169
+
170
+
171
+ def sign_jwt() -> Dict[str, str]:
172
+ payload = {
173
+ "expires": time.time() + 6000
174
+ }
175
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
176
+ return token
177
+
178
+
179
+ def extract_sentences_from_passages(passages):
180
+ sentences = []
181
+ for idx, node in enumerate(passages):
182
+ sentences.extend(tokenize.sent_tokenize(node["text"]))
183
+ return sentences
184
+
185
+
186
+ def similarity_color_picker(similarity: float):
187
+ value = int(similarity * 75)
188
+ rgb = colorsys.hsv_to_rgb(value / 300., 1.0, 1.0)
189
+ return [round(255 * x) for x in rgb]
190
+
191
+
192
+ def rgb_to_hex(rgb):
193
+ return '%02x%02x%02x' % tuple(rgb)
194
+
195
+
196
+ def similiarity_to_hex(similarity: float):
197
+ return rgb_to_hex(similarity_color_picker(similarity))
198
+
199
+
200
+ def rerank(question: str, passages: List[str], include_rank: int = 4) -> List[str]:
201
+ ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
202
+ question_passage_combinations = [[question, p["text"]] for p in passages]
203
+
204
+ # Compute the similarity scores for these combinations
205
+ similarity_scores = ce.predict(question_passage_combinations)
206
+
207
+ # Sort the scores in decreasing order
208
+ sim_ranking_idx = np.flip(np.argsort(similarity_scores))
209
+ return [passages[rank_idx] for rank_idx in sim_ranking_idx[:include_rank]]
210
+
211
+
212
+ def answer_to_context_similarity(generated_answer, context_passages, topk=3):
213
+ context_sentences = extract_sentences_from_passages(context_passages)
214
+ context_sentences_e = get_sentence_transformer_encoding(context_sentences)
215
+ answer_sentences = tokenize.sent_tokenize(generated_answer)
216
+ answer_sentences_e = get_sentence_transformer_encoding(answer_sentences)
217
+ search_result = util.semantic_search(answer_sentences_e, context_sentences_e, top_k=topk)
218
+ result = []
219
+ for idx, r in enumerate(search_result):
220
+ context = []
221
+ for idx_c in range(topk):
222
+ context.append({"source": context_sentences[r[idx_c]["corpus_id"]], "score": r[idx_c]["score"]})
223
+ result.append({"answer": answer_sentences[idx], "context": context})
224
+ return result
225
+
226
+
227
+ def post_process_answer(generated_answer):
228
+ result = generated_answer
229
+ # detect sentence boundaries regex pattern
230
+ regex = r"([A-Z][a-z].*?[.:!?](?=$| [A-Z]))"
231
+ answer_sentences = tokenize.sent_tokenize(generated_answer)
232
+ # do we have truncated last sentence?
233
+ if len(answer_sentences) > len(re.findall(regex, generated_answer)):
234
+ drop_last_sentence = " ".join(s for s in answer_sentences[:-1])
235
+ result = drop_last_sentence
236
+ return result.strip()
237
+
238
+
239
+ def format_score(value: float, precision=2):
240
+ return f"{value:.{precision}f}"
241
+
242
+
243
+ @st.cache(allow_output_mutation=True, show_spinner=False)
244
+ def get_answer(question: str):
245
+ if not question:
246
+ return {}
247
+
248
+ resp: Dict[str, str] = {}
249
+ if question and len(question.split()) > 3:
250
+ header = {"Authorization": f"Bearer {sign_jwt()}"}
251
+ context_passages = request_context_passages(question, header)
252
+ if "error" in context_passages:
253
+ resp = context_passages
254
+ else:
255
+ context_passages = rerank(question, context_passages)
256
+ conditioned_context = "<P> " + " <P> ".join([d["text"] for d in context_passages])
257
+ model_input = f'question: {question} context: {conditioned_context}'
258
+
259
+ inference_response = invoke_lfqa(st.session_state["api_lfqa_selector"], model_input, header)
260
+ if "error" in inference_response:
261
+ resp = inference_response
262
+ else:
263
+ resp["context_passages"] = context_passages
264
+ resp["answer"] = post_process_answer(inference_response[0]["generated_text"])
265
+ else:
266
+ resp = {"error": f"A longer, more descriptive question will receive a better answer. '{question}' is too short."}
267
+ return resp
268
+
269
+
270
+ def app():
271
+ with open('style.css') as f:
272
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
273
+ footer = """
274
+ <div class="footer-custom">
275
+ Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
276
+ LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
277
+ Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a> |
278
+ <a href="https://towardsdatascience.com/long-form-qa-beyond-eli5-an-updated-dataset-and-approach-319cb841aabb" target="_blank">Blog</a>
279
+ </div>
280
+ """
281
+ st.markdown(footer, unsafe_allow_html=True)
282
+
283
+ st.title('Wikipedia Assistant')
284
+ st.header('We are migrating to new backend infrastructure. ETA - 15.6.2022')
285
+
286
+ #question = st.text_input(
287
+ # label='Ask Wikipedia an open-ended question below; for example, "Why do airplanes leave contrails in the sky?"')
288
+ question = ""
289
+ spinner = st.empty()
290
+ if question !="":
291
+ spinner.markdown(
292
+ f"""
293
+ <div class="loader-wrapper">
294
+ <div class="loader">
295
+ </div>
296
+ <p>Generating answer for: <b>{question}</b></p>
297
+ </div>
298
+ <label class="loader-note">Answer generation may take up to 20 sec. Please stand by.</label>
299
+ """,
300
+ unsafe_allow_html=True,
301
+ )
302
+
303
+ question_response = get_answer(question)
304
+ if question_response:
305
+ if "error" in question_response:
306
+ st.warning(question_response["error"])
307
+ else:
308
+ spinner.markdown(f"")
309
+ generated_answer = question_response["answer"]
310
+ context_passages = question_response["context_passages"]
311
+ sentence_similarity = answer_to_context_similarity(generated_answer, context_passages, topk=3)
312
+ sentences = "<div class='sentence-wrapper'>"
313
+ for item in sentence_similarity:
314
+ sentences += '<span>'
315
+ score = item["context"][0]["score"]
316
+ support_sentence = item["context"][0]["source"]
317
+ sentences += "".join([
318
+ f' {item["answer"]}',
319
+ f'<span style="background-color: #{similiarity_to_hex(score)}" class="tooltip">',
320
+ f'{format_score(score, precision=1)}',
321
+ f'<span class="tooltiptext"><b>Wikipedia source</b><br><br> {support_sentence} <br><br>Similarity: {format_score(score)}</span>'
322
+ ])
323
+ sentences += '</span>'
324
+ sentences += '</span>'
325
+ st.markdown(sentences, unsafe_allow_html=True)
326
+
327
+ with st.spinner("Generating audio..."):
328
+ if st.session_state["tts"] == "HuggingFace":
329
+ audio_file = hf_tts(generated_answer)
330
+ with open("out.flac", "wb") as f:
331
+ f.write(audio_file)
332
+ else:
333
+ audio_file = google_tts(generated_answer, st.secrets["private_key_id"],
334
+ st.secrets["private_key"], st.secrets["client_email"])
335
+ with open("out.mp3", "wb") as f:
336
+ f.write(audio_file.audio_content)
337
+
338
+ audio_file = "out.flac" if st.session_state["tts"] == "HuggingFace" else "out.mp3"
339
+ st.audio(audio_file)
340
+
341
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
342
+
343
+ model = get_sentence_transformer()
344
+
345
+ col1, col2 = st.columns(2)
346
+
347
+ with col1:
348
+ st.subheader("Context")
349
+ with col2:
350
+ selection = st.selectbox(
351
+ label="",
352
+ options=('Paragraphs', 'Sentences', 'Answer Similarity'),
353
+ help="Context represents Wikipedia passages used to generate the answer")
354
+ question_e = model.encode(question, convert_to_tensor=True)
355
+ if selection == "Paragraphs":
356
+ sentences = extract_sentences_from_passages(context_passages)
357
+ context_e = get_sentence_transformer_encoding(sentences)
358
+ scores = util.cos_sim(question_e.repeat(context_e.shape[0], 1), context_e)
359
+ similarity_scores = scores[0].squeeze().tolist()
360
+ for idx, node in enumerate(context_passages):
361
+ node["answer_similarity"] = "{0:.2f}".format(similarity_scores[idx])
362
+ context_passages = sorted(context_passages, key=lambda x: x["answer_similarity"], reverse=True)
363
+ st.json(context_passages)
364
+ elif selection == "Sentences":
365
+ sentences = extract_sentences_from_passages(context_passages)
366
+ sentences_e = get_sentence_transformer_encoding(sentences)
367
+ scores = util.cos_sim(question_e.repeat(sentences_e.shape[0], 1), sentences_e)
368
+ sentence_similarity_scores = scores[0].squeeze().tolist()
369
+ result = []
370
+ for idx, sentence in enumerate(sentences):
371
+ result.append(
372
+ {"text": sentence, "answer_similarity": "{0:.2f}".format(sentence_similarity_scores[idx])})
373
+ context_sentences = json.dumps(sorted(result, key=lambda x: x["answer_similarity"], reverse=True))
374
+ st.json(context_sentences)
375
+ else:
376
+ st.json(sentence_similarity)
pages/info.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def app():
5
+ with open('style.css') as f:
6
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
7
+ footer = """
8
+ <div class="footer-custom">
9
+ Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
10
+ LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
11
+ Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
12
+ </div>
13
+ """
14
+ st.markdown(footer, unsafe_allow_html=True)
15
+
16
+ st.subheader("Intro")
17
+ intro = """
18
+ <div class="text">
19
+ Wikipedia Assistant is an example of a task usually referred to as the Long-Form Question Answering (LFQA).
20
+ These systems function by querying large document stores for relevant information and subsequently using
21
+ the retrieved documents to generate accurate, multi-sentence answers. The documents related to a given
22
+ query, colloquially called context passages, are not used merely as source tokens for extracted answers,
23
+ but instead provide a larger context for the synthesis of original, abstractive long-form answers.
24
+ LFQA systems usually consist of three components:
25
+ <ul>
26
+ <li>A document store including content passages for a variety of topics</li>
27
+ <li>Encoder models to encode documents/questions such that it is possible to query the document store</li>
28
+ <li>A Seq2Seq language model capable of generating paragraph-long answers when given a question and
29
+ context passages retrieved from the document store</li>
30
+ </ul>
31
+ </div>
32
+ <br>
33
+ """
34
+ st.markdown(intro, unsafe_allow_html=True)
35
+ st.image("lfqa.png", caption="LFQA Architecture")
36
+ st.subheader("UI/UX")
37
+ st.write("Each sentence in the generated answer ends with a coloured tooltip; the colour ranges from red to green. "
38
+ "The tooltip contains a value representing answer sentence similarity to a specific sentence in the "
39
+ "Wikipedia context passages retrieved. Mouseover on the tooltip will show the sentence from the "
40
+ "Wikipedia context passage. If a sentence similarity is 1.0, the seq2seq model extracted and "
41
+ "copied the sentence verbatim from Wikipedia context passages. Lower values of sentence "
42
+ "similarity indicate the seq2seq model is struggling to generate a relevant sentence for the question "
43
+ "asked.")
44
+ st.image("wikipedia_answer.png", caption="Answer with similarity tooltips")
45
+ st.write("Below the generated answer are question-related Wikipedia context paragraphs (passages). One can view "
46
+ "these passages in a raw format retrieved using the 'Paragraphs' select menu option. The 'Sentences' menu "
47
+ "option shows the same paragraphs but on a sentence level. Finally, the 'Answer Similarity' menu option "
48
+ "shows the most similar three sentences from context paragraphs to each sentence in the generated answer.")
49
+ st.image("wikipedia_context.png", caption="Context paragraphs (passages)")
50
+
51
+ tts = """
52
+ <div class="text">
53
+ Wikipedia Assistant converts the text-based answer to speech via either Google text-to-speech engine or
54
+ <a href="https://github.com/espnet" target=_blank">Espnet model</a> hosted on
55
+ <a href="https://huggingface.co/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan" target=_blank">
56
+ HuggingFace hub</a>
57
+ <br>
58
+ <br>
59
+ """
60
+ st.markdown(tts, unsafe_allow_html=True)
61
+
62
+ st.subheader("Tips")
63
+ tips = """
64
+ <div class="text">
65
+ LFQA task is far from solved. Wikipedia Assistant will sometimes generate an answer unrelated to a question asked,
66
+ even downright wrong. However, if the question is elaborate and more specific, there is a decent chance of
67
+ getting a legible answer. LFQA systems are targeting ELI5 non-factoid type of questions. A general guideline
68
+ is - questions starting with why, what, and how are better suited than where and who questions. Be elaborate.
69
+ <br><br>
70
+ For example, to ask a history-based question, Wikipedia Assistant is better suited to answer the question:
71
+ "What was the objective of the German commando raid on Drvar in Bosnia during the Second World War?" than
72
+ "Why did Germans raid Drvar?". A precise science question like "Why do airplane jet engines leave contrails
73
+ in the sky?" has a good chance of getting a decent answer. Detailed and precise questions are more likely to
74
+ match the right half a dozen relevant passages in a 20+ GB Wikipedia dump to construct a good answer.
75
+ </div>
76
+ <br>
77
+ """
78
+ st.markdown(tips, unsafe_allow_html=True)
79
+ st.subheader("Technical details")
80
+ techinical_intro = """
81
+ <div class="text technical-details-info">
82
+ A question asked will be encoded with an <a href="https://huggingface.co/vblagoje/dpr-question_encoder-single-lfqa-wiki" target=_blank">encoder</a>
83
+ and sent to a server to find the most relevant Wikipedia passages. The Wikipedia <a href="https://huggingface.co/datasets/kilt_wikipedia" target=_blank">passages</a>
84
+ were previously encoded using a passage <a href="https://huggingface.co/vblagoje/dpr-ctx_encoder-single-lfqa-wiki" target=_blank">encoder</a> and
85
+ stored in the <a href="https://github.com/facebookresearch/faiss" target=_blank">Faiss</a> index. The question matching passages (a.k.a context passages) are retrieved from the Faiss
86
+ index and passed to a BART-based seq2seq <a href="https://huggingface.co/vblagoje/bart_lfqa" target=_blank">model</a> to
87
+ synthesize an original answer to the question.
88
+
89
+ </div>
90
+ """
91
+ st.markdown(techinical_intro, unsafe_allow_html=True)
92
+
pages/settings.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ settings = {}
4
+
5
+ def app():
6
+ st.markdown("""
7
+ <style>
8
+ div[data-testid="stForm"] {
9
+ border: 0;
10
+ }
11
+ .footer-custom {
12
+ position: fixed;
13
+ bottom: 0;
14
+ width: 100%;
15
+ color: var(--text-color);
16
+ max-width: 698px;
17
+ font-size: 14px;
18
+ height: 50px;
19
+ padding: 10px 0;
20
+ z-index: 50;
21
+ }
22
+ footer {
23
+ display: none !important;
24
+ }
25
+ .footer-custom a {
26
+ color: var(--text-color);
27
+ }
28
+ button[kind="formSubmit"]{
29
+ margin-top: 40px;
30
+ border-radius: 20px;
31
+ padding: 5px 20px;
32
+ font-size: 18px;
33
+ background-color: var(--primary-color);
34
+ }
35
+ #lfqa-model-parameters {
36
+ margin-bottom: 50px;
37
+ font-size: 36px;
38
+ }
39
+ #tts-model-parameters {
40
+ font-size: 36px;
41
+ margin-top: 50px;
42
+ }
43
+ .stAlert {
44
+ width: 250px;
45
+ margin-top: 32px;
46
+ }
47
+ </style>
48
+ """, unsafe_allow_html=True)
49
+
50
+ with st.form("settings"):
51
+ footer = """
52
+ <div class="footer-custom">
53
+ Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
54
+ LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
55
+ Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
56
+ </div>
57
+ """
58
+ st.markdown(footer, unsafe_allow_html=True)
59
+
60
+ st.title("LFQA model parameters")
61
+
62
+ settings["min_length"] = st.slider("Min length", 20, 80, st.session_state["min_length"],
63
+ help="Min response length (words)")
64
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
65
+ settings["max_length"] = st.slider("Max length", 128, 320, st.session_state["max_length"],
66
+ help="Max response length (words)")
67
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
68
+ col1, col2 = st.columns(2)
69
+ with col1:
70
+ settings["do_sample"] = st.checkbox("Use sampling", st.session_state["do_sample"],
71
+ help="Whether or not to use sampling ; use greedy decoding otherwise.")
72
+ with col2:
73
+ settings["early_stopping"] = st.checkbox("Early stopping", st.session_state["early_stopping"],
74
+ help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not.")
75
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
76
+ settings["num_beams"] = st.slider("Num beams", 1, 16, st.session_state["num_beams"],
77
+ help="Number of beams for beam search. 1 means no beam search.")
78
+ st.markdown("""<hr></hr>""", unsafe_allow_html=True)
79
+ settings["temperature"] = st.slider("Temperature", 0.0, 1.0, st.session_state["temperature"], step=0.1,
80
+ help="The value used to module the next token probabilities")
81
+
82
+ st.title("TTS model parameters")
83
+ settings["tts"] = st.selectbox(label="Engine", options=("Google", "HuggingFace"),
84
+ index=["Google", "HuggingFace"].index(st.session_state["tts"]),
85
+ help="Answer text-to-speech engine")
86
+
87
+ # Every form must have a submit button.
88
+ col3, col4, col5, col6 = st.columns(4)
89
+ with col3:
90
+ submitted = st.form_submit_button("Save")
91
+ with col4:
92
+ if submitted:
93
+ for k, v in settings.items():
94
+ st.session_state[k] = v
95
+ st.success('App settings saved successfully.')