video-to-subs / app.py
abnerh's picture
Update app.py
bc815d6
raw
history blame
4.19 kB
import os, sys, re
import shutil
import subprocess
import soundfile
from process_audio import segment_audio
from write_srt import write_to_file
from clean_text import clean_english, clean_german, clean_spanish
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import gradio as gr
english_model = "facebook/wav2vec2-large-960h-lv60-self"
english_tokenizer = Wav2Vec2Processor.from_pretrained(english_model)
english_asr_model = Wav2Vec2ForCTC.from_pretrained(english_model)
german_model = "jonatasgrosman/wav2vec2-large-xlsr-53-german"
german_tokenizer = Wav2Vec2Processor.from_pretrained(german_model)
german_asr_model = Wav2Vec2ForCTC.from_pretrained(german_model)
spanish_model = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
spanish_tokenizer = Wav2Vec2Processor.from_pretrained(spanish_model)
spanish_asr_model = Wav2Vec2ForCTC.from_pretrained(spanish_model)
# Get German corpus and update nltk
command = ["python", "-m", "textblob.download_corpora"]
subprocess.run(command)
# Line count for SRT file
line_count = 0
def sort_alphanumeric(data):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(data, key = alphanum_key)
def transcribe_audio(tokenizer, asr_model, audio_file, file_handle):
# Run Wav2Vec2.0 inference on each audio file generated after VAD segmentation.
global line_count
speech, rate = soundfile.read(audio_file)
input_values = tokenizer(speech, sampling_rate=16000, return_tensors = "pt", padding='longest').input_values
logits = asr_model(input_values).logits
prediction = torch.argmax(logits, dim = -1)
infered_text = tokenizer.batch_decode(prediction)[0].lower()
if len(infered_text) > 1:
if lang == 'english':
infered_text = clean_english(infered_text)
elif lang == 'german':
infered_text = clean_german(infered_text)
elif lang == 'spanish':
infered_text = clean_spanish(infered_text)
print(infered_text)
limits = audio_file.split(os.sep)[-1][:-4].split("_")[-1].split("-")
line_count += 1
write_to_file(file_handle, infered_text, line_count, limits)
else:
infered_text = ''
def get_subs(input_file, language):
# Get directory for audio
base_directory = os.getcwd()
audio_directory = os.path.join(base_directory, "audio")
if os.path.isdir(audio_directory):
shutil.rmtree(audio_directory)
os.mkdir(audio_directory)
# Extract audio from video file
video_file = input_file
audio_file = audio_directory+'/temp.wav'
command = ["ffmpeg", "-i", video_file, "-ac", "1", "-ar", "16000","-vn", "-f", "wav", audio_file]
subprocess.run(command)
video_file = input_file.split('/')[-1][:-4]
srt_file_name = os.path.join(video_file + ".srt")
# Split audio file based on VAD silent segments
segment_audio(audio_file)
os.remove(audio_file)
# Output SRT file
file_handle = open(srt_file_name, "a+")
file_handle.seek(0)
for file in sort_alphanumeric(os.listdir(audio_directory)):
audio_segment_path = os.path.join(audio_directory, file)
global lang
lang = language.lower()
tokenizer = globals()[lang+'_tokenizer']
asr_model = globals()[lang+'_asr_model']
if audio_segment_path.split(os.sep)[-1] != audio_file.split(os.sep)[-1]:
transcribe_audio(tokenizer, asr_model, audio_segment_path, file_handle)
file_handle.close()
shutil.rmtree(audio_directory)
return srt_file_name
gradio_ui = gr.Interface(
enable_queue=True
fn=get_subs,
title="Video to Subtitle",
description="Get subtitles (SRT file) for your videos. Inference speed is about 10s/per 1min of video BUT the speed of uploading your video depends on your internet connection.",
inputs=[gr.inputs.Video(label="Upload Video File"),
gr.inputs.Radio(label="Choose Language", choices=['English', 'German', 'Spanish'])],
outputs=gr.outputs.File(label="Auto-Transcript")
)
gradio_ui.launch()