Alejandro Vaca commited on
Commit
6bf4ad7
1 Parent(s): 92d518a

initial commit

Browse files
.gitattributes CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.flac filter=lfs diff=lfs merge=lfs -text
29
+ *.faiss filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ DPRQuestionEncoder,
4
+ DPRQuestionEncoderTokenizer,
5
+ MT5ForConditionalGeneration,
6
+ AutoTokenizer,
7
+ AutoModelForCTC,
8
+ Wav2Vec2Tokenizer,
9
+ )
10
+ from general_utils import (
11
+ embed_questions,
12
+ transcript,
13
+ remove_chars_to_tts,
14
+ parse_final_answer,
15
+ )
16
+ from typing import List
17
+ import gradio as gr
18
+ from article_app import article, description, examples
19
+ from haystack.nodes import DensePassageRetriever
20
+ from haystack.document_stores import InMemoryDocumentStore
21
+ import numpy as np
22
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
23
+
24
+ topk = 21
25
+ minchars = 200
26
+ min_snippet_length = 20
27
+ device = "cpu"
28
+ covidterms = ["covid19", "covid", "coronavirus", "covid-19", "sars-cov-2"]
29
+
30
+ models = {
31
+ "wav2vec2-iic": {
32
+ "processor": Wav2Vec2Tokenizer.from_pretrained(
33
+ "IIC/wav2vec2-spanish-multilibrispeech"
34
+ ),
35
+ "model": AutoModelForCTC.from_pretrained(
36
+ "IIC/wav2vec2-spanish-multilibrispeech"
37
+ ),
38
+ },
39
+ # "wav2vec2-jonatangrosman": {
40
+ # "processor": Wav2Vec2Tokenizer.from_pretrained(
41
+ # "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
42
+ # ),
43
+ # "model": AutoModelForCTC.from_pretrained(
44
+ # "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
45
+ # ),
46
+ # },
47
+ }
48
+
49
+
50
+ tts_es = gr.Interface.load("huggingface/facebook/tts_transformer-es-css10")
51
+
52
+
53
+ params_generate = {
54
+ "min_length": 50,
55
+ "max_length": 250,
56
+ "do_sample": False,
57
+ "early_stopping": True,
58
+ "num_beams": 8,
59
+ "temperature": 1.0,
60
+ "top_k": None,
61
+ "top_p": None,
62
+ "no_repeat_ngram_size": 3,
63
+ "num_return_sequences": 1,
64
+ }
65
+
66
+ dpr = DensePassageRetriever(
67
+ document_store=InMemoryDocumentStore(),
68
+ query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
69
+ passage_embedding_model="IIC/dpr-spanish-passage_encoder-allqa-base",
70
+ max_seq_len_query=64,
71
+ max_seq_len_passage=256,
72
+ batch_size=512,
73
+ use_gpu=False,
74
+ )
75
+
76
+ mt5_tokenizer = AutoTokenizer.from_pretrained("IIC/mt5-base-lfqa-es")
77
+ mt5_lfqa = MT5ForConditionalGeneration.from_pretrained("IIC/mt5-base-lfqa-es")
78
+
79
+ similarity_model = SentenceTransformer(
80
+ "distiluse-base-multilingual-cased", device="cpu"
81
+ )
82
+
83
+ crossencoder = CrossEncoder("avacaondata/roberta-base-bne-ranker", device="cpu")
84
+
85
+ dataset = load_dataset("IIC/spanish_biomedical_crawled_corpus", split="train")
86
+
87
+ dataset = dataset.filter(lambda example: len(example["text"]) > minchars)
88
+
89
+ dataset.load_faiss_index(
90
+ "embeddings",
91
+ "dpr_index_bio_newdpr.faiss",
92
+ )
93
+
94
+
95
+ def query_index(question: str):
96
+ question_embedding = dpr.embed_queries([question])[0]
97
+ scores, closest_passages = dataset.get_nearest_examples(
98
+ "embeddings", question_embedding, k=topk
99
+ )
100
+ contexts = [
101
+ closest_passages["text"][i] for i in range(len(closest_passages["text"]))
102
+ ]
103
+ # [:int(topk / 3)]
104
+ return [
105
+ context for context in contexts if len(context.split()) > min_snippet_length
106
+ ]
107
+
108
+
109
+ def sort_on_similarity(question, contexts, include_rank: int = 5):
110
+ # TODO: METER AQUÍ EL CROSSENCODER nuestro
111
+ question_encoded = similarity_model.encode([question])[0]
112
+ ctxs_encoded = similarity_model.encode(contexts)
113
+ similarity_scores = [
114
+ util.cos_sim(question_encoded, ctx_encoded) for ctx_encoded in ctxs_encoded
115
+ ]
116
+ similarity_ranking_idx = np.flip(np.argsort(similarity_scores))
117
+ return [contexts[idx] for idx in similarity_ranking_idx][:include_rank]
118
+
119
+
120
+ def create_context(contexts: List):
121
+ return "<p>" + "<p>".join(contexts)
122
+
123
+
124
+ def create_model_input(question: str, context: str):
125
+ return f"question: {question} context: {context}"
126
+
127
+
128
+ def generate_answer(model_input, update_params):
129
+ model_input = mt5_tokenizer(
130
+ model_input, truncation=True, padding=True, return_tensors="pt", max_length=1024
131
+ )
132
+ params_generate.update(update_params)
133
+ answers_encoded = mt5_lfqa.generate(
134
+ input_ids=model_input["input_ids"].to(device),
135
+ attention_mask=model_input["attention_mask"].to(device),
136
+ **params_generate,
137
+ )
138
+ answers = mt5_tokenizer.batch_decode(
139
+ answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True
140
+ )
141
+ results = [{"generated_text": answer} for answer in answers]
142
+ return results
143
+
144
+
145
+ def search_and_answer(
146
+ question,
147
+ audio_file,
148
+ audio_array,
149
+ min_length_answer,
150
+ num_beams,
151
+ no_repeat_ngram_size,
152
+ temperature,
153
+ max_answer_length,
154
+ wav2vec2_name,
155
+ do_tts,
156
+ ):
157
+ update_params = {
158
+ "min_length": min_length_answer,
159
+ "max_length": max_answer_length,
160
+ "num_beams": int(num_beams),
161
+ "temperature": temperature,
162
+ "no_repeat_ngram_size": no_repeat_ngram_size,
163
+ }
164
+ if not question:
165
+ s2t_model = models[wav2vec2_name]["model"]
166
+ s2t_processor = models[wav2vec2_name]["processor"]
167
+ question = transcript(
168
+ audio_file, audio_array, processor=s2t_processor, model=s2t_model
169
+ )
170
+ print(f"Transcripted question: *** {question} ****")
171
+ if any([any([term in word.lower() for term in covidterms]) for word in question.split(" ")]):
172
+ return "Del COVID no queremos saber ya más nada, lo sentimos, pregúntame sobre otra cosa :P ", "tmptdsnrh_8.flac"
173
+ contexts = query_index(question)
174
+ contexts = sort_on_similarity(question, contexts)
175
+ context = create_context(contexts)
176
+ model_input = create_model_input(question, context)
177
+ answers = generate_answer(model_input, update_params)
178
+ final_answer = answers[0]["generated_text"]
179
+ if do_tts:
180
+ audio_answer = tts_es(remove_chars_to_tts(final_answer))
181
+ final_answer = parse_final_answer(final_answer, contexts)
182
+ return final_answer, audio_answer if do_tts else "tmptdsnrh_8.flac"
183
+
184
+
185
+ if __name__ == "__main__":
186
+ gr.Interface(
187
+ search_and_answer,
188
+ inputs=[
189
+ gr.inputs.Textbox(
190
+ lines=2,
191
+ label="Question",
192
+ placeholder="Type your question (in spanish) to the system.",
193
+ optional=True,
194
+ ),
195
+ gr.inputs.Audio(
196
+ source="upload",
197
+ type="filepath",
198
+ label="Upload your audio asking a question here.",
199
+ optional=True,
200
+ ),
201
+ gr.inputs.Audio(
202
+ source="microphone",
203
+ type="numpy",
204
+ label="Record your audio asking a question.",
205
+ optional=True,
206
+ ),
207
+ gr.inputs.Slider(
208
+ minimum=10,
209
+ maximum=200,
210
+ default=50,
211
+ label="Minimum size for the answer",
212
+ step=1,
213
+ ),
214
+ gr.inputs.Slider(
215
+ minimum=4, maximum=12, default=8, label="number of beams", step=1
216
+ ),
217
+ gr.inputs.Slider(
218
+ minimum=2, maximum=5, default=3, label="no repeat n-gram size", step=1
219
+ ),
220
+ gr.inputs.Slider(
221
+ minimum=0.8, maximum=2.0, default=1.0, label="temperature", step=0.1
222
+ ),
223
+ gr.inputs.Slider(
224
+ minimum=220,
225
+ maximum=360,
226
+ default=250,
227
+ label="maximum answer length",
228
+ step=1,
229
+ ),
230
+ gr.inputs.Dropdown(
231
+ ["wav2vec2-iic", "wav2vec2-jonatangrosman"],
232
+ type="value",
233
+ default=None,
234
+ label="Select the speech recognition model.",
235
+ optional=False,
236
+ ),
237
+ gr.inputs.Checkbox(
238
+ default=False, label="Text to Speech", optional=True),
239
+ ],
240
+ outputs=[
241
+ gr.outputs.HTML(
242
+ # type="str",
243
+ label="Answer from the system."
244
+ ),
245
+ gr.outputs.Audio(label="Answer in audio"),
246
+ ],
247
+ # title="Abstractive QA of BioMedical Domain in Spanish",
248
+ description=description,
249
+ examples=examples,
250
+ theme="grass",
251
+ article=article,
252
+ thumbnail="IIC_logoP.png",
253
+ css="https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css",
254
+ ).launch()
article_app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ article = """
2
+ <img src="https://www.iic.uam.es/wp-content/uploads/2017/12/IIC_logoP.png">
3
+ <img src="https://drive.google.com/uc?export=view&id=1S8v94q39QRCfmVTMvjLCACmhMe9lJQdc">
4
+
5
+ <p style="text-align: justify;"> This app is developed by <a href="https://www.iic.uam.es/">IIC - Instituto de Ingeniería del Conocimiento</a> as part of the <a href="https://www.eventbrite.com/e/registro-hackathon-de-pln-en-espanol-273014111557">Somos PLN Hackaton 2022.</a>
6
+
7
+ The objective of this app is to expand the existing tools regarding long form question answering in Spanish. In fact, multiple novel methods (in Spanish)
8
+ have been introduced to build this app.
9
+ The reason for including audio as a possible input and always as an output is because we wanted to make the App much more accessible to people that cannot read or write.
10
+ Below you can find all the pieces that form the system.
11
+
12
+ 1. <a href="https://huggingface.co/IIC/wav2vec2-spanish-multilibrispeech">Speech2Text</a>: For this we finedtuned a multilingual Wav2Vec2, as explained in the attached link. We use this model to process audio questions.
13
+ 2. <a href="https://huggingface.co/IIC/dpr-spanish-passage_encoder-allqa-base">Dense Passage Retrieval for Context</a>: Dense Passage Retrieval is a methodology <a href="https://arxiv.org/abs/2004.04906">developed by Facebook</a> which is currently the SoTA for Passage Retrieval,
14
+ that is, the task of getting the most relevant passages to answer a given question with. You can find details about how it was trained on the link attached to the name.
15
+ 3. <a href="https://huggingface.co/IIC/dpr-spanish-question_encoder-allqa-base">Dense Passage Retrieval for Question</a>: It is actually part of the same thing as the above. For more details, go to the attached link.
16
+ 4. <a href="https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased-v1">Sentence Encoder Ranker</a>: To rerank the candidate contexts retrieved by dpr for the generative model to see. This also selects the top 5 passages for the model to read, it is the final filter before the generative model.
17
+ 5. <a href="https://huggingface.co/IIC/mt5-base-lfqa-es">Generative Long-Form Question Answering Model</a>: For this we used either mT5 (the one attached) or <a href="https://huggingface.co/IIC/mbart-large-lfqa-es">mBART</a>. This generative model receives the most relevant
18
+ passages and uses them to generate an answer to the question. In the attached link there are more details about how we trained it etc.
19
+
20
+ On the other hand, we uploaded, and in some cases created, datasets in Spanish to be able to build such a system.
21
+
22
+ 1. <a href="https://huggingface.co/datasets/IIC/spanish_biomedical_crawled_corpus">Spanish Biomedical Crawled Corpus</a>. Used for finding answers to questions about biomedicine. (More info in the link.)
23
+ 2. <a href="https://huggingface.co/datasets/IIC/lfqa_spanish">LFQA_Spanish</a>. Used for training the generative model. (More info in the link.)
24
+ 3. <a href="https://huggingface.co/datasets/squad_es">SQUADES</a>. Used to train the DPR models. (More info in the link.)
25
+ 4. <a href="https://huggingface.co/datasets/IIC/bioasq22_es">BioAsq22-Spanish</a>. Used to train the DPR models. (More info in the link.)
26
+ 5. <a href="https://huggingface.co/datasets/PlanTL-GOB-ES/SQAC">SQAC (Spanish Question Answering Corpus)</a>. Used to train the DPR models. (More info in the link.)
27
+ </p>
28
+ """
29
+ # height="100", width="1000"
30
+ description = """
31
+ <a href="https://www.iic.uam.es/">
32
+ <img src="https://drive.google.com/uc?export=view&id=1xNz4EuafyzvMKSMTEfwzELln155uN6_H" style="max-width: 100%; max-height: 10%; height: 250px; object-fit: fill">,
33
+ </a>
34
+ <h1> BioMedIA: Abstractive Question Answering of BioMedical Domain in Spanish </h1>
35
+ Esta aplicación consiste en sistemas de búsqueda del Estado del Arte en Español junto con un modelo generativo entrenado para componer una respuesta a preguntas a partir de una serie de contextos.
36
+ """
37
+
38
+
39
+ examples = [
40
+ [
41
+ "¿Cuáles son los efectos secundarios más ampliamente reportados en el tratamiento de la enfermedad de Crohn?",
42
+ "vacio.flac",
43
+ "vacio.flac",
44
+ 60,
45
+ 8,
46
+ 3,
47
+ 1.0,
48
+ 250,
49
+ "wav2vec2-iic",
50
+ False,
51
+ ],
52
+ [
53
+ "¿Qué alternativas al Paracetamol existen para el dolor de cabeza?",
54
+ "vacio.flac",
55
+ "vacio.flac",
56
+ 80,
57
+ 8,
58
+ 3,
59
+ 1.0,
60
+ 250,
61
+ "wav2vec2-iic",
62
+ False
63
+ ],
64
+ [
65
+ "¿Cuáles son los principales tipos de disartria del trastorno del habla motor?",
66
+ "vacio.flac",
67
+ "vacio.flac",
68
+ 50,
69
+ 8,
70
+ 3,
71
+ 1.0,
72
+ 250,
73
+ "wav2vec2-iic",
74
+ False
75
+ ],
76
+ [
77
+ "¿Es la esclerosis tuberosa una enfermedad genética?",
78
+ "vacio.flac",
79
+ "vacio.flac",
80
+ 50,
81
+ 8,
82
+ 3,
83
+ 1.0,
84
+ 250,
85
+ "wav2vec2-iic",
86
+ False
87
+ ],
88
+ [
89
+ "¿Cuál es la función de la proteína Mis18?",
90
+ "vacio.flac",
91
+ "vacio.flac",
92
+ 50,
93
+ 8,
94
+ 3,
95
+ 1.0,
96
+ 250,
97
+ "wav2vec2-iic",
98
+ False
99
+ ],
100
+ [
101
+ "¿Qué deficiencia es la causa del síndrome de piernas inquietas??",
102
+ "vacio.flac",
103
+ "vacio.flac",
104
+ 50,
105
+ 8,
106
+ 3,
107
+ 1.0,
108
+ 250,
109
+ "wav2vec2-iic",
110
+ False
111
+ ],
112
+ [
113
+ "¿Cuál es la función del 6SRNA en las bacterias?",
114
+ "vacio.flac",
115
+ "vacio.flac",
116
+ 60,
117
+ 8,
118
+ 3,
119
+ 1.0,
120
+ 250,
121
+ "wav2vec2-iic",
122
+ False,
123
+ ],
124
+ [
125
+ "¿Por qué los humanos desarrollamos diabetes?",
126
+ "vacio.flac",
127
+ "vacio.flac",
128
+ 50,
129
+ 10,
130
+ 3,
131
+ 1.0,
132
+ 250,
133
+ "wav2vec2-iic",
134
+ False,
135
+ ],
136
+ [
137
+ "¿Qué factores de riesgo aumentan la probabilidad de sufrir un ataque al corazón?",
138
+ "vacio.flac",
139
+ "vacio.flac",
140
+ 80,
141
+ 8,
142
+ 3,
143
+ 1.0,
144
+ 250,
145
+ "wav2vec2-iic",
146
+ False
147
+ ],
148
+ [
149
+ "¿Cómo funcionan las vacunas?",
150
+ "vacio.flac",
151
+ "vacio.flac",
152
+ 90,
153
+ 8,
154
+ 3,
155
+ 1.0,
156
+ 250,
157
+ "wav2vec2-iic",
158
+ False
159
+ ],
160
+ [
161
+ "¿Tienen conciencia los animales?",
162
+ "vacio.flac",
163
+ "vacio.flac",
164
+ 70,
165
+ 8,
166
+ 3,
167
+ 1.0,
168
+ 250,
169
+ "wav2vec2-iic",
170
+ False
171
+ ],
172
+
173
+ ]
audio_troll.flac ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd58f522978eb9ed242c9c9ff6b3e4dd0054e55f74ab125f4d8f1d821bacfb01
3
+ size 585520
dpr_index_bio.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:098cc186374dde469b419bb91cd71ca1e0ac2fab02adae13977689f6e249e0be
3
+ size 68619327
dpr_index_bio_newdpr.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcedce3fa1c9049abe6f5325ee16f937147d8a5b22b526969dbf77182ebc4c5b
3
+ size 59494679
dpr_index_bio_prueba.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00aa65abf514ebe54aec3b25486589c6302a223ea07b5fc7c9f644ed081c9c6d
3
+ size 301101
dpr_index_bio_splitted.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcc178f8b1ec7795834dd209875847f3f6fc6e26bcebce1b08be38d6bdd27211
3
+ size 162144239
general_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nltk
3
+ from scipy.io.wavfile import write
4
+ import librosa
5
+ import hashlib
6
+ from typing import List
7
+
8
+
9
+ def embed_questions(
10
+ question_model, question_tokenizer, questions, max_length=128, device="cpu"
11
+ ):
12
+ query = question_tokenizer(
13
+ questions,
14
+ max_length=max_length,
15
+ padding="max_length",
16
+ truncation=True,
17
+ return_tensors="pt",
18
+ )
19
+ with torch.no_grad():
20
+ q_reps = question_model(
21
+ query["input_ids"].to(device), query["attention_mask"].to(device)
22
+ ).pooler_output
23
+ return q_reps.cpu().numpy()
24
+
25
+
26
+ def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"):
27
+ p = ctx_tokenizer(
28
+ passages["text"],
29
+ max_length=max_length,
30
+ padding="max_length",
31
+ truncation=True,
32
+ return_tensors="pt",
33
+ )
34
+ with torch.no_grad():
35
+ a_reps = ctx_model(
36
+ p["input_ids"].to(device), p["attention_mask"].to(device)
37
+ ).pooler_output
38
+ return {"embeddings": a_reps.cpu().numpy()}
39
+
40
+
41
+ class Document:
42
+ def __init__(self, meta={}, content: str = "", id_: str = ""):
43
+ self.meta = meta
44
+ self.content = content
45
+ self.id = id_
46
+
47
+
48
+ def _alter_docs_for_haystack(passages):
49
+ return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)]
50
+
51
+
52
+ def embed_passages_haystack(
53
+ dpr_model,
54
+ passages,
55
+ ):
56
+ passages = _alter_docs_for_haystack(passages["text"])
57
+ embeddings = dpr_model.embed_documents(passages)
58
+ return {"embeddings": embeddings}
59
+
60
+
61
+ def correct_casing(input_sentence):
62
+ """This function is for correcting the casing of the generated transcribed text"""
63
+ sentences = nltk.sent_tokenize(input_sentence)
64
+ return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences])
65
+
66
+
67
+ def clean_transcript(text):
68
+ text = text.replace("[pad]".upper(), "")
69
+ return text
70
+
71
+
72
+ def add_question_symbols(text):
73
+ if text[0] != "¿":
74
+ text = "¿" + text
75
+ if text[-1] != "?":
76
+ text = text + "?"
77
+ return text
78
+
79
+
80
+ def remove_chars_to_tts(text):
81
+ text = text.replace(",", " ")
82
+ return text
83
+
84
+
85
+ def transcript(input_file, audio_array, processor, model):
86
+ if audio_array:
87
+ rate, sample = audio_array
88
+ write("temp.wav", rate, sample)
89
+ input_file = "temp.wav"
90
+ transcript = ""
91
+ # Ensure that the sample rate is 16k
92
+ sample_rate = librosa.get_samplerate(input_file)
93
+
94
+ # Stream over 10 seconds chunks rather than load the full file
95
+ stream = librosa.stream(
96
+ input_file,
97
+ block_length=20, # number of seconds to split the batch
98
+ frame_length=sample_rate, # 16000,
99
+ hop_length=sample_rate, # 16000
100
+ )
101
+
102
+ for speech in stream:
103
+ if len(speech.shape) > 1:
104
+ speech = speech[:, 0] + speech[:, 1]
105
+ if sample_rate != 16000:
106
+ speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
107
+ input_values = processor(speech, return_tensors="pt").input_values
108
+ logits = model(input_values).logits
109
+
110
+ predicted_ids = torch.argmax(logits, dim=-1)
111
+ transcription = processor.decode(
112
+ predicted_ids[0],
113
+ clean_up_tokenization_spaces=True,
114
+ skip_special_tokens=True,
115
+ )
116
+ transcription = clean_transcript(transcription)
117
+ # transcript += transcription.lower()
118
+ transcript += correct_casing(transcription.lower()) + ". "
119
+ # transcript += " "
120
+ whole_text = transcript[:3800]
121
+ whole_text = add_question_symbols(whole_text)
122
+ return whole_text
123
+
124
+
125
+ def parse_final_answer(answer_text: str, contexts: List):
126
+ """Parse the final answer into correct format"""
127
+ s = (
128
+ f"<b><em>Final Answer:</em> {answer_text}</b> \n\n\n"
129
+ + "<p> Contexts Used: \n <p>"
130
+ + "\n".join(
131
+ [
132
+ ("""<p style="text-align: justify;">""" + context)[:300]
133
+ + "[...]</p>"
134
+ for context in contexts[:5]
135
+ ]
136
+ )
137
+ )
138
+ return s
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ libsndfile1
2
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ nltk==3.7
2
+ transformers==4.13.0
3
+ torch==1.10.2
4
+ librosa==0.9.1
5
+ numpy==1.21
6
+ gradio==2.8.13
7
+ jinja2==3.0.3
8
+ datasets==1.18.4
9
+ faiss-gpu==1.7.2
10
+ farm-haystack==1.3.0
save_faiss_index.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import DPRContextEncoderTokenizer, DPRContextEncoder
3
+ from general_utils import embed_passages, embed_passages_haystack
4
+ import faiss
5
+ import argparse
6
+ import os
7
+ from haystack.nodes import DensePassageRetriever
8
+ from haystack.document_stores import InMemoryDocumentStore
9
+
10
+
11
+ os.environ["OMP_NUM_THREADS"] = "8"
12
+
13
+
14
+ def create_faiss_index(args):
15
+ minchars = 200
16
+ dims = 128
17
+
18
+ dpr = DensePassageRetriever(
19
+ document_store=InMemoryDocumentStore(),
20
+ query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
21
+ passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
22
+ max_seq_len_query=64,
23
+ max_seq_len_passage=256,
24
+ batch_size=512,
25
+ )
26
+
27
+ dataset = load_dataset(
28
+ "IIC/spanish_biomedical_crawled_corpus", split="train"
29
+ )
30
+
31
+ dataset = dataset.filter(lambda example: len(example["text"]) > minchars)
32
+
33
+ def embed_passages_retrieval(examples):
34
+ return embed_passages_haystack(dpr, examples)
35
+
36
+ dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192)
37
+
38
+ dataset.add_faiss_index(
39
+ column="embeddings",
40
+ string_factory="OPQ64_128,IVF4898,PQ64x4fsr",
41
+ train_size=len(dataset),
42
+ )
43
+ dataset.save_faiss_index("embeddings", args.index_file_name)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
48
+
49
+ parser.add_argument(
50
+ "--ctx_encoder_name",
51
+ default="IIC/dpr-spanish-passage_encoder-squades-base",
52
+ help="Encoding model to use for passage encoding",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--index_file_name",
57
+ default="dpr_index_bio_splitted.faiss",
58
+ help="Faiss index file with passage embeddings",
59
+ )
60
+ parser.add_argument(
61
+ "--device", default="cuda:0", help="The device to index data on."
62
+ )
63
+
64
+ main_args, _ = parser.parse_known_args()
65
+ create_faiss_index(main_args)
tmptdsnrh_8.flac ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04f8d015f3597c6858e74d40c72fb70fe1caab7bf6b015ccca0eda5f53a49c71
3
+ size 389592
vacio.flac ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04a9780650bebeb4e93ccdc6f9298e7338a53cc30b7c3a281cd1a01ff2bbb5c8
3
+ size 103880