File size: 3,983 Bytes
6bf4ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4228c91
6bf4ad7
 
05aebdd
 
 
6bf4ad7
05aebdd
6bf4ad7
 
 
 
 
05aebdd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import nltk
from scipy.io.wavfile import write
import librosa
import hashlib
from typing import List


def embed_questions(
    question_model, question_tokenizer, questions, max_length=128, device="cpu"
):
    query = question_tokenizer(
        questions,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        q_reps = question_model(
            query["input_ids"].to(device), query["attention_mask"].to(device)
        ).pooler_output
    return q_reps.cpu().numpy()


def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cpu"):
    p = ctx_tokenizer(
        passages["text"],
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        a_reps = ctx_model(
            p["input_ids"].to(device), p["attention_mask"].to(device)
        ).pooler_output
    return {"embeddings": a_reps.cpu().numpy()}


class Document:
    def __init__(self, meta={}, content: str = "", id_: str = ""):
        self.meta = meta
        self.content = content
        self.id = id_


def _alter_docs_for_haystack(passages):
    return [Document(content=passage, id_=str(i)) for i, passage in enumerate(passages)]


def embed_passages_haystack(
    dpr_model,
    passages,
):
    passages = _alter_docs_for_haystack(passages["text"])
    embeddings = dpr_model.embed_documents(passages)
    return {"embeddings": embeddings}


def correct_casing(input_sentence):
    """This function is for correcting the casing of the generated transcribed text"""
    sentences = nltk.sent_tokenize(input_sentence)
    return " ".join([s.replace(s[0], s[0].capitalize(), 1) for s in sentences])


def clean_transcript(text):
    text = text.replace("[pad]".upper(), "")
    return text


def add_question_symbols(text):
    if text[0] != "¿":
        text = "¿" + text
    if text[-1] != "?":
        text = text + "?"
    return text


def remove_chars_to_tts(text):
    text = text.replace(",", " ")
    return text


def transcript(input_file, audio_array, processor, model):
    if audio_array:
        rate, sample = audio_array
        write("temp.wav", rate, sample)
        input_file = "temp.wav"
    transcript = ""
    # Ensure that the sample rate is 16k
    sample_rate = librosa.get_samplerate(input_file)

    # Stream over 10 seconds chunks rather than load the full file
    stream = librosa.stream(
        input_file,
        block_length=20,  # number of seconds to split the batch
        frame_length=sample_rate,  # 16000,
        hop_length=sample_rate,  # 16000
    )

    for speech in stream:
        if len(speech.shape) > 1:
            speech = speech[:, 0] + speech[:, 1]
        if sample_rate != 16000:
            speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
        input_values = processor(speech, return_tensors="pt").input_values
        logits = model(input_values).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.decode(
            predicted_ids[0],
            clean_up_tokenization_spaces=True,
            skip_special_tokens=True,
        )
        transcription = clean_transcript(transcription)
        # transcript += transcription.lower()
        transcript += correct_casing(transcription.lower()) + ". "
        # transcript += " "
    whole_text = transcript[:3800]
    whole_text = add_question_symbols(whole_text)
    return whole_text


def parse_final_answer(answer_text: str, contexts: List):
    """Parse the final answer into correct format"""
    answer = f"<p><b>{answer_text}</b></p> \n\n\n"
    docs = (
        "\n".join(
            [
                ("""<p style="text-align: justify;">""" + context)[:250]
                + "[...]</p>"
                for context in contexts[:5]
            ]
        )
    )
    return answer, docs