|
from datasets import load_dataset |
|
from transformers import ( |
|
DPRQuestionEncoder, |
|
DPRQuestionEncoderTokenizer, |
|
MT5ForConditionalGeneration, |
|
AutoTokenizer, |
|
AutoModelForCTC, |
|
Wav2Vec2Tokenizer, |
|
) |
|
from general_utils import ( |
|
embed_questions, |
|
transcript, |
|
remove_chars_to_tts, |
|
parse_final_answer, |
|
) |
|
from typing import List |
|
import gradio as gr |
|
from article_app import article, description, examples |
|
from haystack.nodes import DensePassageRetriever |
|
from haystack.document_stores import InMemoryDocumentStore |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer, util, CrossEncoder |
|
|
|
topk = 21 |
|
minchars = 200 |
|
min_snippet_length = 20 |
|
device = "cpu" |
|
covidterms = ["covid19", "covid", "coronavirus", "covid-19", "sars-cov-2"] |
|
|
|
models = { |
|
"wav2vec2-iic": { |
|
"processor": Wav2Vec2Tokenizer.from_pretrained( |
|
"IIC/wav2vec2-spanish-multilibrispeech" |
|
), |
|
"model": AutoModelForCTC.from_pretrained( |
|
"IIC/wav2vec2-spanish-multilibrispeech" |
|
), |
|
}, |
|
} |
|
|
|
|
|
tts_es = gr.Interface.load("huggingface/facebook/tts_transformer-es-css10") |
|
|
|
|
|
params_generate = { |
|
"min_length": 50, |
|
"max_length": 250, |
|
"do_sample": False, |
|
"early_stopping": True, |
|
"num_beams": 8, |
|
"temperature": 1.0, |
|
"top_k": None, |
|
"top_p": None, |
|
"no_repeat_ngram_size": 3, |
|
"num_return_sequences": 1, |
|
} |
|
|
|
dpr = DensePassageRetriever( |
|
document_store=InMemoryDocumentStore(), |
|
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", |
|
passage_embedding_model="IIC/dpr-spanish-passage_encoder-allqa-base", |
|
max_seq_len_query=64, |
|
max_seq_len_passage=256, |
|
batch_size=512, |
|
use_gpu=False, |
|
) |
|
|
|
mt5_tokenizer = AutoTokenizer.from_pretrained("IIC/mt5-base-lfqa-es") |
|
mt5_lfqa = MT5ForConditionalGeneration.from_pretrained("IIC/mt5-base-lfqa-es") |
|
|
|
similarity_model = SentenceTransformer( |
|
"distiluse-base-multilingual-cased", device="cpu" |
|
) |
|
|
|
crossencoder = CrossEncoder("IIC/roberta-base-bne-ranker", device="cpu") |
|
|
|
dataset = load_dataset("IIC/spanish_biomedical_crawled_corpus", split="train") |
|
|
|
dataset = dataset.filter(lambda example: len(example["text"]) > minchars) |
|
|
|
dataset.load_faiss_index( |
|
"embeddings", |
|
"dpr_index_bio_newdpr.faiss", |
|
) |
|
|
|
|
|
def query_index(question: str): |
|
question_embedding = dpr.embed_queries([question])[0] |
|
scores, closest_passages = dataset.get_nearest_examples( |
|
"embeddings", question_embedding, k=topk |
|
) |
|
contexts = [ |
|
closest_passages["text"][i] for i in range(len(closest_passages["text"])) |
|
] |
|
return [ |
|
context for context in contexts if len(context.split()) > min_snippet_length |
|
] |
|
|
|
|
|
def sort_on_similarity(question, contexts, include_rank: int = 5): |
|
question_encoded = similarity_model.encode([question])[0] |
|
ctxs_encoded = similarity_model.encode(contexts) |
|
sim_scores_ss = [ |
|
util.cos_sim(question_encoded, ctx_encoded) for ctx_encoded in ctxs_encoded |
|
] |
|
text_pairs = [[question, ctx] for ctx in contexts] |
|
similarity_scores = crossencoder.predict(text_pairs) |
|
similarity_scores = np.array(sim_scores_ss) * similarity_scores |
|
similarity_ranking_idx = np.flip(np.argsort(similarity_scores)) |
|
return [contexts[idx] for idx in similarity_ranking_idx][:include_rank] |
|
|
|
|
|
def create_context(contexts: List): |
|
return "<p>" + "<p>".join(contexts) |
|
|
|
|
|
def create_model_input(question: str, context: str): |
|
return f"question: {question} context: {context}" |
|
|
|
|
|
def generate_answer(model_input, update_params): |
|
model_input = mt5_tokenizer( |
|
model_input, truncation=True, padding=True, return_tensors="pt", max_length=1024 |
|
) |
|
params_generate.update(update_params) |
|
answers_encoded = mt5_lfqa.generate( |
|
input_ids=model_input["input_ids"].to(device), |
|
attention_mask=model_input["attention_mask"].to(device), |
|
**params_generate, |
|
) |
|
answers = mt5_tokenizer.batch_decode( |
|
answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
results = [{"generated_text": answer} for answer in answers] |
|
return results |
|
|
|
|
|
def search_and_answer( |
|
question, |
|
audio_file, |
|
audio_array, |
|
min_length_answer, |
|
num_beams, |
|
no_repeat_ngram_size, |
|
temperature, |
|
max_answer_length, |
|
do_tts, |
|
): |
|
update_params = { |
|
"min_length": min_length_answer, |
|
"max_length": max_answer_length, |
|
"num_beams": int(num_beams), |
|
"temperature": temperature, |
|
"no_repeat_ngram_size": no_repeat_ngram_size, |
|
} |
|
if not question: |
|
s2t_model = models["wav2vec2-iic"]["model"] |
|
s2t_processor = models["wav2vec2-iic"]["processor"] |
|
question = transcript( |
|
audio_file, audio_array, processor=s2t_processor, model=s2t_model |
|
) |
|
print(f"Transcripted question: *** {question} ****") |
|
if any([any([term in word.lower() for term in covidterms]) for word in question.split(" ")]): |
|
return "Del COVID no queremos saber ya más nada, lo sentimos, pregúntame sobre otra cosa :P ", "ni contexto ni contexta.", "audio_troll.flac" |
|
contexts = query_index(question) |
|
contexts = sort_on_similarity(question, contexts) |
|
context = create_context(contexts) |
|
model_input = create_model_input(question, context) |
|
answers = generate_answer(model_input, update_params) |
|
final_answer = answers[0]["generated_text"] |
|
if do_tts: |
|
audio_answer = tts_es(remove_chars_to_tts(final_answer)) |
|
final_answer, documents = parse_final_answer(final_answer, contexts) |
|
return final_answer, documents, audio_answer if do_tts else "audio_troll.flac" |
|
|
|
|
|
if __name__ == "__main__": |
|
gr.Interface( |
|
search_and_answer, |
|
inputs=[ |
|
gr.inputs.Textbox( |
|
lines=2, |
|
label="Pregúntame sobre BioMedicina o temas relacionados. Puedes simplemente preguntarme aquí y darle al botón verde de abajo que pone Enviar.", |
|
placeholder="Escribe aquí tu pregunta", |
|
optional=True, |
|
), |
|
gr.inputs.Audio( |
|
source="upload", |
|
type="filepath", |
|
label="Sube un audio con tu respuesta aquí si quieres.", |
|
optional=True, |
|
), |
|
gr.inputs.Audio( |
|
source="microphone", |
|
type="numpy", |
|
label="Graba aquí un audio con tu pregunta.", |
|
optional=True, |
|
), |
|
gr.inputs.Slider( |
|
minimum=10, |
|
maximum=200, |
|
default=50, |
|
label="Minimum size for the answer", |
|
step=1, |
|
), |
|
gr.inputs.Slider( |
|
minimum=4, maximum=12, default=8, label="number of beams", step=1 |
|
), |
|
gr.inputs.Slider( |
|
minimum=2, maximum=5, default=3, label="no repeat n-gram size", step=1 |
|
), |
|
gr.inputs.Slider( |
|
minimum=0.8, maximum=2.0, default=1.0, label="temperature", step=0.1 |
|
), |
|
gr.inputs.Slider( |
|
minimum=220, |
|
maximum=360, |
|
default=250, |
|
label="maximum answer length", |
|
step=1, |
|
), |
|
gr.inputs.Checkbox( |
|
default=False, label="Text to Speech", optional=True), |
|
], |
|
outputs=[ |
|
gr.outputs.HTML( |
|
label="Respuesta generada." |
|
), |
|
gr.outputs.HTML( |
|
label="Documentos utilizados." |
|
), |
|
gr.outputs.Audio(label="Respuesta en audio."), |
|
], |
|
description=description, |
|
examples=examples, |
|
theme="grass", |
|
article=article, |
|
thumbnail="IIC_logoP.png", |
|
css="https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css", |
|
).launch() |
|
|