File size: 3,759 Bytes
360c1a4
 
 
 
 
 
 
ab76106
 
 
 
360c1a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6358d1
360c1a4
 
 
 
 
 
 
 
 
 
 
 
 
 
c6358d1
360c1a4
 
 
 
 
 
 
 
 
471175f
360c1a4
 
 
 
 
ab76106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import transformers 
from transformers import pipeline

import whisper

import datetime

import os
import gradio as gr
from pytube import YouTube

transformers.utils.move_cache()

# ====================================
# Load speech recognition model
# speech_recognition_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
speech_recognition_model = whisper.load_model("base")

# ====================================
# Load text summarization model English
# text_summarization_pipeline_En = pipeline("summarization", model="facebook/bart-large-cnn")
tokenizer_En = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-cnn")  
text_summarization_model_En = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

# ====================================
# Load text summarization model Vietnamese
tokenizer_Vi = transformers.AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")  
text_summarization_model_Vi = transformers.AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")

def asr_transcript(input_file):
    audio = whisper.load_audio(input_file)
    output = speech_recognition_model.transcribe(audio)
    text = output['text']
    lang = "English"
    if output["language"] == 'en':
        lang = "English"
    elif output["language"] == 'vi':
        lang = "Vietnamese"

    detail = ""
    for segment in output['segments']:
        start = str(datetime.timedelta(seconds=round(segment['start'])))
        end = str(datetime.timedelta(seconds=round(segment['end'])))
        small_text = segment['text']
        detail = detail + start + "-" + end + " " + small_text + "\n"
    return text, lang, detail

def text_summarize_en(text_input):
    encoding = tokenizer_En(text_input, truncation=True, return_tensors="pt")
    input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
    outputs = text_summarization_model_En.generate(
        input_ids=input_ids, attention_mask=attention_masks,
        max_length=256,
        min_length=20,
        early_stopping=True
    )
    text = ""
    for output in outputs:
        line = tokenizer_En.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        text = text + line
    return text

def text_summarize_vi(text_input):
    encoding = tokenizer_Vi(text_input, truncation=True, return_tensors="pt")
    input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
    outputs = text_summarization_model_Vi.generate(
        input_ids=input_ids, attention_mask=attention_masks,
        max_length=256,
        min_length=20,
        early_stopping=True
    )
    text = ""
    for output in outputs:
        line = tokenizer_Vi.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        text = text + line
    return text

def text_summarize(text_input, lang):
    if (len(text_input) <= 100): return text_input
    if lang == 'English':
        return text_summarize_en(text_input)
    elif lang == 'Vietnamese':
        return text_summarize_vi(text_input)
    else:
        return ""

def load_video_url(url):
    current_dir = os.getcwd()
    
    try: 
        yt = YouTube(url) 
    except: 
        print("Connection Error")
        raise gr.Error("Connection Error")
    try:
        highest_audio = yt.streams.filter(progressive=False).get_highest_resolution().itag
        file_url = os.path.join(current_dir, "audio", "temp.mp4")
        yt.streams.get_by_itag(highest_audio).download(output_path=os.path.join(current_dir, "audio"), filename = "temp.mp4") 
    except : 
        print("Download video error") 
        raise gr.Error("Download video error")

    return file_url