BioMedIA / general_utils.py
avacaondata's picture
quitados algunos comments
4228c91
raw
history blame contribute delete
No virus
3.98 kB
import torch
import nltk
from scipy.io.wavfile import write
import librosa
import hashlib
from typing import List
def embed_questions(
question_model, question_tokenizer, questions, max_length=128, device="cpu"
):
query = question_tokenizer(
questions,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
q_reps = question_model(
query["input_ids"].to(device), query["attention_mask"].to(device)
).pooler_output
return q_reps.cpu().numpy()
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"):
p = ctx_tokenizer(
passages["text"],
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
a_reps = ctx_model(
p["input_ids"].to(device), p["attention_mask"].to(device)
).pooler_output
return {"embeddings": a_reps.cpu().numpy()}
class Document:
def __init__(self, meta={}, content: str = "", id_: str = ""):
self.meta = meta
self.content = content
self.id = id_
def _alter_docs_for_haystack(passages):
return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)]
def embed_passages_haystack(
dpr_model,
passages,
):
passages = _alter_docs_for_haystack(passages["text"])
embeddings = dpr_model.embed_documents(passages)
return {"embeddings": embeddings}
def correct_casing(input_sentence):
"""This function is for correcting the casing of the generated transcribed text"""
sentences = nltk.sent_tokenize(input_sentence)
return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences])
def clean_transcript(text):
text = text.replace("[pad]".upper(), "")
return text
def add_question_symbols(text):
if text[0] != "¿":
text = "¿" + text
if text[-1] != "?":
text = text + "?"
return text
def remove_chars_to_tts(text):
text = text.replace(",", " ")
return text
def transcript(input_file, audio_array, processor, model):
if audio_array:
rate, sample = audio_array
write("temp.wav", rate, sample)
input_file = "temp.wav"
transcript = ""
# Ensure that the sample rate is 16k
sample_rate = librosa.get_samplerate(input_file)
# Stream over 10 seconds chunks rather than load the full file
stream = librosa.stream(
input_file,
block_length=20, # number of seconds to split the batch
frame_length=sample_rate, # 16000,
hop_length=sample_rate, # 16000
)
for speech in stream:
if len(speech.shape) > 1:
speech = speech[:, 0] + speech[:, 1]
if sample_rate != 16000:
speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
input_values = processor(speech, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(
predicted_ids[0],
clean_up_tokenization_spaces=True,
skip_special_tokens=True,
)
transcription = clean_transcript(transcription)
# transcript += transcription.lower()
transcript += correct_casing(transcription.lower()) + ". "
# transcript += " "
whole_text = transcript[:3800]
whole_text = add_question_symbols(whole_text)
return whole_text
def parse_final_answer(answer_text: str, contexts: List):
"""Parse the final answer into correct format"""
answer = f"<p><b>{answer_text}</b></p> \n\n\n"
docs = (
"\n".join(
[
("""<p style="text-align: justify;">""" + context)[:250]
+ "[...]</p>"
for context in contexts[:5]
]
)
)
return answer, docs