|
import colorsys |
|
import json |
|
import re |
|
import time |
|
|
|
import nltk |
|
import numpy as np |
|
from nltk import tokenize |
|
|
|
nltk.download('punkt') |
|
from google.oauth2 import service_account |
|
from google.cloud import texttospeech |
|
|
|
from typing import Dict, Optional, List |
|
|
|
import jwt |
|
import requests |
|
import streamlit as st |
|
from sentence_transformers import SentenceTransformer, util, CrossEncoder |
|
|
|
JWT_SECRET = st.secrets["api_secret"] |
|
JWT_ALGORITHM = st.secrets["api_algorithm"] |
|
INFERENCE_TOKEN = st.secrets["api_inference"] |
|
CONTEXT_API_URL = st.secrets["api_context"] |
|
LFQA_API_URL = st.secrets["api_lfqa"] |
|
|
|
headers = {"Authorization": f"Bearer {INFERENCE_TOKEN}"} |
|
API_URL = "https://api-inference.huggingface.co/models/vblagoje/bart_lfqa" |
|
API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan" |
|
|
|
|
|
def api_inference_lfqa(model_input: str): |
|
payload = { |
|
"inputs": model_input, |
|
"parameters": { |
|
"truncation": "longest_first", |
|
"min_length": st.session_state["min_length"], |
|
"max_length": st.session_state["max_length"], |
|
"do_sample": st.session_state["do_sample"], |
|
"early_stopping": st.session_state["early_stopping"], |
|
"num_beams": st.session_state["num_beams"], |
|
"temperature": st.session_state["temperature"], |
|
"top_k": None, |
|
"top_p": None, |
|
"no_repeat_ngram_size": 3, |
|
"num_return_sequences": 1 |
|
}, |
|
"options": { |
|
"wait_for_model": True |
|
} |
|
} |
|
data = json.dumps(payload) |
|
response = requests.request("POST", API_URL, headers=headers, data=data) |
|
return json.loads(response.content.decode("utf-8")) |
|
|
|
|
|
def inference_lfqa(model_input: str, header: dict): |
|
payload = { |
|
"model_input": model_input, |
|
"parameters": { |
|
"min_length": st.session_state["min_length"], |
|
"max_length": st.session_state["max_length"], |
|
"do_sample": st.session_state["do_sample"], |
|
"early_stopping": st.session_state["early_stopping"], |
|
"num_beams": st.session_state["num_beams"], |
|
"temperature": st.session_state["temperature"], |
|
"top_k": None, |
|
"top_p": None, |
|
"no_repeat_ngram_size": 3, |
|
"num_return_sequences": 1 |
|
} |
|
} |
|
data = json.dumps(payload) |
|
try: |
|
response = requests.request("POST", LFQA_API_URL, headers=header, data=data) |
|
if response.status_code == 200: |
|
json_response = response.content.decode("utf-8") |
|
result = json.loads(json_response) |
|
else: |
|
result = {"error": f"LFQA service unavailable, status code={response.status_code}"} |
|
except requests.exceptions.RequestException as e: |
|
result = {"error": e} |
|
return result |
|
|
|
|
|
def invoke_lfqa(service_backend: str, model_input: str, header: Optional[dict]): |
|
if "HuggingFace" == service_backend: |
|
inference_response = api_inference_lfqa(model_input) |
|
else: |
|
inference_response = inference_lfqa(model_input, header) |
|
return inference_response |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def hf_tts(text: str): |
|
payload = { |
|
"inputs": text, |
|
"parameters": { |
|
"vocoder_tag": "str_or_none(none)", |
|
"threshold": 0.5, |
|
"minlenratio": 0.0, |
|
"maxlenratio": 10.0, |
|
"use_att_constraint": False, |
|
"backward_window": 1, |
|
"forward_window": 3, |
|
"speed_control_alpha": 1.0, |
|
"noise_scale": 0.333, |
|
"noise_scale_dur": 0.333 |
|
}, |
|
"options": { |
|
"wait_for_model": True |
|
} |
|
} |
|
data = json.dumps(payload) |
|
response = requests.request("POST", API_URL_TTS, headers=headers, data=data) |
|
return response.content |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def google_tts(text: str, private_key_id: str, private_key: str, client_email: str): |
|
config = { |
|
"private_key_id": private_key_id, |
|
"private_key": f"-----BEGIN PRIVATE KEY-----\n{private_key}\n-----END PRIVATE KEY-----\n", |
|
"client_email": client_email, |
|
"token_uri": "https://oauth2.googleapis.com/token", |
|
} |
|
credentials = service_account.Credentials.from_service_account_info(config) |
|
client = texttospeech.TextToSpeechClient(credentials=credentials) |
|
|
|
synthesis_input = texttospeech.SynthesisInput(text=text) |
|
|
|
|
|
|
|
voice = texttospeech.VoiceSelectionParams(language_code="en-US", |
|
ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL) |
|
|
|
|
|
audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3) |
|
|
|
|
|
|
|
response = client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config) |
|
return response |
|
|
|
|
|
def request_context_passages(question, header): |
|
try: |
|
response = requests.request("GET", CONTEXT_API_URL + question, headers=header) |
|
if response.status_code == 200: |
|
json_response = response.content.decode("utf-8") |
|
result = json.loads(json_response) |
|
else: |
|
result = {"error": f"Context passage service unavailable, status code={response.status_code}"} |
|
except requests.exceptions.RequestException as e: |
|
result = {"error": e} |
|
|
|
return result |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def get_sentence_transformer(): |
|
return SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def get_sentence_transformer_encoding(sentences): |
|
model = get_sentence_transformer() |
|
return model.encode([sentence for sentence in sentences], convert_to_tensor=True) |
|
|
|
|
|
def sign_jwt() -> Dict[str, str]: |
|
payload = { |
|
"expires": time.time() + 6000 |
|
} |
|
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) |
|
return token |
|
|
|
|
|
def extract_sentences_from_passages(passages): |
|
sentences = [] |
|
for idx, node in enumerate(passages): |
|
sentences.extend(tokenize.sent_tokenize(node["text"])) |
|
return sentences |
|
|
|
|
|
def similarity_color_picker(similarity: float): |
|
value = int(similarity * 75) |
|
rgb = colorsys.hsv_to_rgb(value / 300., 1.0, 1.0) |
|
return [round(255 * x) for x in rgb] |
|
|
|
|
|
def rgb_to_hex(rgb): |
|
return '%02x%02x%02x' % tuple(rgb) |
|
|
|
|
|
def similiarity_to_hex(similarity: float): |
|
return rgb_to_hex(similarity_color_picker(similarity)) |
|
|
|
|
|
def rerank(question: str, passages: List[str], include_rank: int = 4) -> List[str]: |
|
ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
question_passage_combinations = [[question, p["text"]] for p in passages] |
|
|
|
|
|
similarity_scores = ce.predict(question_passage_combinations) |
|
|
|
|
|
sim_ranking_idx = np.flip(np.argsort(similarity_scores)) |
|
return [passages[rank_idx] for rank_idx in sim_ranking_idx[:include_rank]] |
|
|
|
|
|
def answer_to_context_similarity(generated_answer, context_passages, topk=3): |
|
context_sentences = extract_sentences_from_passages(context_passages) |
|
context_sentences_e = get_sentence_transformer_encoding(context_sentences) |
|
answer_sentences = tokenize.sent_tokenize(generated_answer) |
|
answer_sentences_e = get_sentence_transformer_encoding(answer_sentences) |
|
search_result = util.semantic_search(answer_sentences_e, context_sentences_e, top_k=topk) |
|
result = [] |
|
for idx, r in enumerate(search_result): |
|
context = [] |
|
for idx_c in range(topk): |
|
context.append({"source": context_sentences[r[idx_c]["corpus_id"]], "score": r[idx_c]["score"]}) |
|
result.append({"answer": answer_sentences[idx], "context": context}) |
|
return result |
|
|
|
|
|
def post_process_answer(generated_answer): |
|
result = generated_answer |
|
|
|
regex = r"([A-Z][a-z].*?[.:!?](?=$| [A-Z]))" |
|
answer_sentences = tokenize.sent_tokenize(generated_answer) |
|
|
|
if len(answer_sentences) > len(re.findall(regex, generated_answer)): |
|
drop_last_sentence = " ".join(s for s in answer_sentences[:-1]) |
|
result = drop_last_sentence |
|
return result.strip() |
|
|
|
|
|
def format_score(value: float, precision=2): |
|
return f"{value:.{precision}f}" |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def get_answer(question: str): |
|
if not question: |
|
return {} |
|
|
|
resp: Dict[str, str] = {} |
|
if question and len(question.split()) > 3: |
|
header = {"Authorization": f"Bearer {sign_jwt()}"} |
|
context_passages = request_context_passages(question, header) |
|
if "error" in context_passages: |
|
resp = context_passages |
|
else: |
|
context_passages = rerank(question, context_passages) |
|
conditioned_context = "<P> " + " <P> ".join([d["text"] for d in context_passages]) |
|
model_input = f'question: {question} context: {conditioned_context}' |
|
|
|
inference_response = invoke_lfqa(st.session_state["api_lfqa_selector"], model_input, header) |
|
if "error" in inference_response: |
|
resp = inference_response |
|
else: |
|
resp["context_passages"] = context_passages |
|
resp["answer"] = post_process_answer(inference_response[0]["generated_text"]) |
|
else: |
|
resp = {"error": f"A longer, more descriptive question will receive a better answer. '{question}' is too short."} |
|
return resp |
|
|
|
|
|
def app(): |
|
with open('style.css') as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
footer = """ |
|
<div class="footer-custom"> |
|
Streamlit app - <a href="https://www.linkedin.com/in/danijel-petkovic-573309144/" target="_blank">Danijel Petkovic</a> | |
|
LFQA/DPR models - <a href="https://www.linkedin.com/in/blagojevicvladimir/" target="_blank">Vladimir Blagojevic</a> | |
|
Guidance & Feedback - <a href="https://yjernite.github.io/" target="_blank">Yacine Jernite</a> | |
|
<a href="https://towardsdatascience.com/long-form-qa-beyond-eli5-an-updated-dataset-and-approach-319cb841aabb" target="_blank">Blog</a> |
|
</div> |
|
""" |
|
st.markdown(footer, unsafe_allow_html=True) |
|
|
|
st.title('Wikipedia Assistant') |
|
st.header('We are migrating to new backend infrastructure. ETA - 15.6.2022') |
|
|
|
|
|
|
|
question = "" |
|
spinner = st.empty() |
|
if question !="": |
|
spinner.markdown( |
|
f""" |
|
<div class="loader-wrapper"> |
|
<div class="loader"> |
|
</div> |
|
<p>Generating answer for: <b>{question}</b></p> |
|
</div> |
|
<label class="loader-note">Answer generation may take up to 20 sec. Please stand by.</label> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
question_response = get_answer(question) |
|
if question_response: |
|
if "error" in question_response: |
|
st.warning(question_response["error"]) |
|
else: |
|
spinner.markdown(f"") |
|
generated_answer = question_response["answer"] |
|
context_passages = question_response["context_passages"] |
|
sentence_similarity = answer_to_context_similarity(generated_answer, context_passages, topk=3) |
|
sentences = "<div class='sentence-wrapper'>" |
|
for item in sentence_similarity: |
|
sentences += '<span>' |
|
score = item["context"][0]["score"] |
|
support_sentence = item["context"][0]["source"] |
|
sentences += "".join([ |
|
f' {item["answer"]}', |
|
f'<span style="background-color: #{similiarity_to_hex(score)}" class="tooltip">', |
|
f'{format_score(score, precision=1)}', |
|
f'<span class="tooltiptext"><b>Wikipedia source</b><br><br> {support_sentence} <br><br>Similarity: {format_score(score)}</span>' |
|
]) |
|
sentences += '</span>' |
|
sentences += '</span>' |
|
st.markdown(sentences, unsafe_allow_html=True) |
|
|
|
with st.spinner("Generating audio..."): |
|
if st.session_state["tts"] == "HuggingFace": |
|
audio_file = hf_tts(generated_answer) |
|
with open("out.flac", "wb") as f: |
|
f.write(audio_file) |
|
else: |
|
audio_file = google_tts(generated_answer, st.secrets["private_key_id"], |
|
st.secrets["private_key"], st.secrets["client_email"]) |
|
with open("out.mp3", "wb") as f: |
|
f.write(audio_file.audio_content) |
|
|
|
audio_file = "out.flac" if st.session_state["tts"] == "HuggingFace" else "out.mp3" |
|
st.audio(audio_file) |
|
|
|
st.markdown("""<hr></hr>""", unsafe_allow_html=True) |
|
|
|
model = get_sentence_transformer() |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Context") |
|
with col2: |
|
selection = st.selectbox( |
|
label="", |
|
options=('Paragraphs', 'Sentences', 'Answer Similarity'), |
|
help="Context represents Wikipedia passages used to generate the answer") |
|
question_e = model.encode(question, convert_to_tensor=True) |
|
if selection == "Paragraphs": |
|
sentences = extract_sentences_from_passages(context_passages) |
|
context_e = get_sentence_transformer_encoding(sentences) |
|
scores = util.cos_sim(question_e.repeat(context_e.shape[0], 1), context_e) |
|
similarity_scores = scores[0].squeeze().tolist() |
|
for idx, node in enumerate(context_passages): |
|
node["answer_similarity"] = "{0:.2f}".format(similarity_scores[idx]) |
|
context_passages = sorted(context_passages, key=lambda x: x["answer_similarity"], reverse=True) |
|
st.json(context_passages) |
|
elif selection == "Sentences": |
|
sentences = extract_sentences_from_passages(context_passages) |
|
sentences_e = get_sentence_transformer_encoding(sentences) |
|
scores = util.cos_sim(question_e.repeat(sentences_e.shape[0], 1), sentences_e) |
|
sentence_similarity_scores = scores[0].squeeze().tolist() |
|
result = [] |
|
for idx, sentence in enumerate(sentences): |
|
result.append( |
|
{"text": sentence, "answer_similarity": "{0:.2f}".format(sentence_similarity_scores[idx])}) |
|
context_sentences = json.dumps(sorted(result, key=lambda x: x["answer_similarity"], reverse=True)) |
|
st.json(context_sentences) |
|
else: |
|
st.json(sentence_similarity) |
|
|