File size: 5,307 Bytes
d03e31e
72aa5a1
 
 
 
 
 
 
d03e31e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72aa5a1
d03e31e
 
 
 
 
 
 
ff5ad78
d03e31e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72aa5a1
d03e31e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72aa5a1
d03e31e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21324a6
 
d03e31e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from transformers import pipeline, AutoTokenizer, T5ForConditionalGeneration
import torch
from pyannote.audio import Pipeline
import os
from faster_whisper import WhisperModel
from pydub import AudioSegment
from random import randint
import gradio as gr


model_name = "IlyaGusev/rut5_base_sum_gazeta"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_sum = T5ForConditionalGeneration.from_pretrained(model_name)


def summ_mT5_G(text):
    input_ids = tokenizer(
        [text],
        max_length=600,
        add_special_tokens=True,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )["input_ids"]
    output_ids = model_sum.generate(
        input_ids=input_ids,
        no_repeat_ngram_size=4
    )[0]

    summary = tokenizer.decode(output_ids, skip_special_tokens=True)
    return summary


# punctuation
model_punc, example_texts, languages, punct, apply_te = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_te')


def punct(text):
    # print(text)
    return apply_te(text.lower(), lan='ru')


pipeline_a = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=str(os.getenv("s1")))


# верификация
def speackers_list(audio_f : str):
    # # send pipeline to GPU (when available)
    # import torch
    # pipeline.to(torch.device("cuda"))

    # apply pretrained pipeline
    diarization = pipeline_a(audio_f)
    speackers_list = []
    # print the result
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        # print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}")
        speackers_list.append([turn.start, turn.end, speaker])

    # speackers_list.sort(key = lambda x: x[0])
    # print(speackers_list)
    i = 0
    while i<len(speackers_list)-1:
        if speackers_list[i][-1] == speackers_list[i+1][-1] and (i==0 or (speackers_list[i-1][1]<speackers_list[i][1])):
              speackers_list[i][1] = speackers_list[i+1][1]
              speackers_list.pop(i+1)
        else:
          i+=1
    return speackers_list


# speackers = speackers_list(name_of_file)

model_size = "large-v3"
# Run on GPU with FP16
model_tts = WhisperModel(model_size)  #, lan = "ru") #,  device="cpu", compute_type="int8") #, device="cuda", compute_type="float16")


def speach_to_text(file_name):

    segments, info = model_tts.transcribe(file_name, beam_size=5)

    # print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
    text_of_seg = ""
    for segment in segments:
        # segment_text = Segment_text(text = segment.text, start = round(segment.start,5), end = round(segment.end, 5))
        # list_of_segments.append(segment_text)
        # print(f"[{segment.start:.2f}  -> {segment.end:.2f}] {segment.text}")

        text_of_seg += segment.text
    # return list_of_segments
    return text_of_seg


class Segment_text:
    text : str # текст сегмента, примененный
    start : float # время начала сегмента
    end : float  # время окончания сегмента
    speacker : str # верифицированный спикер
    summarization : str # абстракт исходного текста сегмента


    def __init__(self, text : str, start : float, end : float, speacker : str ): # , summarization : str):
        self.text = punct(text)
        self.start = start
        self.end = end
        self.speacker = speacker
        self.value = len(text)
        self.summarization = summ_mT5_G(text) if self.value > 200 else text # summarization

    def get_text(self):
        return self.text

    def get_time(self):
        return (self.start, self.end)

    def get_summarization(self):
        return self.summarization

    def get_speacker(self):
        return self.speacker


def init_segments(speackers, name_of_file):
    list_of_segments = []
    audio = AudioSegment.from_file(name_of_file)
    for ind, seg in enumerate(speackers):
        temp_seg = audio[seg[0]*10**3 : seg[1]*10**3]
        name_of_seg = "seg"+str(ind)+".mp3"
        temp_seg.export(name_of_seg, format="mp3")

        temp_text = speach_to_text(name_of_seg)
        segment_text = Segment_text(text = temp_text, start = seg[0], end = seg[1], speacker = seg[2])
        list_of_segments.append(segment_text)
    return list_of_segments


def get_text_to_out(list_of_segments : list):
    res_text, res_sum = "", ""
    for seg in list_of_segments:
        res_text += f"{seg.get_speacker()} : {seg.get_text()}\n"
    for seg in list_of_segments:
        res_sum += seg.get_speacker() + ":  " + seg.get_summarization() + "\n"
    return res_text, res_sum


def do_smth(file):
    audio = AudioSegment.from_wav(file)
    name_of_file = "f"+str(randint(1,10**8))
    audio.export(name_of_file, format="mp3")

    speackers = speackers_list(name_of_file)

    list_of_segments = init_segments(speackers, name_of_file)

    out_text, out_sum = get_text_to_out(list_of_segments)

    return out_text, out_sum


demo = gr.Interface(
    do_smth,
    gr.Audio(type="filepath"),
    [
        gr.Textbox(value="", label="Исходный текст"),
        gr.Textbox(value="", label="Сокращенный текст")
    ]
)
demo.launch()