Spaces:
Build error
Build error
| from datasets import load_dataset | |
| from transformers import ( | |
| DPRQuestionEncoder, | |
| DPRQuestionEncoderTokenizer, | |
| MT5ForConditionalGeneration, | |
| AutoTokenizer, | |
| AutoModelForCTC, | |
| Wav2Vec2Tokenizer, | |
| ) | |
| from general_utils import ( | |
| embed_questions, | |
| transcript, | |
| remove_chars_to_tts, | |
| parse_final_answer, | |
| ) | |
| from typing import List | |
| import gradio as gr | |
| from article_app import article, description, examples | |
| from haystack.nodes import DensePassageRetriever | |
| from haystack.document_stores import InMemoryDocumentStore | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer, util, CrossEncoder | |
| topk = 21 | |
| minchars = 200 | |
| min_snippet_length = 20 | |
| device = "cpu" | |
| covidterms = ["covid19", "covid", "coronavirus", "covid-19", "sars-cov-2"] | |
| models = { | |
| "wav2vec2-iic": { | |
| "processor": Wav2Vec2Tokenizer.from_pretrained( | |
| "IIC/wav2vec2-spanish-multilibrispeech" | |
| ), | |
| "model": AutoModelForCTC.from_pretrained( | |
| "IIC/wav2vec2-spanish-multilibrispeech" | |
| ), | |
| }, | |
| } | |
| tts_es = gr.Interface.load("huggingface/facebook/tts_transformer-es-css10") | |
| params_generate = { | |
| "min_length": 50, | |
| "max_length": 250, | |
| "do_sample": False, | |
| "early_stopping": True, | |
| "num_beams": 8, | |
| "temperature": 1.0, | |
| "top_k": None, | |
| "top_p": None, | |
| "no_repeat_ngram_size": 3, | |
| "num_return_sequences": 1, | |
| } | |
| dpr = DensePassageRetriever( | |
| document_store=InMemoryDocumentStore(), | |
| query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", | |
| passage_embedding_model="IIC/dpr-spanish-passage_encoder-allqa-base", | |
| max_seq_len_query=64, | |
| max_seq_len_passage=256, | |
| batch_size=512, | |
| use_gpu=False, | |
| ) | |
| mt5_tokenizer = AutoTokenizer.from_pretrained("IIC/mt5-base-lfqa-es") | |
| mt5_lfqa = MT5ForConditionalGeneration.from_pretrained("IIC/mt5-base-lfqa-es") | |
| similarity_model = SentenceTransformer( | |
| "distiluse-base-multilingual-cased", device="cpu" | |
| ) | |
| crossencoder = CrossEncoder("IIC/roberta-base-bne-ranker", device="cpu") | |
| dataset = load_dataset("IIC/spanish_biomedical_crawled_corpus", split="train") | |
| dataset = dataset.filter(lambda example: len(example["text"]) > minchars) | |
| dataset.load_faiss_index( | |
| "embeddings", | |
| "dpr_index_bio_newdpr.faiss", | |
| ) | |
| def query_index(question: str): | |
| question_embedding = dpr.embed_queries([question])[0] | |
| scores, closest_passages = dataset.get_nearest_examples( | |
| "embeddings", question_embedding, k=topk | |
| ) | |
| contexts = [ | |
| closest_passages["text"][i] for i in range(len(closest_passages["text"])) | |
| ]# [:int(topk / 3)] | |
| return [ | |
| context for context in contexts if len(context.split()) > min_snippet_length | |
| ] | |
| def sort_on_similarity(question, contexts, include_rank: int = 5): | |
| question_encoded = similarity_model.encode([question])[0] | |
| ctxs_encoded = similarity_model.encode(contexts) | |
| sim_scores_ss = [ | |
| util.cos_sim(question_encoded, ctx_encoded) for ctx_encoded in ctxs_encoded | |
| ] | |
| text_pairs = [[question, ctx] for ctx in contexts] | |
| similarity_scores = crossencoder.predict(text_pairs) | |
| similarity_scores = np.array(sim_scores_ss) * similarity_scores | |
| similarity_ranking_idx = np.flip(np.argsort(similarity_scores)) | |
| return [contexts[idx] for idx in similarity_ranking_idx][:include_rank] | |
| def create_context(contexts: List): | |
| return "<p>" + "<p>".join(contexts) | |
| def create_model_input(question: str, context: str): | |
| return f"question: {question} context: {context}" | |
| def generate_answer(model_input, update_params): | |
| model_input = mt5_tokenizer( | |
| model_input, truncation=True, padding=True, return_tensors="pt", max_length=1024 | |
| ) | |
| params_generate.update(update_params) | |
| answers_encoded = mt5_lfqa.generate( | |
| input_ids=model_input["input_ids"].to(device), | |
| attention_mask=model_input["attention_mask"].to(device), | |
| **params_generate, | |
| ) | |
| answers = mt5_tokenizer.batch_decode( | |
| answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| results = [{"generated_text": answer} for answer in answers] | |
| return results | |
| def search_and_answer( | |
| question, | |
| audio_file, | |
| audio_array, | |
| min_length_answer, | |
| num_beams, | |
| no_repeat_ngram_size, | |
| temperature, | |
| max_answer_length, | |
| do_tts, | |
| ): | |
| update_params = { | |
| "min_length": min_length_answer, | |
| "max_length": max_answer_length, | |
| "num_beams": int(num_beams), | |
| "temperature": temperature, | |
| "no_repeat_ngram_size": no_repeat_ngram_size, | |
| } | |
| if not question: | |
| s2t_model = models["wav2vec2-iic"]["model"] | |
| s2t_processor = models["wav2vec2-iic"]["processor"] | |
| question = transcript( | |
| audio_file, audio_array, processor=s2t_processor, model=s2t_model | |
| ) | |
| print(f"Transcripted question: *** {question} ****") | |
| if any([any([term in word.lower() for term in covidterms]) for word in question.split(" ")]): | |
| return "Del COVID no queremos saber ya más nada, lo sentimos, pregúntame sobre otra cosa :P ", "ni contexto ni contexta.", "audio_troll.flac" | |
| contexts = query_index(question) | |
| contexts = sort_on_similarity(question, contexts) | |
| context = create_context(contexts) | |
| model_input = create_model_input(question, context) | |
| answers = generate_answer(model_input, update_params) | |
| final_answer = answers[0]["generated_text"] | |
| if do_tts: | |
| audio_answer = tts_es(remove_chars_to_tts(final_answer)) | |
| final_answer, documents = parse_final_answer(final_answer, contexts) | |
| return final_answer, documents, audio_answer if do_tts else "audio_troll.flac" | |
| if __name__ == "__main__": | |
| gr.Interface( | |
| search_and_answer, | |
| inputs=[ | |
| gr.inputs.Textbox( | |
| lines=2, | |
| label="Pregúntame sobre BioMedicina o temas relacionados. Puedes simplemente preguntarme aquí y darle al botón verde de abajo que pone Enviar.", | |
| placeholder="Escribe aquí tu pregunta", | |
| optional=True, | |
| ), | |
| gr.inputs.Audio( | |
| source="upload", | |
| type="filepath", | |
| label="Sube un audio con tu respuesta aquí si quieres.", | |
| optional=True, | |
| ), | |
| gr.inputs.Audio( | |
| source="microphone", | |
| type="numpy", | |
| label="Graba aquí un audio con tu pregunta.", | |
| optional=True, | |
| ), | |
| gr.inputs.Slider( | |
| minimum=10, | |
| maximum=200, | |
| default=50, | |
| label="Minimum size for the answer", | |
| step=1, | |
| ), | |
| gr.inputs.Slider( | |
| minimum=4, maximum=12, default=8, label="number of beams", step=1 | |
| ), | |
| gr.inputs.Slider( | |
| minimum=2, maximum=5, default=3, label="no repeat n-gram size", step=1 | |
| ), | |
| gr.inputs.Slider( | |
| minimum=0.8, maximum=2.0, default=1.0, label="temperature", step=0.1 | |
| ), | |
| gr.inputs.Slider( | |
| minimum=220, | |
| maximum=360, | |
| default=250, | |
| label="maximum answer length", | |
| step=1, | |
| ), | |
| gr.inputs.Checkbox( | |
| default=False, label="Text to Speech", optional=True), | |
| ], | |
| outputs=[ | |
| gr.outputs.HTML( | |
| label="Respuesta generada." | |
| ), | |
| gr.outputs.HTML( | |
| label="Documentos utilizados." | |
| ), | |
| gr.outputs.Audio(label="Respuesta en audio."), | |
| ], | |
| description=description, | |
| examples=examples, | |
| theme="grass", | |
| article=article, | |
| thumbnail="IIC_logoP.png", | |
| css="https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css", | |
| ).launch() | |