Spaces:
Running
Running
import os | |
import re | |
import io | |
import torch | |
import requests | |
import torchaudio | |
import numpy as np | |
import gradio as gr | |
from uroman import uroman | |
import concurrent.futures | |
from pydub import AudioSegment | |
from datasets import load_dataset | |
from IPython.display import Audio | |
from scipy.signal import butter, lfilter | |
from speechbrain.pretrained import EncoderClassifier | |
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
# Variables | |
spk_model_name = "speechbrain/spkrec-xvect-voxceleb" | |
dataset_name = "truong-xuan-linh/vi-xvector-speechbrain" | |
cache_dir="temp/" | |
default_model_name = "truong-xuan-linh/speecht5-vietnamese-voiceclone-lsvsc" | |
speaker_id = "speech_dataset_denoised" | |
# Active device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load models and datasets | |
speaker_model = EncoderClassifier.from_hparams( | |
source=spk_model_name, | |
run_opts={"device": device}, | |
savedir=os.path.join("/tmp", spk_model_name), | |
) | |
dataset = load_dataset( | |
dataset_name, | |
download_mode="force_redownload", | |
verification_mode="no_checks", | |
cache_dir=cache_dir, | |
revision="5ea5e4345258333cbc6d1dd2544f6c658e66a634" | |
) | |
dataset = dataset["train"].to_list() | |
dataset_dict = {} | |
for rc in dataset: | |
dataset_dict[rc["speaker_id"]] = rc["embedding"] | |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
# Model utility functions | |
def remove_special_characters(sentence): | |
# Use regular expression to keep only letters, periods, and commas | |
sentence_after_removal = re.sub(r'[^a-zA-Z\s,.\u00C0-\u1EF9]', ' ,', sentence) | |
return sentence_after_removal | |
def create_speaker_embedding(waveform): | |
with torch.no_grad(): | |
speaker_embeddings = speaker_model.encode_batch(waveform) | |
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=-1) | |
return speaker_embeddings | |
def butter_bandpass(lowcut, highcut, fs, order=5): | |
nyq = 0.5 * fs | |
low = lowcut / nyq | |
high = highcut / nyq | |
b, a = butter(order, [low, high], btype='band') | |
return b, a | |
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): | |
b, a = butter_bandpass(lowcut, highcut, fs, order=order) | |
y = lfilter(b, a, data) | |
return y | |
def korean_splitter(string): | |
pattern = re.compile('[가-힣]+') | |
matches = pattern.findall(string) | |
return matches | |
def uroman_normalization(string): | |
korean_inputs = korean_splitter(string) | |
for korean_input in korean_inputs: | |
korean_roman = uroman(korean_input) | |
string = string.replace(korean_input, korean_roman) | |
return string | |
# Model class | |
class Model(): | |
def __init__(self, model_name, speaker_url=""): | |
self.model_name = model_name | |
self.processor = SpeechT5Processor.from_pretrained(model_name) | |
self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name) | |
self.model.eval() | |
self.speaker_url = speaker_url | |
if speaker_url: | |
print(f"download speaker_url") | |
response = requests.get(speaker_url) | |
audio_stream = io.BytesIO(response.content) | |
audio_segment = AudioSegment.from_file(audio_stream, format="wav") | |
audio_segment = audio_segment.set_channels(1) | |
audio_segment = audio_segment.set_frame_rate(16000) | |
audio_segment = audio_segment.set_sample_width(2) | |
wavform, _ = torchaudio.load(audio_segment.export()) | |
self.speaker_embeddings = create_speaker_embedding(wavform)[0] | |
else: | |
self.speaker_embeddings = None | |
if model_name == "truong-xuan-linh/speecht5-vietnamese-commonvoice" or model_name == "truong-xuan-linh/speecht5-irmvivoice": | |
self.speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file | |
def inference(self, text, speaker_id=None): | |
if "voiceclone" in self.model_name: | |
if not self.speaker_url: | |
self.speaker_embeddings = torch.tensor(dataset_dict[speaker_id]) | |
with torch.no_grad(): | |
full_speech = [] | |
separators = r";|\.|!|\?|\n" | |
text = uroman_normalization(text) | |
text = remove_special_characters(text) | |
text = text.replace(" ", "▁") | |
split_texts = re.split(separators, text) | |
for split_text in split_texts: | |
if split_text != "▁": | |
split_text = split_text.lower() + "▁" | |
print(split_text) | |
inputs = self.processor.tokenizer(text=split_text, return_tensors="pt") | |
speech = self.model.generate_speech(inputs["input_ids"], threshold=0.5, speaker_embeddings=self.speaker_embeddings, vocoder=vocoder) | |
full_speech.append(speech.numpy()) | |
return np.concatenate(full_speech) | |
def moving_average(data, window_size): | |
return np.convolve(data, np.ones(window_size)/window_size, mode='same') | |
# Initialize model | |
model = Model( | |
model_name=default_model_name, | |
speaker_url="" | |
) | |
# Audio processing functions | |
def read_srt(file_path): | |
subtitles = [] | |
with open(file_path, 'r', encoding='utf-8') as file: | |
lines = file.readlines() | |
for i in range(0, len(lines), 4): | |
if i+2 < len(lines): | |
start_time, end_time = lines[i+1].strip().split('-->') | |
start_time = start_time.strip() | |
end_time = end_time.strip() | |
text = lines[i+2].strip() | |
subtitles.append((start_time, end_time, text)) | |
return subtitles | |
def is_valid_srt(file_path): | |
try: | |
read_srt(file_path) | |
return True | |
except: | |
return False | |
def time_to_seconds(time_str): | |
h, m, s = time_str.split(':') | |
seconds = int(h) * 3600 + int(m) * 60 + float(s.replace(',', '.')) | |
return seconds | |
def generate_audio_with_pause(srt_file_path): | |
subtitles = read_srt(srt_file_path) | |
audio_clips = [] | |
for i, (start_time, end_time, text) in enumerate(subtitles): | |
audio_data = model.inference(text=text, speaker_id=speaker_id) | |
audio_data = audio_data / np.max(np.abs(audio_data)) | |
audio_clips.append(audio_data) | |
if i < len(subtitles) - 1: | |
next_start_time = subtitles[i + 1][0] | |
pause_duration = time_to_seconds(next_start_time) - time_to_seconds(end_time) | |
if pause_duration > 0: | |
pause_samples = int(pause_duration * 16000) | |
audio_clips.append(np.zeros(pause_samples)) | |
final_audio = np.concatenate(audio_clips) | |
return final_audio | |
def srt_to_audio_multi(srt_files): | |
output_paths = [] | |
invalid_files = [] | |
def process_file(srt_file): | |
if not is_valid_srt(srt_file.name): | |
invalid_files.append(srt_file.name) | |
return None | |
audio_data = generate_audio_with_pause(srt_file.name) | |
output_path = os.path.join(cache_dir, f'output_{os.path.basename(srt_file.name)}.wav') | |
torchaudio.save(output_path, torch.tensor(audio_data).unsqueeze(0), 16000) | |
return output_path | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = [executor.submit(process_file, srt_file) for srt_file in srt_files] | |
for future in concurrent.futures.as_completed(futures): | |
result = future.result() | |
if result: | |
output_paths.append(result) | |
if invalid_files: | |
raise ValueError(f"Invalid SRT files: {', '.join(invalid_files)}") | |
return output_paths | |
# Initialize model | |
model = Model( | |
model_name=default_model_name, | |
speaker_url="" | |
) | |
# UI display | |
css = ''' | |
#title{text-align: center} | |
#container{display: flex; justify-content: space-between; align-items: center;} | |
''' | |
with gr.Blocks(css=css) as demo: | |
title = gr.HTML( | |
"""<h1>SRT to Audio Tool</h1>""", | |
elem_id="title", | |
) | |
with gr.Row(elem_id="container"): | |
inp = gr.File(label="Upload SRT files", file_count="multiple", type="filepath") | |
out = gr.File(label="Generated Audio Files", file_count="multiple", type="filepath") | |
btn = gr.Button("Generate") | |
btn.click(fn=srt_to_audio_multi, inputs=inp, outputs=out) | |
if __name__ == "__main__": | |
demo.launch() |