BioMedIA / general_utils.py
avacaondata's picture
quitados algunos comments
4228c91
raw history blame
No virus
3.98 kB
import torch
import nltk
from scipy.io.wavfile import write
import librosa
import hashlib
from typing import List
def embed_questions(
question_model, question_tokenizer, questions, max_length=128, device="cpu"
):
query = question_tokenizer(
questions,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
q_reps = question_model(
query["input_ids"].to(device), query["attention_mask"].to(device)
).pooler_output
return q_reps.cpu().numpy()
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"):
p = ctx_tokenizer(
passages["text"],
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
a_reps = ctx_model(
p["input_ids"].to(device), p["attention_mask"].to(device)
).pooler_output
return {"embeddings": a_reps.cpu().numpy()}
class Document:
def __init__(self, meta={}, content: str = "", id_: str = ""):
self.meta = meta
self.content = content
self.id = id_
def _alter_docs_for_haystack(passages):
return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)]
def embed_passages_haystack(
dpr_model,
passages,
):
passages = _alter_docs_for_haystack(passages["text"])
embeddings = dpr_model.embed_documents(passages)
return {"embeddings": embeddings}
def correct_casing(input_sentence):
"""This function is for correcting the casing of the generated transcribed text"""
sentences = nltk.sent_tokenize(input_sentence)
return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences])
def clean_transcript(text):
text = text.replace("[pad]".upper(), "")
return text
def add_question_symbols(text):
if text[0] != "¿":
text = "¿" + text
if text[-1] != "?":
text = text + "?"
return text
def remove_chars_to_tts(text):
text = text.replace(",", " ")
return text
def transcript(input_file, audio_array, processor, model):
if audio_array:
rate, sample = audio_array
write("temp.wav", rate, sample)
input_file = "temp.wav"
transcript = ""
# Ensure that the sample rate is 16k
sample_rate = librosa.get_samplerate(input_file)
# Stream over 10 seconds chunks rather than load the full file
stream = librosa.stream(
input_file,
block_length=20, # number of seconds to split the batch
frame_length=sample_rate, # 16000,
hop_length=sample_rate, # 16000
)
for speech in stream:
if len(speech.shape) > 1:
speech = speech[:, 0] + speech[:, 1]
if sample_rate != 16000:
speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
input_values = processor(speech, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(
predicted_ids[0],
clean_up_tokenization_spaces=True,
skip_special_tokens=True,
)
transcription = clean_transcript(transcription)
# transcript += transcription.lower()
transcript += correct_casing(transcription.lower()) + ". "
# transcript += " "
whole_text = transcript[:3800]
whole_text = add_question_symbols(whole_text)
return whole_text
def parse_final_answer(answer_text: str, contexts: List):
"""Parse the final answer into correct format"""
answer = f"<p><b>{answer_text}</b></p> \n\n\n"
docs = (
"\n".join(
[
("""<p style="text-align: justify;">""" + context)[:250]
+ "[...]</p>"
for context in contexts[:5]
]
)
)
return answer, docs