minnehwg commited on
Commit
b222b37
1 Parent(s): 91c119e

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +139 -0
util.py CHANGED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
3
+ from youtube_transcript_api import YouTubeTranscriptApi
4
+ from deepmultilingualpunctuation import PunctuationModel
5
+ from googletrans import Translator
6
+ import time
7
+ import torch
8
+ import re
9
+
10
+
11
+
12
+ def load_model(cp):
13
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(cp)
15
+ return tokenizer, model
16
+
17
+
18
+ def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
19
+ model.to(device)
20
+ inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
21
+
22
+ with torch.no_grad():
23
+ summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
24
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
25
+
26
+ return summary
27
+
28
+
29
+ def processed(text):
30
+ processed_text = text.replace('\n', ' ')
31
+ processed_text = processed_text.lower()
32
+ return processed_text
33
+
34
+
35
+ def get_subtitles(video_url):
36
+ try:
37
+ video_id = video_url.split("v=")[1]
38
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
39
+ subs = " ".join(entry['text'] for entry in transcript)
40
+
41
+ return transcript, subs
42
+
43
+ except Exception as e:
44
+ return [], f"An error occurred: {e}"
45
+
46
+
47
+ def restore_punctuation(text):
48
+ model = PunctuationModel()
49
+ result = model.restore_punctuation(text)
50
+ return result
51
+
52
+
53
+ def translate_long(text, language='vi'):
54
+ translator = Translator()
55
+ limit = 4700
56
+ chunks = []
57
+ current_chunk = ''
58
+
59
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
60
+
61
+ for sentence in sentences:
62
+ if len(current_chunk) + len(sentence) <= limit:
63
+ current_chunk += sentence.strip() + ' '
64
+ else:
65
+ chunks.append(current_chunk.strip())
66
+ current_chunk = sentence.strip() + ' '
67
+
68
+ if current_chunk:
69
+ chunks.append(current_chunk.strip())
70
+
71
+ translated_text = ''
72
+
73
+ for chunk in chunks:
74
+ try:
75
+ time.sleep(1)
76
+ translation = translator.translate(chunk, dest=language)
77
+ translated_text += translation.text + ' '
78
+ except Exception as e:
79
+ translated_text += chunk + ' '
80
+
81
+ return translated_text.strip()
82
+
83
+ def split_into_chunks(text, max_words=800, overlap_sentences=2):
84
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
85
+
86
+ chunks = []
87
+ current_chunk = []
88
+ current_word_count = 0
89
+
90
+ for sentence in sentences:
91
+ word_count = len(sentence.split())
92
+ if current_word_count + word_count <= max_words:
93
+ current_chunk.append(sentence)
94
+ current_word_count += word_count
95
+ else:
96
+ if len(current_chunk) >= overlap_sentences:
97
+ overlap = current_chunk[-overlap_sentences:]
98
+ print(f"Overlapping sentences: {' '.join(overlap)}")
99
+ chunks.append(' '.join(current_chunk))
100
+ current_chunk = current_chunk[-overlap_sentences:] + [sentence]
101
+ current_word_count = sum(len(sent.split()) for sent in current_chunk)
102
+ if current_chunk:
103
+ if len(current_chunk) >= overlap_sentences:
104
+ overlap = current_chunk[-overlap_sentences:]
105
+ print(f"Overlapping sentences: {' '.join(overlap)}")
106
+ chunks.append(' '.join(current_chunk))
107
+
108
+ return chunks
109
+
110
+
111
+ def post_processing(text):
112
+ sentences = re.split(r'(?<=[.!?])\s*', text)
113
+ for i in range(len(sentences)):
114
+ if sentences[i]:
115
+ sentences[i] = sentences[i][0].upper() + sentences[i][1:]
116
+ text = " ".join(sentences)
117
+ return text
118
+
119
+ def display(text):
120
+ sentences = re.split(r'(?<=[.!?])\s*', text)
121
+ unique_sentences = list(dict.fromkeys(sentences[:-1]))
122
+ formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
123
+ return formatted_sentences
124
+
125
+ def pipeline(url):
126
+ trans, sub = get_subtitles(url)
127
+ sub = restore_punctuation(sub)
128
+ vie_sub = translate_long(sub)
129
+ vie_sub = processed(vie_sub)
130
+ chunks = split_into_chunks(vie_sub, 700, 3)
131
+ sum_para = []
132
+ for i in chunks:
133
+ tmp = summarize(i, model_aug, tokenizer, num_beams=4)
134
+ sum_para.append(tmp)
135
+ sum = ''.join(sum_para)
136
+ del sub, vie_sub, sum_para, chunks
137
+ sum = post_processing(sum)
138
+ re = display(sum)
139
+ return re