|
import torch |
|
import nltk |
|
from scipy.io.wavfile import write |
|
import librosa |
|
import hashlib |
|
from typing import List |
|
|
|
|
|
def embed_questions( |
|
question_model, question_tokenizer, questions, max_length=128, device="cpu" |
|
): |
|
query = question_tokenizer( |
|
questions, |
|
max_length=max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
with torch.no_grad(): |
|
q_reps = question_model( |
|
query["input_ids"].to(device), query["attention_mask"].to(device) |
|
).pooler_output |
|
return q_reps.cpu().numpy() |
|
|
|
|
|
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"): |
|
p = ctx_tokenizer( |
|
passages["text"], |
|
max_length=max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
with torch.no_grad(): |
|
a_reps = ctx_model( |
|
p["input_ids"].to(device), p["attention_mask"].to(device) |
|
).pooler_output |
|
return {"embeddings": a_reps.cpu().numpy()} |
|
|
|
|
|
class Document: |
|
def __init__(self, meta={}, content: str = "", id_: str = ""): |
|
self.meta = meta |
|
self.content = content |
|
self.id = id_ |
|
|
|
|
|
def _alter_docs_for_haystack(passages): |
|
return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)] |
|
|
|
|
|
def embed_passages_haystack( |
|
dpr_model, |
|
passages, |
|
): |
|
passages = _alter_docs_for_haystack(passages["text"]) |
|
embeddings = dpr_model.embed_documents(passages) |
|
return {"embeddings": embeddings} |
|
|
|
|
|
def correct_casing(input_sentence): |
|
"""This function is for correcting the casing of the generated transcribed text""" |
|
sentences = nltk.sent_tokenize(input_sentence) |
|
return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences]) |
|
|
|
|
|
def clean_transcript(text): |
|
text = text.replace("[pad]".upper(), "") |
|
return text |
|
|
|
|
|
def add_question_symbols(text): |
|
if text[0] != "¿": |
|
text = "¿" + text |
|
if text[-1] != "?": |
|
text = text + "?" |
|
return text |
|
|
|
|
|
def remove_chars_to_tts(text): |
|
text = text.replace(",", " ") |
|
return text |
|
|
|
|
|
def transcript(input_file, audio_array, processor, model): |
|
if audio_array: |
|
rate, sample = audio_array |
|
write("temp.wav", rate, sample) |
|
input_file = "temp.wav" |
|
transcript = "" |
|
|
|
sample_rate = librosa.get_samplerate(input_file) |
|
|
|
|
|
stream = librosa.stream( |
|
input_file, |
|
block_length=20, |
|
frame_length=sample_rate, |
|
hop_length=sample_rate, |
|
) |
|
|
|
for speech in stream: |
|
if len(speech.shape) > 1: |
|
speech = speech[:, 0] + speech[:, 1] |
|
if sample_rate != 16000: |
|
speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000) |
|
input_values = processor(speech, return_tensors="pt").input_values |
|
logits = model(input_values).logits |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.decode( |
|
predicted_ids[0], |
|
clean_up_tokenization_spaces=True, |
|
skip_special_tokens=True, |
|
) |
|
transcription = clean_transcript(transcription) |
|
|
|
transcript += correct_casing(transcription.lower()) + ". " |
|
|
|
whole_text = transcript[:3800] |
|
whole_text = add_question_symbols(whole_text) |
|
return whole_text |
|
|
|
|
|
def parse_final_answer(answer_text: str, contexts: List): |
|
"""Parse the final answer into correct format""" |
|
answer = f"<p><b>{answer_text}</b></p> \n\n\n" |
|
docs = ( |
|
"\n".join( |
|
[ |
|
("""<p style="text-align: justify;">""" + context)[:250] |
|
+ "[...]</p>" |
|
for context in contexts[:5] |
|
] |
|
) |
|
) |
|
return answer, docs |
|
|