minnehwg commited on
Commit
a4a5aff
1 Parent(s): 8908a57

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +142 -2
util.py CHANGED
@@ -1,2 +1,142 @@
1
- def update(name):
2
- return f"Welcome to Gradio, {name}!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
3
+ from youtube_transcript_api import YouTubeTranscriptApi
4
+ from deepmultilingualpunctuation import PunctuationModel
5
+ from googletrans import Translator
6
+ import time
7
+ import torch
8
+ import re
9
+
10
+
11
+ def load_model(cp):
12
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(cp)
14
+ return tokenizer, model
15
+
16
+
17
+ def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
18
+ model.to(device)
19
+ inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
20
+
21
+ with torch.no_grad():
22
+ summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
23
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
24
+
25
+ return summary
26
+
27
+
28
+ def processed(text):
29
+ processed_text = text.replace('\n', ' ')
30
+ processed_text = processed_text.lower()
31
+ return processed_text
32
+
33
+
34
+ def get_subtitles(video_url):
35
+ try:
36
+ video_id = video_url.split("v=")[1]
37
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
38
+ subs = " ".join(entry['text'] for entry in transcript)
39
+ print(subs)
40
+
41
+ return transcript, subs
42
+
43
+ except Exception as e:
44
+ return [], f"An error occurred: {e}"
45
+
46
+ from youtube_transcript_api import YouTubeTranscriptApi
47
+
48
+
49
+ def restore_punctuation(text):
50
+ model = PunctuationModel()
51
+ result = model.restore_punctuation(text)
52
+ return result
53
+
54
+
55
+ def translate_long(text, language='vi'):
56
+ translator = Translator()
57
+ limit = 4700
58
+ chunks = []
59
+ current_chunk = ''
60
+
61
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
62
+
63
+ for sentence in sentences:
64
+ if len(current_chunk) + len(sentence) <= limit:
65
+ current_chunk += sentence.strip() + ' '
66
+ else:
67
+ chunks.append(current_chunk.strip())
68
+ current_chunk = sentence.strip() + ' '
69
+
70
+ if current_chunk:
71
+ chunks.append(current_chunk.strip())
72
+
73
+ translated_text = ''
74
+
75
+ for chunk in chunks:
76
+ try:
77
+ time.sleep(1)
78
+ translation = translator.translate(chunk, dest=language)
79
+ translated_text += translation.text + ' '
80
+ except Exception as e:
81
+ translated_text += chunk + ' '
82
+
83
+ return translated_text.strip()
84
+
85
+ def split_into_chunks(text, max_words=800, overlap_sentences=2):
86
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
87
+
88
+ chunks = []
89
+ current_chunk = []
90
+ current_word_count = 0
91
+
92
+ for sentence in sentences:
93
+ word_count = len(sentence.split())
94
+ if current_word_count + word_count <= max_words:
95
+ current_chunk.append(sentence)
96
+ current_word_count += word_count
97
+ else:
98
+ if len(current_chunk) >= overlap_sentences:
99
+ overlap = current_chunk[-overlap_sentences:]
100
+ chunks.append(' '.join(current_chunk))
101
+ current_chunk = current_chunk[-overlap_sentences:] + [sentence]
102
+ current_word_count = sum(len(sent.split()) for sent in current_chunk)
103
+ if current_chunk:
104
+ if len(current_chunk) >= overlap_sentences:
105
+ overlap = current_chunk[-overlap_sentences:]
106
+ chunks.append(' '.join(current_chunk))
107
+
108
+ return chunks
109
+
110
+
111
+ def post_processing(text):
112
+ sentences = re.split(r'(?<=[.!?])\s*', text)
113
+ for i in range(len(sentences)):
114
+ if sentences[i]:
115
+ sentences[i] = sentences[i][0].upper() + sentences[i][1:]
116
+ text = " ".join(sentences)
117
+ return text
118
+
119
+
120
+ def display(text):
121
+ sentences = re.split(r'(?<=[.!?])\s*', text)
122
+ unique_sentences = list(dict.fromkeys(sentences[:-1]))
123
+ formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
124
+ return formatted_sentences
125
+
126
+
127
+
128
+ def pipeline(url, model, tokenizer):
129
+ trans, sub = get_subtitles(url)
130
+ sub = restore_punctuation(sub)
131
+ vie_sub = translate_long(sub)
132
+ vie_sub = processed(vie_sub)
133
+ chunks = split_into_chunks(vie_sub, 700, 2)
134
+ sum_para = []
135
+ for i in chunks:
136
+ tmp = summarize(i, model, tokenizer, num_beams=3)
137
+ sum_para.append(tmp)
138
+ suma = ''.join(sum_para)
139
+ del sub, vie_sub, sum_para, chunks
140
+ suma = post_processing(suma)
141
+ re = display(suma)
142
+ return re