Spaces:
Build error
Build error
Achyut Tiwari
commited on
Commit
•
83ede0c
1
Parent(s):
468d439
Add files via upload
Browse files- pages/ask.py +376 -0
- pages/info.py +92 -0
- pages/settings.py +95 -0
pages/ask.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import colorsys
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
|
6 |
+
import nltk
|
7 |
+
import numpy as np
|
8 |
+
from nltk import tokenize
|
9 |
+
|
10 |
+
nltk.download('punkt')
|
11 |
+
from google.oauth2 import service_account
|
12 |
+
from google.cloud import texttospeech
|
13 |
+
|
14 |
+
from typing import Dict, Optional, List
|
15 |
+
|
16 |
+
import jwt
|
17 |
+
import requests
|
18 |
+
import streamlit as st
|
19 |
+
from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
20 |
+
|
21 |
+
JWT_SECRET = st.secrets["api_secret"]
|
22 |
+
JWT_ALGORITHM = st.secrets["api_algorithm"]
|
23 |
+
INFERENCE_TOKEN = st.secrets["api_inference"]
|
24 |
+
CONTEXT_API_URL = st.secrets["api_context"]
|
25 |
+
LFQA_API_URL = st.secrets["api_lfqa"]
|
26 |
+
|
27 |
+
headers = {"Authorization": f"Bearer {INFERENCE_TOKEN}"}
|
28 |
+
API_URL = "https://api-inference.huggingface.co/models/vblagoje/bart_lfqa"
|
29 |
+
API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan"
|
30 |
+
|
31 |
+
|
32 |
+
def api_inference_lfqa(model_input: str):
|
33 |
+
payload = {
|
34 |
+
"inputs": model_input,
|
35 |
+
"parameters": {
|
36 |
+
"truncation": "longest_first",
|
37 |
+
"min_length": st.session_state["min_length"],
|
38 |
+
"max_length": st.session_state["max_length"],
|
39 |
+
"do_sample": st.session_state["do_sample"],
|
40 |
+
"early_stopping": st.session_state["early_stopping"],
|
41 |
+
"num_beams": st.session_state["num_beams"],
|
42 |
+
"temperature": st.session_state["temperature"],
|
43 |
+
"top_k": None,
|
44 |
+
"top_p": None,
|
45 |
+
"no_repeat_ngram_size": 3,
|
46 |
+
"num_return_sequences": 1
|
47 |
+
},
|
48 |
+
"options": {
|
49 |
+
"wait_for_model": True
|
50 |
+
}
|
51 |
+
}
|
52 |
+
data = json.dumps(payload)
|
53 |
+
response = requests.request("POST", API_URL, headers=headers, data=data)
|
54 |
+
return json.loads(response.content.decode("utf-8"))
|
55 |
+
|
56 |
+
|
57 |
+
def inference_lfqa(model_input: str, header: dict):
|
58 |
+
payload = {
|
59 |
+
"model_input": model_input,
|
60 |
+
"parameters": {
|
61 |
+
"min_length": st.session_state["min_length"],
|
62 |
+
"max_length": st.session_state["max_length"],
|
63 |
+
"do_sample": st.session_state["do_sample"],
|
64 |
+
"early_stopping": st.session_state["early_stopping"],
|
65 |
+
"num_beams": st.session_state["num_beams"],
|
66 |
+
"temperature": st.session_state["temperature"],
|
67 |
+
"top_k": None,
|
68 |
+
"top_p": None,
|
69 |
+
"no_repeat_ngram_size": 3,
|
70 |
+
"num_return_sequences": 1
|
71 |
+
}
|
72 |
+
}
|
73 |
+
data = json.dumps(payload)
|
74 |
+
try:
|
75 |
+
response = requests.request("POST", LFQA_API_URL, headers=header, data=data)
|
76 |
+
if response.status_code == 200:
|
77 |
+
json_response = response.content.decode("utf-8")
|
78 |
+
result = json.loads(json_response)
|
79 |
+
else:
|
80 |
+
result = {"error": f"LFQA service unavailable, status code={response.status_code}"}
|
81 |
+
except requests.exceptions.RequestException as e:
|
82 |
+
result = {"error": e}
|
83 |
+
return result
|
84 |
+
|
85 |
+
|
86 |
+
def invoke_lfqa(service_backend: str, model_input: str, header: Optional[dict]):
|
87 |
+
if "HuggingFace" == service_backend:
|
88 |
+
inference_response = api_inference_lfqa(model_input)
|
89 |
+
else:
|
90 |
+
inference_response = inference_lfqa(model_input, header)
|
91 |
+
return inference_response
|
92 |
+
|
93 |
+
|
94 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
95 |
+
def hf_tts(text: str):
|
96 |
+
payload = {
|
97 |
+
"inputs": text,
|
98 |
+
"parameters": {
|
99 |
+
"vocoder_tag": "str_or_none(none)",
|
100 |
+
"threshold": 0.5,
|
101 |
+
"minlenratio": 0.0,
|
102 |
+
"maxlenratio": 10.0,
|
103 |
+
"use_att_constraint": False,
|
104 |
+
"backward_window": 1,
|
105 |
+
"forward_window": 3,
|
106 |
+
"speed_control_alpha": 1.0,
|
107 |
+
"noise_scale": 0.333,
|
108 |
+
"noise_scale_dur": 0.333
|
109 |
+
},
|
110 |
+
"options": {
|
111 |
+
"wait_for_model": True
|
112 |
+
}
|
113 |
+
}
|
114 |
+
data = json.dumps(payload)
|
115 |
+
response = requests.request("POST", API_URL_TTS, headers=headers, data=data)
|
116 |
+
return response.content
|
117 |
+
|
118 |
+
|
119 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
120 |
+
def google_tts(text: str, private_key_id: str, private_key: str, client_email: str):
|
121 |
+
config = {
|
122 |
+
"private_key_id": private_key_id,
|
123 |
+
"private_key": f"-----BEGIN PRIVATE KEY-----\n{private_key}\n-----END PRIVATE KEY-----\n",
|
124 |
+
"client_email": client_email,
|
125 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
126 |
+
}
|
127 |
+
credentials = service_account.Credentials.from_service_account_info(config)
|
128 |
+
client = texttospeech.TextToSpeechClient(credentials=credentials)
|
129 |
+
|
130 |
+
synthesis_input = texttospeech.SynthesisInput(text=text)
|
131 |
+
|
132 |
+
# Build the voice request, select the language code ("en-US") and the ssml
|
133 |
+
# voice gender ("neutral")
|
134 |
+
voice = texttospeech.VoiceSelectionParams(language_code="en-US",
|
135 |
+
ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL)
|
136 |
+
|
137 |
+
# Select the type of audio file you want returned
|
138 |
+
audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3)
|
139 |
+
|
140 |
+
# Perform the text-to-speech request on the text input with the selected
|
141 |
+
# voice parameters and audio file type
|
142 |
+
response = client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config)
|
143 |
+
return response
|
144 |
+
|
145 |
+
|
146 |
+
def request_context_passages(question, header):
|
147 |
+
try:
|
148 |
+
response = requests.request("GET", CONTEXT_API_URL + question, headers=header)
|
149 |
+
if response.status_code == 200:
|
150 |
+
json_response = response.content.decode("utf-8")
|
151 |
+
result = json.loads(json_response)
|
152 |
+
else:
|
153 |
+
result = {"error": f"Context passage service unavailable, status code={response.status_code}"}
|
154 |
+
except requests.exceptions.RequestException as e:
|
155 |
+
result = {"error": e}
|
156 |
+
|
157 |
+
return result
|
158 |
+
|
159 |
+
|
160 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
161 |
+
def get_sentence_transformer():
|
162 |
+
return SentenceTransformer('all-MiniLM-L6-v2')
|
163 |
+
|
164 |
+
|
165 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
166 |
+
def get_sentence_transformer_encoding(sentences):
|
167 |
+
model = get_sentence_transformer()
|
168 |
+
return model.encode([sentence for sentence in sentences], convert_to_tensor=True)
|
169 |
+
|
170 |
+
|
171 |
+
def sign_jwt() -> Dict[str, str]:
|
172 |
+
payload = {
|
173 |
+
"expires": time.time() + 6000
|
174 |
+
}
|
175 |
+
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
176 |
+
return token
|
177 |
+
|
178 |
+
|
179 |
+
def extract_sentences_from_passages(passages):
|
180 |
+
sentences = []
|
181 |
+
for idx, node in enumerate(passages):
|
182 |
+
sentences.extend(tokenize.sent_tokenize(node["text"]))
|
183 |
+
return sentences
|
184 |
+
|
185 |
+
|
186 |
+
def similarity_color_picker(similarity: float):
|
187 |
+
value = int(similarity * 75)
|
188 |
+
rgb = colorsys.hsv_to_rgb(value / 300., 1.0, 1.0)
|
189 |
+
return [round(255 * x) for x in rgb]
|
190 |
+
|
191 |
+
|
192 |
+
def rgb_to_hex(rgb):
|
193 |
+
return '%02x%02x%02x' % tuple(rgb)
|
194 |
+
|
195 |
+
|
196 |
+
def similiarity_to_hex(similarity: float):
|
197 |
+
return rgb_to_hex(similarity_color_picker(similarity))
|
198 |
+
|
199 |
+
|
200 |
+
def rerank(question: str, passages: List[str], include_rank: int = 4) -> List[str]:
|
201 |
+
ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
202 |
+
question_passage_combinations = [[question, p["text"]] for p in passages]
|
203 |
+
|
204 |
+
# Compute the similarity scores for these combinations
|
205 |
+
similarity_scores = ce.predict(question_passage_combinations)
|
206 |
+
|
207 |
+
# Sort the scores in decreasing order
|
208 |
+
sim_ranking_idx = np.flip(np.argsort(similarity_scores))
|
209 |
+
return [passages[rank_idx] for rank_idx in sim_ranking_idx[:include_rank]]
|
210 |
+
|
211 |
+
|
212 |
+
def answer_to_context_similarity(generated_answer, context_passages, topk=3):
|
213 |
+
context_sentences = extract_sentences_from_passages(context_passages)
|
214 |
+
context_sentences_e = get_sentence_transformer_encoding(context_sentences)
|
215 |
+
answer_sentences = tokenize.sent_tokenize(generated_answer)
|
216 |
+
answer_sentences_e = get_sentence_transformer_encoding(answer_sentences)
|
217 |
+
search_result = util.semantic_search(answer_sentences_e, context_sentences_e, top_k=topk)
|
218 |
+
result = []
|
219 |
+
for idx, r in enumerate(search_result):
|
220 |
+
context = []
|
221 |
+
for idx_c in range(topk):
|
222 |
+
context.append({"source": context_sentences[r[idx_c]["corpus_id"]], "score": r[idx_c]["score"]})
|
223 |
+
result.append({"answer": answer_sentences[idx], "context": context})
|
224 |
+
return result
|
225 |
+
|
226 |
+
|
227 |
+
def post_process_answer(generated_answer):
|
228 |
+
result = generated_answer
|
229 |
+
# detect sentence boundaries regex pattern
|
230 |
+
regex = r"([A-Z][a-z].*?[.:!?](?=$| [A-Z]))"
|
231 |
+
answer_sentences = tokenize.sent_tokenize(generated_answer)
|
232 |
+
# do we have truncated last sentence?
|
233 |
+
if len(answer_sentences) > len(re.findall(regex, generated_answer)):
|
234 |
+
drop_last_sentence = " ".join(s for s in answer_sentences[:-1])
|
235 |
+
result = drop_last_sentence
|
236 |
+
return result.strip()
|
237 |
+
|
238 |
+
|
239 |
+
def format_score(value: float, precision=2):
|
240 |
+
return f"{value:.{precision}f}"
|
241 |
+
|
242 |
+
|
243 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
244 |
+
def get_answer(question: str):
|
245 |
+
if not question:
|
246 |
+
return {}
|
247 |
+
|
248 |
+
resp: Dict[str, str] = {}
|
249 |
+
if question and len(question.split()) > 3:
|
250 |
+
header = {"Authorization": f"Bearer {sign_jwt()}"}
|
251 |
+
context_passages = request_context_passages(question, header)
|
252 |
+
if "error" in context_passages:
|
253 |
+
resp = context_passages
|
254 |
+
else:
|
255 |
+
context_passages = rerank(question, context_passages)
|
256 |
+
conditioned_context = "<P> " + " <P> ".join([d["text"] for d in context_passages])
|
257 |
+
model_input = f'question: {question} context: {conditioned_context}'
|
258 |
+
|
259 |
+
inference_response = invoke_lfqa(st.session_state["api_lfqa_selector"], model_input, header)
|
260 |
+
if "error" in inference_response:
|
261 |
+
resp = inference_response
|
262 |
+
else:
|
263 |
+
resp["context_passages"] = context_passages
|
264 |
+
resp["answer"] = post_process_answer(inference_response[0]["generated_text"])
|
265 |
+
else:
|
266 |
+
resp = {"error": f"A longer, more descriptive question will receive a better answer. '{question}' is too short."}
|
267 |
+
return resp
|
268 |
+
|
269 |
+
|
270 |
+
def app():
|
271 |
+
with open('style.css') as f:
|
272 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
273 |
+
footer = """
|
274 |
+
<div class="footer-custom">
|
275 |
+
Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
|
276 |
+
LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
|
277 |
+
Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a> |
|
278 |
+
<a href="https://towardsdatascience.com/long-form-qa-beyond-eli5-an-updated-dataset-and-approach-319cb841aabb" target="_blank">Blog</a>
|
279 |
+
</div>
|
280 |
+
"""
|
281 |
+
st.markdown(footer, unsafe_allow_html=True)
|
282 |
+
|
283 |
+
st.title('Wikipedia Assistant')
|
284 |
+
st.header('We are migrating to new backend infrastructure. ETA - 15.6.2022')
|
285 |
+
|
286 |
+
#question = st.text_input(
|
287 |
+
# label='Ask Wikipedia an open-ended question below; for example, "Why do airplanes leave contrails in the sky?"')
|
288 |
+
question = ""
|
289 |
+
spinner = st.empty()
|
290 |
+
if question !="":
|
291 |
+
spinner.markdown(
|
292 |
+
f"""
|
293 |
+
<div class="loader-wrapper">
|
294 |
+
<div class="loader">
|
295 |
+
</div>
|
296 |
+
<p>Generating answer for: <b>{question}</b></p>
|
297 |
+
</div>
|
298 |
+
<label class="loader-note">Answer generation may take up to 20 sec. Please stand by.</label>
|
299 |
+
""",
|
300 |
+
unsafe_allow_html=True,
|
301 |
+
)
|
302 |
+
|
303 |
+
question_response = get_answer(question)
|
304 |
+
if question_response:
|
305 |
+
if "error" in question_response:
|
306 |
+
st.warning(question_response["error"])
|
307 |
+
else:
|
308 |
+
spinner.markdown(f"")
|
309 |
+
generated_answer = question_response["answer"]
|
310 |
+
context_passages = question_response["context_passages"]
|
311 |
+
sentence_similarity = answer_to_context_similarity(generated_answer, context_passages, topk=3)
|
312 |
+
sentences = "<div class='sentence-wrapper'>"
|
313 |
+
for item in sentence_similarity:
|
314 |
+
sentences += '<span>'
|
315 |
+
score = item["context"][0]["score"]
|
316 |
+
support_sentence = item["context"][0]["source"]
|
317 |
+
sentences += "".join([
|
318 |
+
f' {item["answer"]}',
|
319 |
+
f'<span style="background-color: #{similiarity_to_hex(score)}" class="tooltip">',
|
320 |
+
f'{format_score(score, precision=1)}',
|
321 |
+
f'<span class="tooltiptext"><b>Wikipedia source</b><br><br> {support_sentence} <br><br>Similarity: {format_score(score)}</span>'
|
322 |
+
])
|
323 |
+
sentences += '</span>'
|
324 |
+
sentences += '</span>'
|
325 |
+
st.markdown(sentences, unsafe_allow_html=True)
|
326 |
+
|
327 |
+
with st.spinner("Generating audio..."):
|
328 |
+
if st.session_state["tts"] == "HuggingFace":
|
329 |
+
audio_file = hf_tts(generated_answer)
|
330 |
+
with open("out.flac", "wb") as f:
|
331 |
+
f.write(audio_file)
|
332 |
+
else:
|
333 |
+
audio_file = google_tts(generated_answer, st.secrets["private_key_id"],
|
334 |
+
st.secrets["private_key"], st.secrets["client_email"])
|
335 |
+
with open("out.mp3", "wb") as f:
|
336 |
+
f.write(audio_file.audio_content)
|
337 |
+
|
338 |
+
audio_file = "out.flac" if st.session_state["tts"] == "HuggingFace" else "out.mp3"
|
339 |
+
st.audio(audio_file)
|
340 |
+
|
341 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
342 |
+
|
343 |
+
model = get_sentence_transformer()
|
344 |
+
|
345 |
+
col1, col2 = st.columns(2)
|
346 |
+
|
347 |
+
with col1:
|
348 |
+
st.subheader("Context")
|
349 |
+
with col2:
|
350 |
+
selection = st.selectbox(
|
351 |
+
label="",
|
352 |
+
options=('Paragraphs', 'Sentences', 'Answer Similarity'),
|
353 |
+
help="Context represents Wikipedia passages used to generate the answer")
|
354 |
+
question_e = model.encode(question, convert_to_tensor=True)
|
355 |
+
if selection == "Paragraphs":
|
356 |
+
sentences = extract_sentences_from_passages(context_passages)
|
357 |
+
context_e = get_sentence_transformer_encoding(sentences)
|
358 |
+
scores = util.cos_sim(question_e.repeat(context_e.shape[0], 1), context_e)
|
359 |
+
similarity_scores = scores[0].squeeze().tolist()
|
360 |
+
for idx, node in enumerate(context_passages):
|
361 |
+
node["answer_similarity"] = "{0:.2f}".format(similarity_scores[idx])
|
362 |
+
context_passages = sorted(context_passages, key=lambda x: x["answer_similarity"], reverse=True)
|
363 |
+
st.json(context_passages)
|
364 |
+
elif selection == "Sentences":
|
365 |
+
sentences = extract_sentences_from_passages(context_passages)
|
366 |
+
sentences_e = get_sentence_transformer_encoding(sentences)
|
367 |
+
scores = util.cos_sim(question_e.repeat(sentences_e.shape[0], 1), sentences_e)
|
368 |
+
sentence_similarity_scores = scores[0].squeeze().tolist()
|
369 |
+
result = []
|
370 |
+
for idx, sentence in enumerate(sentences):
|
371 |
+
result.append(
|
372 |
+
{"text": sentence, "answer_similarity": "{0:.2f}".format(sentence_similarity_scores[idx])})
|
373 |
+
context_sentences = json.dumps(sorted(result, key=lambda x: x["answer_similarity"], reverse=True))
|
374 |
+
st.json(context_sentences)
|
375 |
+
else:
|
376 |
+
st.json(sentence_similarity)
|
pages/info.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def app():
|
5 |
+
with open('style.css') as f:
|
6 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
7 |
+
footer = """
|
8 |
+
<div class="footer-custom">
|
9 |
+
Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
|
10 |
+
LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
|
11 |
+
Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
|
12 |
+
</div>
|
13 |
+
"""
|
14 |
+
st.markdown(footer, unsafe_allow_html=True)
|
15 |
+
|
16 |
+
st.subheader("Intro")
|
17 |
+
intro = """
|
18 |
+
<div class="text">
|
19 |
+
Wikipedia Assistant is an example of a task usually referred to as the Long-Form Question Answering (LFQA).
|
20 |
+
These systems function by querying large document stores for relevant information and subsequently using
|
21 |
+
the retrieved documents to generate accurate, multi-sentence answers. The documents related to a given
|
22 |
+
query, colloquially called context passages, are not used merely as source tokens for extracted answers,
|
23 |
+
but instead provide a larger context for the synthesis of original, abstractive long-form answers.
|
24 |
+
LFQA systems usually consist of three components:
|
25 |
+
<ul>
|
26 |
+
<li>A document store including content passages for a variety of topics</li>
|
27 |
+
<li>Encoder models to encode documents/questions such that it is possible to query the document store</li>
|
28 |
+
<li>A Seq2Seq language model capable of generating paragraph-long answers when given a question and
|
29 |
+
context passages retrieved from the document store</li>
|
30 |
+
</ul>
|
31 |
+
</div>
|
32 |
+
<br>
|
33 |
+
"""
|
34 |
+
st.markdown(intro, unsafe_allow_html=True)
|
35 |
+
st.image("lfqa.png", caption="LFQA Architecture")
|
36 |
+
st.subheader("UI/UX")
|
37 |
+
st.write("Each sentence in the generated answer ends with a coloured tooltip; the colour ranges from red to green. "
|
38 |
+
"The tooltip contains a value representing answer sentence similarity to a specific sentence in the "
|
39 |
+
"Wikipedia context passages retrieved. Mouseover on the tooltip will show the sentence from the "
|
40 |
+
"Wikipedia context passage. If a sentence similarity is 1.0, the seq2seq model extracted and "
|
41 |
+
"copied the sentence verbatim from Wikipedia context passages. Lower values of sentence "
|
42 |
+
"similarity indicate the seq2seq model is struggling to generate a relevant sentence for the question "
|
43 |
+
"asked.")
|
44 |
+
st.image("wikipedia_answer.png", caption="Answer with similarity tooltips")
|
45 |
+
st.write("Below the generated answer are question-related Wikipedia context paragraphs (passages). One can view "
|
46 |
+
"these passages in a raw format retrieved using the 'Paragraphs' select menu option. The 'Sentences' menu "
|
47 |
+
"option shows the same paragraphs but on a sentence level. Finally, the 'Answer Similarity' menu option "
|
48 |
+
"shows the most similar three sentences from context paragraphs to each sentence in the generated answer.")
|
49 |
+
st.image("wikipedia_context.png", caption="Context paragraphs (passages)")
|
50 |
+
|
51 |
+
tts = """
|
52 |
+
<div class="text">
|
53 |
+
Wikipedia Assistant converts the text-based answer to speech via either Google text-to-speech engine or
|
54 |
+
<a href="https://github.com/espnet" target=_blank">Espnet model</a> hosted on
|
55 |
+
<a href="https://huggingface.co/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan" target=_blank">
|
56 |
+
HuggingFace hub</a>
|
57 |
+
<br>
|
58 |
+
<br>
|
59 |
+
"""
|
60 |
+
st.markdown(tts, unsafe_allow_html=True)
|
61 |
+
|
62 |
+
st.subheader("Tips")
|
63 |
+
tips = """
|
64 |
+
<div class="text">
|
65 |
+
LFQA task is far from solved. Wikipedia Assistant will sometimes generate an answer unrelated to a question asked,
|
66 |
+
even downright wrong. However, if the question is elaborate and more specific, there is a decent chance of
|
67 |
+
getting a legible answer. LFQA systems are targeting ELI5 non-factoid type of questions. A general guideline
|
68 |
+
is - questions starting with why, what, and how are better suited than where and who questions. Be elaborate.
|
69 |
+
<br><br>
|
70 |
+
For example, to ask a history-based question, Wikipedia Assistant is better suited to answer the question:
|
71 |
+
"What was the objective of the German commando raid on Drvar in Bosnia during the Second World War?" than
|
72 |
+
"Why did Germans raid Drvar?". A precise science question like "Why do airplane jet engines leave contrails
|
73 |
+
in the sky?" has a good chance of getting a decent answer. Detailed and precise questions are more likely to
|
74 |
+
match the right half a dozen relevant passages in a 20+ GB Wikipedia dump to construct a good answer.
|
75 |
+
</div>
|
76 |
+
<br>
|
77 |
+
"""
|
78 |
+
st.markdown(tips, unsafe_allow_html=True)
|
79 |
+
st.subheader("Technical details")
|
80 |
+
techinical_intro = """
|
81 |
+
<div class="text technical-details-info">
|
82 |
+
A question asked will be encoded with an <a href="https://huggingface.co/vblagoje/dpr-question_encoder-single-lfqa-wiki" target=_blank">encoder</a>
|
83 |
+
and sent to a server to find the most relevant Wikipedia passages. The Wikipedia <a href="https://huggingface.co/datasets/kilt_wikipedia" target=_blank">passages</a>
|
84 |
+
were previously encoded using a passage <a href="https://huggingface.co/vblagoje/dpr-ctx_encoder-single-lfqa-wiki" target=_blank">encoder</a> and
|
85 |
+
stored in the <a href="https://github.com/facebookresearch/faiss" target=_blank">Faiss</a> index. The question matching passages (a.k.a context passages) are retrieved from the Faiss
|
86 |
+
index and passed to a BART-based seq2seq <a href="https://huggingface.co/vblagoje/bart_lfqa" target=_blank">model</a> to
|
87 |
+
synthesize an original answer to the question.
|
88 |
+
|
89 |
+
</div>
|
90 |
+
"""
|
91 |
+
st.markdown(techinical_intro, unsafe_allow_html=True)
|
92 |
+
|
pages/settings.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
settings = {}
|
4 |
+
|
5 |
+
def app():
|
6 |
+
st.markdown("""
|
7 |
+
<style>
|
8 |
+
div[data-testid="stForm"] {
|
9 |
+
border: 0;
|
10 |
+
}
|
11 |
+
.footer-custom {
|
12 |
+
position: fixed;
|
13 |
+
bottom: 0;
|
14 |
+
width: 100%;
|
15 |
+
color: var(--text-color);
|
16 |
+
max-width: 698px;
|
17 |
+
font-size: 14px;
|
18 |
+
height: 50px;
|
19 |
+
padding: 10px 0;
|
20 |
+
z-index: 50;
|
21 |
+
}
|
22 |
+
footer {
|
23 |
+
display: none !important;
|
24 |
+
}
|
25 |
+
.footer-custom a {
|
26 |
+
color: var(--text-color);
|
27 |
+
}
|
28 |
+
button[kind="formSubmit"]{
|
29 |
+
margin-top: 40px;
|
30 |
+
border-radius: 20px;
|
31 |
+
padding: 5px 20px;
|
32 |
+
font-size: 18px;
|
33 |
+
background-color: var(--primary-color);
|
34 |
+
}
|
35 |
+
#lfqa-model-parameters {
|
36 |
+
margin-bottom: 50px;
|
37 |
+
font-size: 36px;
|
38 |
+
}
|
39 |
+
#tts-model-parameters {
|
40 |
+
font-size: 36px;
|
41 |
+
margin-top: 50px;
|
42 |
+
}
|
43 |
+
.stAlert {
|
44 |
+
width: 250px;
|
45 |
+
margin-top: 32px;
|
46 |
+
}
|
47 |
+
</style>
|
48 |
+
""", unsafe_allow_html=True)
|
49 |
+
|
50 |
+
with st.form("settings"):
|
51 |
+
footer = """
|
52 |
+
<div class="footer-custom">
|
53 |
+
Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> |
|
54 |
+
LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> |
|
55 |
+
Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a>
|
56 |
+
</div>
|
57 |
+
"""
|
58 |
+
st.markdown(footer, unsafe_allow_html=True)
|
59 |
+
|
60 |
+
st.title("LFQA model parameters")
|
61 |
+
|
62 |
+
settings["min_length"] = st.slider("Min length", 20, 80, st.session_state["min_length"],
|
63 |
+
help="Min response length (words)")
|
64 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
65 |
+
settings["max_length"] = st.slider("Max length", 128, 320, st.session_state["max_length"],
|
66 |
+
help="Max response length (words)")
|
67 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
68 |
+
col1, col2 = st.columns(2)
|
69 |
+
with col1:
|
70 |
+
settings["do_sample"] = st.checkbox("Use sampling", st.session_state["do_sample"],
|
71 |
+
help="Whether or not to use sampling ; use greedy decoding otherwise.")
|
72 |
+
with col2:
|
73 |
+
settings["early_stopping"] = st.checkbox("Early stopping", st.session_state["early_stopping"],
|
74 |
+
help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not.")
|
75 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
76 |
+
settings["num_beams"] = st.slider("Num beams", 1, 16, st.session_state["num_beams"],
|
77 |
+
help="Number of beams for beam search. 1 means no beam search.")
|
78 |
+
st.markdown("""<hr></hr>""", unsafe_allow_html=True)
|
79 |
+
settings["temperature"] = st.slider("Temperature", 0.0, 1.0, st.session_state["temperature"], step=0.1,
|
80 |
+
help="The value used to module the next token probabilities")
|
81 |
+
|
82 |
+
st.title("TTS model parameters")
|
83 |
+
settings["tts"] = st.selectbox(label="Engine", options=("Google", "HuggingFace"),
|
84 |
+
index=["Google", "HuggingFace"].index(st.session_state["tts"]),
|
85 |
+
help="Answer text-to-speech engine")
|
86 |
+
|
87 |
+
# Every form must have a submit button.
|
88 |
+
col3, col4, col5, col6 = st.columns(4)
|
89 |
+
with col3:
|
90 |
+
submitted = st.form_submit_button("Save")
|
91 |
+
with col4:
|
92 |
+
if submitted:
|
93 |
+
for k, v in settings.items():
|
94 |
+
st.session_state[k] = v
|
95 |
+
st.success('App settings saved successfully.')
|