ASS / core.py
VMORnD's picture
Update core.py
c6358d1
raw history blame
No virus
3.76 kB
import transformers
from transformers import pipeline
import whisper
import datetime
import os
import gradio as gr
from pytube import YouTube
transformers.utils.move_cache()
# ====================================
# Load speech recognition model
# speech_recognition_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")
speech_recognition_model = whisper.load_model("base")
# ====================================
# Load text summarization model English
# text_summarization_pipeline_En = pipeline("summarization", model="facebook/bart-large-cnn")
tokenizer_En = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
text_summarization_model_En = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
# ====================================
# Load text summarization model Vietnamese
tokenizer_Vi = transformers.AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
text_summarization_model_Vi = transformers.AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
def asr_transcript(input_file):
audio = whisper.load_audio(input_file)
output = speech_recognition_model.transcribe(audio)
text = output['text']
lang = "English"
if output["language"] == 'en':
lang = "English"
elif output["language"] == 'vi':
lang = "Vietnamese"
detail = ""
for segment in output['segments']:
start = str(datetime.timedelta(seconds=round(segment['start'])))
end = str(datetime.timedelta(seconds=round(segment['end'])))
small_text = segment['text']
detail = detail + start + "-" + end + " " + small_text + "\n"
return text, lang, detail
def text_summarize_en(text_input):
encoding = tokenizer_En(text_input, truncation=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
outputs = text_summarization_model_En.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=256,
min_length=20,
early_stopping=True
)
text = ""
for output in outputs:
line = tokenizer_En.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
text = text + line
return text
def text_summarize_vi(text_input):
encoding = tokenizer_Vi(text_input, truncation=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
outputs = text_summarization_model_Vi.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=256,
min_length=20,
early_stopping=True
)
text = ""
for output in outputs:
line = tokenizer_Vi.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
text = text + line
return text
def text_summarize(text_input, lang):
if (len(text_input) <= 100): return text_input
if lang == 'English':
return text_summarize_en(text_input)
elif lang == 'Vietnamese':
return text_summarize_vi(text_input)
else:
return ""
def load_video_url(url):
current_dir = os.getcwd()
try:
yt = YouTube(url)
except:
print("Connection Error")
raise gr.Error("Connection Error")
try:
highest_audio = yt.streams.filter(progressive=False).get_highest_resolution().itag
file_url = os.path.join(current_dir, "audio", "temp.mp4")
yt.streams.get_by_itag(highest_audio).download(output_path=os.path.join(current_dir, "audio"), filename = "temp.mp4")
except :
print("Download video error")
raise gr.Error("Download video error")
return file_url