Spaces:
Build error
Build error
Alejandro Vaca
commited on
Commit
•
6bf4ad7
1
Parent(s):
92d518a
initial commit
Browse files- .gitattributes +2 -0
- app.py +254 -0
- article_app.py +173 -0
- audio_troll.flac +3 -0
- dpr_index_bio.faiss +3 -0
- dpr_index_bio_newdpr.faiss +3 -0
- dpr_index_bio_prueba.faiss +3 -0
- dpr_index_bio_splitted.faiss +3 -0
- general_utils.py +138 -0
- packages.txt +2 -0
- requirements.txt +10 -0
- save_faiss_index.py +65 -0
- tmptdsnrh_8.flac +3 -0
- vacio.flac +3 -0
.gitattributes
CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import (
|
3 |
+
DPRQuestionEncoder,
|
4 |
+
DPRQuestionEncoderTokenizer,
|
5 |
+
MT5ForConditionalGeneration,
|
6 |
+
AutoTokenizer,
|
7 |
+
AutoModelForCTC,
|
8 |
+
Wav2Vec2Tokenizer,
|
9 |
+
)
|
10 |
+
from general_utils import (
|
11 |
+
embed_questions,
|
12 |
+
transcript,
|
13 |
+
remove_chars_to_tts,
|
14 |
+
parse_final_answer,
|
15 |
+
)
|
16 |
+
from typing import List
|
17 |
+
import gradio as gr
|
18 |
+
from article_app import article, description, examples
|
19 |
+
from haystack.nodes import DensePassageRetriever
|
20 |
+
from haystack.document_stores import InMemoryDocumentStore
|
21 |
+
import numpy as np
|
22 |
+
from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
23 |
+
|
24 |
+
topk = 21
|
25 |
+
minchars = 200
|
26 |
+
min_snippet_length = 20
|
27 |
+
device = "cpu"
|
28 |
+
covidterms = ["covid19", "covid", "coronavirus", "covid-19", "sars-cov-2"]
|
29 |
+
|
30 |
+
models = {
|
31 |
+
"wav2vec2-iic": {
|
32 |
+
"processor": Wav2Vec2Tokenizer.from_pretrained(
|
33 |
+
"IIC/wav2vec2-spanish-multilibrispeech"
|
34 |
+
),
|
35 |
+
"model": AutoModelForCTC.from_pretrained(
|
36 |
+
"IIC/wav2vec2-spanish-multilibrispeech"
|
37 |
+
),
|
38 |
+
},
|
39 |
+
# "wav2vec2-jonatangrosman": {
|
40 |
+
# "processor": Wav2Vec2Tokenizer.from_pretrained(
|
41 |
+
# "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
|
42 |
+
# ),
|
43 |
+
# "model": AutoModelForCTC.from_pretrained(
|
44 |
+
# "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
|
45 |
+
# ),
|
46 |
+
# },
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
tts_es = gr.Interface.load("huggingface/facebook/tts_transformer-es-css10")
|
51 |
+
|
52 |
+
|
53 |
+
params_generate = {
|
54 |
+
"min_length": 50,
|
55 |
+
"max_length": 250,
|
56 |
+
"do_sample": False,
|
57 |
+
"early_stopping": True,
|
58 |
+
"num_beams": 8,
|
59 |
+
"temperature": 1.0,
|
60 |
+
"top_k": None,
|
61 |
+
"top_p": None,
|
62 |
+
"no_repeat_ngram_size": 3,
|
63 |
+
"num_return_sequences": 1,
|
64 |
+
}
|
65 |
+
|
66 |
+
dpr = DensePassageRetriever(
|
67 |
+
document_store=InMemoryDocumentStore(),
|
68 |
+
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
|
69 |
+
passage_embedding_model="IIC/dpr-spanish-passage_encoder-allqa-base",
|
70 |
+
max_seq_len_query=64,
|
71 |
+
max_seq_len_passage=256,
|
72 |
+
batch_size=512,
|
73 |
+
use_gpu=False,
|
74 |
+
)
|
75 |
+
|
76 |
+
mt5_tokenizer = AutoTokenizer.from_pretrained("IIC/mt5-base-lfqa-es")
|
77 |
+
mt5_lfqa = MT5ForConditionalGeneration.from_pretrained("IIC/mt5-base-lfqa-es")
|
78 |
+
|
79 |
+
similarity_model = SentenceTransformer(
|
80 |
+
"distiluse-base-multilingual-cased", device="cpu"
|
81 |
+
)
|
82 |
+
|
83 |
+
crossencoder = CrossEncoder("avacaondata/roberta-base-bne-ranker", device="cpu")
|
84 |
+
|
85 |
+
dataset = load_dataset("IIC/spanish_biomedical_crawled_corpus", split="train")
|
86 |
+
|
87 |
+
dataset = dataset.filter(lambda example: len(example["text"]) > minchars)
|
88 |
+
|
89 |
+
dataset.load_faiss_index(
|
90 |
+
"embeddings",
|
91 |
+
"dpr_index_bio_newdpr.faiss",
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def query_index(question: str):
|
96 |
+
question_embedding = dpr.embed_queries([question])[0]
|
97 |
+
scores, closest_passages = dataset.get_nearest_examples(
|
98 |
+
"embeddings", question_embedding, k=topk
|
99 |
+
)
|
100 |
+
contexts = [
|
101 |
+
closest_passages["text"][i] for i in range(len(closest_passages["text"]))
|
102 |
+
]
|
103 |
+
# [:int(topk / 3)]
|
104 |
+
return [
|
105 |
+
context for context in contexts if len(context.split()) > min_snippet_length
|
106 |
+
]
|
107 |
+
|
108 |
+
|
109 |
+
def sort_on_similarity(question, contexts, include_rank: int = 5):
|
110 |
+
# TODO: METER AQUÍ EL CROSSENCODER nuestro
|
111 |
+
question_encoded = similarity_model.encode([question])[0]
|
112 |
+
ctxs_encoded = similarity_model.encode(contexts)
|
113 |
+
similarity_scores = [
|
114 |
+
util.cos_sim(question_encoded, ctx_encoded) for ctx_encoded in ctxs_encoded
|
115 |
+
]
|
116 |
+
similarity_ranking_idx = np.flip(np.argsort(similarity_scores))
|
117 |
+
return [contexts[idx] for idx in similarity_ranking_idx][:include_rank]
|
118 |
+
|
119 |
+
|
120 |
+
def create_context(contexts: List):
|
121 |
+
return "<p>" + "<p>".join(contexts)
|
122 |
+
|
123 |
+
|
124 |
+
def create_model_input(question: str, context: str):
|
125 |
+
return f"question: {question} context: {context}"
|
126 |
+
|
127 |
+
|
128 |
+
def generate_answer(model_input, update_params):
|
129 |
+
model_input = mt5_tokenizer(
|
130 |
+
model_input, truncation=True, padding=True, return_tensors="pt", max_length=1024
|
131 |
+
)
|
132 |
+
params_generate.update(update_params)
|
133 |
+
answers_encoded = mt5_lfqa.generate(
|
134 |
+
input_ids=model_input["input_ids"].to(device),
|
135 |
+
attention_mask=model_input["attention_mask"].to(device),
|
136 |
+
**params_generate,
|
137 |
+
)
|
138 |
+
answers = mt5_tokenizer.batch_decode(
|
139 |
+
answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
140 |
+
)
|
141 |
+
results = [{"generated_text": answer} for answer in answers]
|
142 |
+
return results
|
143 |
+
|
144 |
+
|
145 |
+
def search_and_answer(
|
146 |
+
question,
|
147 |
+
audio_file,
|
148 |
+
audio_array,
|
149 |
+
min_length_answer,
|
150 |
+
num_beams,
|
151 |
+
no_repeat_ngram_size,
|
152 |
+
temperature,
|
153 |
+
max_answer_length,
|
154 |
+
wav2vec2_name,
|
155 |
+
do_tts,
|
156 |
+
):
|
157 |
+
update_params = {
|
158 |
+
"min_length": min_length_answer,
|
159 |
+
"max_length": max_answer_length,
|
160 |
+
"num_beams": int(num_beams),
|
161 |
+
"temperature": temperature,
|
162 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
163 |
+
}
|
164 |
+
if not question:
|
165 |
+
s2t_model = models[wav2vec2_name]["model"]
|
166 |
+
s2t_processor = models[wav2vec2_name]["processor"]
|
167 |
+
question = transcript(
|
168 |
+
audio_file, audio_array, processor=s2t_processor, model=s2t_model
|
169 |
+
)
|
170 |
+
print(f"Transcripted question: *** {question} ****")
|
171 |
+
if any([any([term in word.lower() for term in covidterms]) for word in question.split(" ")]):
|
172 |
+
return "Del COVID no queremos saber ya más nada, lo sentimos, pregúntame sobre otra cosa :P ", "tmptdsnrh_8.flac"
|
173 |
+
contexts = query_index(question)
|
174 |
+
contexts = sort_on_similarity(question, contexts)
|
175 |
+
context = create_context(contexts)
|
176 |
+
model_input = create_model_input(question, context)
|
177 |
+
answers = generate_answer(model_input, update_params)
|
178 |
+
final_answer = answers[0]["generated_text"]
|
179 |
+
if do_tts:
|
180 |
+
audio_answer = tts_es(remove_chars_to_tts(final_answer))
|
181 |
+
final_answer = parse_final_answer(final_answer, contexts)
|
182 |
+
return final_answer, audio_answer if do_tts else "tmptdsnrh_8.flac"
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
gr.Interface(
|
187 |
+
search_and_answer,
|
188 |
+
inputs=[
|
189 |
+
gr.inputs.Textbox(
|
190 |
+
lines=2,
|
191 |
+
label="Question",
|
192 |
+
placeholder="Type your question (in spanish) to the system.",
|
193 |
+
optional=True,
|
194 |
+
),
|
195 |
+
gr.inputs.Audio(
|
196 |
+
source="upload",
|
197 |
+
type="filepath",
|
198 |
+
label="Upload your audio asking a question here.",
|
199 |
+
optional=True,
|
200 |
+
),
|
201 |
+
gr.inputs.Audio(
|
202 |
+
source="microphone",
|
203 |
+
type="numpy",
|
204 |
+
label="Record your audio asking a question.",
|
205 |
+
optional=True,
|
206 |
+
),
|
207 |
+
gr.inputs.Slider(
|
208 |
+
minimum=10,
|
209 |
+
maximum=200,
|
210 |
+
default=50,
|
211 |
+
label="Minimum size for the answer",
|
212 |
+
step=1,
|
213 |
+
),
|
214 |
+
gr.inputs.Slider(
|
215 |
+
minimum=4, maximum=12, default=8, label="number of beams", step=1
|
216 |
+
),
|
217 |
+
gr.inputs.Slider(
|
218 |
+
minimum=2, maximum=5, default=3, label="no repeat n-gram size", step=1
|
219 |
+
),
|
220 |
+
gr.inputs.Slider(
|
221 |
+
minimum=0.8, maximum=2.0, default=1.0, label="temperature", step=0.1
|
222 |
+
),
|
223 |
+
gr.inputs.Slider(
|
224 |
+
minimum=220,
|
225 |
+
maximum=360,
|
226 |
+
default=250,
|
227 |
+
label="maximum answer length",
|
228 |
+
step=1,
|
229 |
+
),
|
230 |
+
gr.inputs.Dropdown(
|
231 |
+
["wav2vec2-iic", "wav2vec2-jonatangrosman"],
|
232 |
+
type="value",
|
233 |
+
default=None,
|
234 |
+
label="Select the speech recognition model.",
|
235 |
+
optional=False,
|
236 |
+
),
|
237 |
+
gr.inputs.Checkbox(
|
238 |
+
default=False, label="Text to Speech", optional=True),
|
239 |
+
],
|
240 |
+
outputs=[
|
241 |
+
gr.outputs.HTML(
|
242 |
+
# type="str",
|
243 |
+
label="Answer from the system."
|
244 |
+
),
|
245 |
+
gr.outputs.Audio(label="Answer in audio"),
|
246 |
+
],
|
247 |
+
# title="Abstractive QA of BioMedical Domain in Spanish",
|
248 |
+
description=description,
|
249 |
+
examples=examples,
|
250 |
+
theme="grass",
|
251 |
+
article=article,
|
252 |
+
thumbnail="IIC_logoP.png",
|
253 |
+
css="https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css",
|
254 |
+
).launch()
|
article_app.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
article = """
|
2 |
+
<img src="https://www.iic.uam.es/wp-content/uploads/2017/12/IIC_logoP.png">
|
3 |
+
<img src="https://drive.google.com/uc?export=view&id=1S8v94q39QRCfmVTMvjLCACmhMe9lJQdc">
|
4 |
+
|
5 |
+
<p style="text-align: justify;"> This app is developed by <a href="https://www.iic.uam.es/">IIC - Instituto de Ingeniería del Conocimiento</a> as part of the <a href="https://www.eventbrite.com/e/registro-hackathon-de-pln-en-espanol-273014111557">Somos PLN Hackaton 2022.</a>
|
6 |
+
|
7 |
+
The objective of this app is to expand the existing tools regarding long form question answering in Spanish. In fact, multiple novel methods (in Spanish)
|
8 |
+
have been introduced to build this app.
|
9 |
+
The reason for including audio as a possible input and always as an output is because we wanted to make the App much more accessible to people that cannot read or write.
|
10 |
+
Below you can find all the pieces that form the system.
|
11 |
+
|
12 |
+
1. <a href="https://huggingface.co/IIC/wav2vec2-spanish-multilibrispeech">Speech2Text</a>: For this we finedtuned a multilingual Wav2Vec2, as explained in the attached link. We use this model to process audio questions.
|
13 |
+
2. <a href="https://huggingface.co/IIC/dpr-spanish-passage_encoder-allqa-base">Dense Passage Retrieval for Context</a>: Dense Passage Retrieval is a methodology <a href="https://arxiv.org/abs/2004.04906">developed by Facebook</a> which is currently the SoTA for Passage Retrieval,
|
14 |
+
that is, the task of getting the most relevant passages to answer a given question with. You can find details about how it was trained on the link attached to the name.
|
15 |
+
3. <a href="https://huggingface.co/IIC/dpr-spanish-question_encoder-allqa-base">Dense Passage Retrieval for Question</a>: It is actually part of the same thing as the above. For more details, go to the attached link.
|
16 |
+
4. <a href="https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased-v1">Sentence Encoder Ranker</a>: To rerank the candidate contexts retrieved by dpr for the generative model to see. This also selects the top 5 passages for the model to read, it is the final filter before the generative model.
|
17 |
+
5. <a href="https://huggingface.co/IIC/mt5-base-lfqa-es">Generative Long-Form Question Answering Model</a>: For this we used either mT5 (the one attached) or <a href="https://huggingface.co/IIC/mbart-large-lfqa-es">mBART</a>. This generative model receives the most relevant
|
18 |
+
passages and uses them to generate an answer to the question. In the attached link there are more details about how we trained it etc.
|
19 |
+
|
20 |
+
On the other hand, we uploaded, and in some cases created, datasets in Spanish to be able to build such a system.
|
21 |
+
|
22 |
+
1. <a href="https://huggingface.co/datasets/IIC/spanish_biomedical_crawled_corpus">Spanish Biomedical Crawled Corpus</a>. Used for finding answers to questions about biomedicine. (More info in the link.)
|
23 |
+
2. <a href="https://huggingface.co/datasets/IIC/lfqa_spanish">LFQA_Spanish</a>. Used for training the generative model. (More info in the link.)
|
24 |
+
3. <a href="https://huggingface.co/datasets/squad_es">SQUADES</a>. Used to train the DPR models. (More info in the link.)
|
25 |
+
4. <a href="https://huggingface.co/datasets/IIC/bioasq22_es">BioAsq22-Spanish</a>. Used to train the DPR models. (More info in the link.)
|
26 |
+
5. <a href="https://huggingface.co/datasets/PlanTL-GOB-ES/SQAC">SQAC (Spanish Question Answering Corpus)</a>. Used to train the DPR models. (More info in the link.)
|
27 |
+
</p>
|
28 |
+
"""
|
29 |
+
# height="100", width="1000"
|
30 |
+
description = """
|
31 |
+
<a href="https://www.iic.uam.es/">
|
32 |
+
<img src="https://drive.google.com/uc?export=view&id=1xNz4EuafyzvMKSMTEfwzELln155uN6_H" style="max-width: 100%; max-height: 10%; height: 250px; object-fit: fill">,
|
33 |
+
</a>
|
34 |
+
<h1> BioMedIA: Abstractive Question Answering of BioMedical Domain in Spanish </h1>
|
35 |
+
Esta aplicación consiste en sistemas de búsqueda del Estado del Arte en Español junto con un modelo generativo entrenado para componer una respuesta a preguntas a partir de una serie de contextos.
|
36 |
+
"""
|
37 |
+
|
38 |
+
|
39 |
+
examples = [
|
40 |
+
[
|
41 |
+
"¿Cuáles son los efectos secundarios más ampliamente reportados en el tratamiento de la enfermedad de Crohn?",
|
42 |
+
"vacio.flac",
|
43 |
+
"vacio.flac",
|
44 |
+
60,
|
45 |
+
8,
|
46 |
+
3,
|
47 |
+
1.0,
|
48 |
+
250,
|
49 |
+
"wav2vec2-iic",
|
50 |
+
False,
|
51 |
+
],
|
52 |
+
[
|
53 |
+
"¿Qué alternativas al Paracetamol existen para el dolor de cabeza?",
|
54 |
+
"vacio.flac",
|
55 |
+
"vacio.flac",
|
56 |
+
80,
|
57 |
+
8,
|
58 |
+
3,
|
59 |
+
1.0,
|
60 |
+
250,
|
61 |
+
"wav2vec2-iic",
|
62 |
+
False
|
63 |
+
],
|
64 |
+
[
|
65 |
+
"¿Cuáles son los principales tipos de disartria del trastorno del habla motor?",
|
66 |
+
"vacio.flac",
|
67 |
+
"vacio.flac",
|
68 |
+
50,
|
69 |
+
8,
|
70 |
+
3,
|
71 |
+
1.0,
|
72 |
+
250,
|
73 |
+
"wav2vec2-iic",
|
74 |
+
False
|
75 |
+
],
|
76 |
+
[
|
77 |
+
"¿Es la esclerosis tuberosa una enfermedad genética?",
|
78 |
+
"vacio.flac",
|
79 |
+
"vacio.flac",
|
80 |
+
50,
|
81 |
+
8,
|
82 |
+
3,
|
83 |
+
1.0,
|
84 |
+
250,
|
85 |
+
"wav2vec2-iic",
|
86 |
+
False
|
87 |
+
],
|
88 |
+
[
|
89 |
+
"¿Cuál es la función de la proteína Mis18?",
|
90 |
+
"vacio.flac",
|
91 |
+
"vacio.flac",
|
92 |
+
50,
|
93 |
+
8,
|
94 |
+
3,
|
95 |
+
1.0,
|
96 |
+
250,
|
97 |
+
"wav2vec2-iic",
|
98 |
+
False
|
99 |
+
],
|
100 |
+
[
|
101 |
+
"¿Qué deficiencia es la causa del síndrome de piernas inquietas??",
|
102 |
+
"vacio.flac",
|
103 |
+
"vacio.flac",
|
104 |
+
50,
|
105 |
+
8,
|
106 |
+
3,
|
107 |
+
1.0,
|
108 |
+
250,
|
109 |
+
"wav2vec2-iic",
|
110 |
+
False
|
111 |
+
],
|
112 |
+
[
|
113 |
+
"¿Cuál es la función del 6SRNA en las bacterias?",
|
114 |
+
"vacio.flac",
|
115 |
+
"vacio.flac",
|
116 |
+
60,
|
117 |
+
8,
|
118 |
+
3,
|
119 |
+
1.0,
|
120 |
+
250,
|
121 |
+
"wav2vec2-iic",
|
122 |
+
False,
|
123 |
+
],
|
124 |
+
[
|
125 |
+
"¿Por qué los humanos desarrollamos diabetes?",
|
126 |
+
"vacio.flac",
|
127 |
+
"vacio.flac",
|
128 |
+
50,
|
129 |
+
10,
|
130 |
+
3,
|
131 |
+
1.0,
|
132 |
+
250,
|
133 |
+
"wav2vec2-iic",
|
134 |
+
False,
|
135 |
+
],
|
136 |
+
[
|
137 |
+
"¿Qué factores de riesgo aumentan la probabilidad de sufrir un ataque al corazón?",
|
138 |
+
"vacio.flac",
|
139 |
+
"vacio.flac",
|
140 |
+
80,
|
141 |
+
8,
|
142 |
+
3,
|
143 |
+
1.0,
|
144 |
+
250,
|
145 |
+
"wav2vec2-iic",
|
146 |
+
False
|
147 |
+
],
|
148 |
+
[
|
149 |
+
"¿Cómo funcionan las vacunas?",
|
150 |
+
"vacio.flac",
|
151 |
+
"vacio.flac",
|
152 |
+
90,
|
153 |
+
8,
|
154 |
+
3,
|
155 |
+
1.0,
|
156 |
+
250,
|
157 |
+
"wav2vec2-iic",
|
158 |
+
False
|
159 |
+
],
|
160 |
+
[
|
161 |
+
"¿Tienen conciencia los animales?",
|
162 |
+
"vacio.flac",
|
163 |
+
"vacio.flac",
|
164 |
+
70,
|
165 |
+
8,
|
166 |
+
3,
|
167 |
+
1.0,
|
168 |
+
250,
|
169 |
+
"wav2vec2-iic",
|
170 |
+
False
|
171 |
+
],
|
172 |
+
|
173 |
+
]
|
audio_troll.flac
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd58f522978eb9ed242c9c9ff6b3e4dd0054e55f74ab125f4d8f1d821bacfb01
|
3 |
+
size 585520
|
dpr_index_bio.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:098cc186374dde469b419bb91cd71ca1e0ac2fab02adae13977689f6e249e0be
|
3 |
+
size 68619327
|
dpr_index_bio_newdpr.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fcedce3fa1c9049abe6f5325ee16f937147d8a5b22b526969dbf77182ebc4c5b
|
3 |
+
size 59494679
|
dpr_index_bio_prueba.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:00aa65abf514ebe54aec3b25486589c6302a223ea07b5fc7c9f644ed081c9c6d
|
3 |
+
size 301101
|
dpr_index_bio_splitted.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bcc178f8b1ec7795834dd209875847f3f6fc6e26bcebce1b08be38d6bdd27211
|
3 |
+
size 162144239
|
general_utils.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import nltk
|
3 |
+
from scipy.io.wavfile import write
|
4 |
+
import librosa
|
5 |
+
import hashlib
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
|
9 |
+
def embed_questions(
|
10 |
+
question_model, question_tokenizer, questions, max_length=128, device="cpu"
|
11 |
+
):
|
12 |
+
query = question_tokenizer(
|
13 |
+
questions,
|
14 |
+
max_length=max_length,
|
15 |
+
padding="max_length",
|
16 |
+
truncation=True,
|
17 |
+
return_tensors="pt",
|
18 |
+
)
|
19 |
+
with torch.no_grad():
|
20 |
+
q_reps = question_model(
|
21 |
+
query["input_ids"].to(device), query["attention_mask"].to(device)
|
22 |
+
).pooler_output
|
23 |
+
return q_reps.cpu().numpy()
|
24 |
+
|
25 |
+
|
26 |
+
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"):
|
27 |
+
p = ctx_tokenizer(
|
28 |
+
passages["text"],
|
29 |
+
max_length=max_length,
|
30 |
+
padding="max_length",
|
31 |
+
truncation=True,
|
32 |
+
return_tensors="pt",
|
33 |
+
)
|
34 |
+
with torch.no_grad():
|
35 |
+
a_reps = ctx_model(
|
36 |
+
p["input_ids"].to(device), p["attention_mask"].to(device)
|
37 |
+
).pooler_output
|
38 |
+
return {"embeddings": a_reps.cpu().numpy()}
|
39 |
+
|
40 |
+
|
41 |
+
class Document:
|
42 |
+
def __init__(self, meta={}, content: str = "", id_: str = ""):
|
43 |
+
self.meta = meta
|
44 |
+
self.content = content
|
45 |
+
self.id = id_
|
46 |
+
|
47 |
+
|
48 |
+
def _alter_docs_for_haystack(passages):
|
49 |
+
return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)]
|
50 |
+
|
51 |
+
|
52 |
+
def embed_passages_haystack(
|
53 |
+
dpr_model,
|
54 |
+
passages,
|
55 |
+
):
|
56 |
+
passages = _alter_docs_for_haystack(passages["text"])
|
57 |
+
embeddings = dpr_model.embed_documents(passages)
|
58 |
+
return {"embeddings": embeddings}
|
59 |
+
|
60 |
+
|
61 |
+
def correct_casing(input_sentence):
|
62 |
+
"""This function is for correcting the casing of the generated transcribed text"""
|
63 |
+
sentences = nltk.sent_tokenize(input_sentence)
|
64 |
+
return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences])
|
65 |
+
|
66 |
+
|
67 |
+
def clean_transcript(text):
|
68 |
+
text = text.replace("[pad]".upper(), "")
|
69 |
+
return text
|
70 |
+
|
71 |
+
|
72 |
+
def add_question_symbols(text):
|
73 |
+
if text[0] != "¿":
|
74 |
+
text = "¿" + text
|
75 |
+
if text[-1] != "?":
|
76 |
+
text = text + "?"
|
77 |
+
return text
|
78 |
+
|
79 |
+
|
80 |
+
def remove_chars_to_tts(text):
|
81 |
+
text = text.replace(",", " ")
|
82 |
+
return text
|
83 |
+
|
84 |
+
|
85 |
+
def transcript(input_file, audio_array, processor, model):
|
86 |
+
if audio_array:
|
87 |
+
rate, sample = audio_array
|
88 |
+
write("temp.wav", rate, sample)
|
89 |
+
input_file = "temp.wav"
|
90 |
+
transcript = ""
|
91 |
+
# Ensure that the sample rate is 16k
|
92 |
+
sample_rate = librosa.get_samplerate(input_file)
|
93 |
+
|
94 |
+
# Stream over 10 seconds chunks rather than load the full file
|
95 |
+
stream = librosa.stream(
|
96 |
+
input_file,
|
97 |
+
block_length=20, # number of seconds to split the batch
|
98 |
+
frame_length=sample_rate, # 16000,
|
99 |
+
hop_length=sample_rate, # 16000
|
100 |
+
)
|
101 |
+
|
102 |
+
for speech in stream:
|
103 |
+
if len(speech.shape) > 1:
|
104 |
+
speech = speech[:, 0] + speech[:, 1]
|
105 |
+
if sample_rate != 16000:
|
106 |
+
speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
|
107 |
+
input_values = processor(speech, return_tensors="pt").input_values
|
108 |
+
logits = model(input_values).logits
|
109 |
+
|
110 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
111 |
+
transcription = processor.decode(
|
112 |
+
predicted_ids[0],
|
113 |
+
clean_up_tokenization_spaces=True,
|
114 |
+
skip_special_tokens=True,
|
115 |
+
)
|
116 |
+
transcription = clean_transcript(transcription)
|
117 |
+
# transcript += transcription.lower()
|
118 |
+
transcript += correct_casing(transcription.lower()) + ". "
|
119 |
+
# transcript += " "
|
120 |
+
whole_text = transcript[:3800]
|
121 |
+
whole_text = add_question_symbols(whole_text)
|
122 |
+
return whole_text
|
123 |
+
|
124 |
+
|
125 |
+
def parse_final_answer(answer_text: str, contexts: List):
|
126 |
+
"""Parse the final answer into correct format"""
|
127 |
+
s = (
|
128 |
+
f"<b><em>Final Answer:</em> {answer_text}</b> \n\n\n"
|
129 |
+
+ "<p> Contexts Used: \n <p>"
|
130 |
+
+ "\n".join(
|
131 |
+
[
|
132 |
+
("""<p style="text-align: justify;">""" + context)[:300]
|
133 |
+
+ "[...]</p>"
|
134 |
+
for context in contexts[:5]
|
135 |
+
]
|
136 |
+
)
|
137 |
+
)
|
138 |
+
return s
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
libsndfile1
|
2 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nltk==3.7
|
2 |
+
transformers==4.13.0
|
3 |
+
torch==1.10.2
|
4 |
+
librosa==0.9.1
|
5 |
+
numpy==1.21
|
6 |
+
gradio==2.8.13
|
7 |
+
jinja2==3.0.3
|
8 |
+
datasets==1.18.4
|
9 |
+
faiss-gpu==1.7.2
|
10 |
+
farm-haystack==1.3.0
|
save_faiss_index.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import DPRContextEncoderTokenizer, DPRContextEncoder
|
3 |
+
from general_utils import embed_passages, embed_passages_haystack
|
4 |
+
import faiss
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
from haystack.nodes import DensePassageRetriever
|
8 |
+
from haystack.document_stores import InMemoryDocumentStore
|
9 |
+
|
10 |
+
|
11 |
+
os.environ["OMP_NUM_THREADS"] = "8"
|
12 |
+
|
13 |
+
|
14 |
+
def create_faiss_index(args):
|
15 |
+
minchars = 200
|
16 |
+
dims = 128
|
17 |
+
|
18 |
+
dpr = DensePassageRetriever(
|
19 |
+
document_store=InMemoryDocumentStore(),
|
20 |
+
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
|
21 |
+
passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
|
22 |
+
max_seq_len_query=64,
|
23 |
+
max_seq_len_passage=256,
|
24 |
+
batch_size=512,
|
25 |
+
)
|
26 |
+
|
27 |
+
dataset = load_dataset(
|
28 |
+
"IIC/spanish_biomedical_crawled_corpus", split="train"
|
29 |
+
)
|
30 |
+
|
31 |
+
dataset = dataset.filter(lambda example: len(example["text"]) > minchars)
|
32 |
+
|
33 |
+
def embed_passages_retrieval(examples):
|
34 |
+
return embed_passages_haystack(dpr, examples)
|
35 |
+
|
36 |
+
dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192)
|
37 |
+
|
38 |
+
dataset.add_faiss_index(
|
39 |
+
column="embeddings",
|
40 |
+
string_factory="OPQ64_128,IVF4898,PQ64x4fsr",
|
41 |
+
train_size=len(dataset),
|
42 |
+
)
|
43 |
+
dataset.save_faiss_index("embeddings", args.index_file_name)
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"--ctx_encoder_name",
|
51 |
+
default="IIC/dpr-spanish-passage_encoder-squades-base",
|
52 |
+
help="Encoding model to use for passage encoding",
|
53 |
+
)
|
54 |
+
|
55 |
+
parser.add_argument(
|
56 |
+
"--index_file_name",
|
57 |
+
default="dpr_index_bio_splitted.faiss",
|
58 |
+
help="Faiss index file with passage embeddings",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--device", default="cuda:0", help="The device to index data on."
|
62 |
+
)
|
63 |
+
|
64 |
+
main_args, _ = parser.parse_known_args()
|
65 |
+
create_faiss_index(main_args)
|
tmptdsnrh_8.flac
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04f8d015f3597c6858e74d40c72fb70fe1caab7bf6b015ccca0eda5f53a49c71
|
3 |
+
size 389592
|
vacio.flac
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04a9780650bebeb4e93ccdc6f9298e7338a53cc30b7c3a281cd1a01ff2bbb5c8
|
3 |
+
size 103880
|