Spaces:
Running
Running
import os | |
import re | |
import io | |
import torch | |
import librosa | |
import zipfile | |
import requests | |
import torchaudio | |
import numpy as np | |
import gradio as gr | |
from uroman import uroman | |
import concurrent.futures | |
from pydub import AudioSegment | |
from datasets import load_dataset | |
from IPython.display import Audio | |
from scipy.signal import butter, lfilter | |
from speechbrain.pretrained import EncoderClassifier | |
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
# Variables | |
spk_model_name = "speechbrain/spkrec-xvect-voxceleb" | |
dataset_name = "truong-xuan-linh/vi-xvector-speechbrain" | |
cache_dir="temp/" | |
default_model_name = "truong-xuan-linh/speecht5-vietnamese-voiceclone-lsvsc" | |
speaker_id = "speech_dataset_denoised" | |
# Active device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load models and datasets | |
speaker_model = EncoderClassifier.from_hparams( | |
source=spk_model_name, | |
run_opts={"device": device}, | |
savedir=os.path.join("/tmp", spk_model_name), | |
) | |
dataset = load_dataset( | |
dataset_name, | |
download_mode="force_redownload", | |
verification_mode="no_checks", | |
cache_dir=cache_dir, | |
revision="5ea5e4345258333cbc6d1dd2544f6c658e66a634" | |
) | |
dataset = dataset["train"].to_list() | |
dataset_dict = {} | |
for rc in dataset: | |
dataset_dict[rc["speaker_id"]] = rc["embedding"] | |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
# Model utility functions | |
def remove_special_characters(sentence): | |
# Use regular expression to keep only letters, periods, and commas | |
sentence_after_removal = re.sub(r'[^a-zA-Z\s,.\u00C0-\u1EF9]', ' ,', sentence) | |
return sentence_after_removal | |
def create_speaker_embedding(waveform): | |
with torch.no_grad(): | |
speaker_embeddings = speaker_model.encode_batch(waveform) | |
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=-1) | |
return speaker_embeddings | |
def butter_bandpass(lowcut, highcut, fs, order=5): | |
nyq = 0.5 * fs | |
low = lowcut / nyq | |
high = highcut / nyq | |
b, a = butter(order, [low, high], btype='band') | |
return b, a | |
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): | |
b, a = butter_bandpass(lowcut, highcut, fs, order=order) | |
y = lfilter(b, a, data) | |
return y | |
def korean_splitter(string): | |
pattern = re.compile('[가-힣]+') | |
matches = pattern.findall(string) | |
return matches | |
def uroman_normalization(string): | |
korean_inputs = korean_splitter(string) | |
for korean_input in korean_inputs: | |
korean_roman = uroman(korean_input) | |
string = string.replace(korean_input, korean_roman) | |
return string | |
# Model class | |
class Model(): | |
def __init__(self, model_name, speaker_url=""): | |
self.model_name = model_name | |
self.processor = SpeechT5Processor.from_pretrained(model_name) | |
self.model = SpeechT5ForTextToSpeech.from_pretrained(model_name) | |
self.model.eval() | |
self.speaker_url = speaker_url | |
if speaker_url: | |
print(f"download speaker_url") | |
response = requests.get(speaker_url) | |
audio_stream = io.BytesIO(response.content) | |
audio_segment = AudioSegment.from_file(audio_stream, format="wav") | |
audio_segment = audio_segment.set_channels(1) | |
audio_segment = audio_segment.set_frame_rate(16000) | |
audio_segment = audio_segment.set_sample_width(2) | |
wavform, _ = torchaudio.load(audio_segment.export()) | |
self.speaker_embeddings = create_speaker_embedding(wavform)[0] | |
else: | |
self.speaker_embeddings = None | |
if model_name == "truong-xuan-linh/speecht5-vietnamese-commonvoice" or model_name == "truong-xuan-linh/speecht5-irmvivoice": | |
self.speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file | |
def inference(self, text, speaker_id=None): | |
if "voiceclone" in self.model_name: | |
if not self.speaker_url: | |
self.speaker_embeddings = torch.tensor(dataset_dict[speaker_id]) | |
with torch.no_grad(): | |
full_speech = [] | |
separators = r";|\.|!|\?|\n" | |
text = uroman_normalization(text) | |
text = remove_special_characters(text) | |
text = text.replace(" ", "▁") | |
split_texts = re.split(separators, text) | |
for split_text in split_texts: | |
if split_text != "▁": | |
split_text = split_text.lower() + "▁" | |
print(split_text) | |
inputs = self.processor.tokenizer(text=split_text, return_tensors="pt") | |
speech = self.model.generate_speech(inputs["input_ids"], threshold=0.5, speaker_embeddings=self.speaker_embeddings, vocoder=vocoder) | |
full_speech.append(speech.numpy()) | |
return np.concatenate(full_speech) | |
def moving_average(data, window_size): | |
return np.convolve(data, np.ones(window_size)/window_size, mode='same') | |
# Initialize model | |
model = Model( | |
model_name=default_model_name, | |
speaker_url="" | |
) | |
# Audio processing functions | |
def read_srt(file_path): | |
subtitles = [] | |
with open(file_path, 'r', encoding='utf-8') as file: | |
lines = file.readlines() | |
for i in range(0, len(lines), 4): | |
if i+2 < len(lines): | |
start_time, end_time = lines[i+1].strip().split('-->') | |
start_time = start_time.strip() | |
end_time = end_time.strip() | |
text = lines[i+2].strip() | |
# Delete trailing dots | |
while text.endswith('.'): | |
text = text[:-1] | |
subtitles.append((start_time, end_time, text)) | |
return subtitles | |
def is_valid_srt(file_path): | |
try: | |
read_srt(file_path) | |
return True | |
except: | |
return False | |
def time_to_seconds(time_str): | |
h, m, s = time_str.split(':') | |
seconds = int(h) * 3600 + int(m) * 60 + float(s.replace(',', '.')) | |
return seconds | |
def closest_speedup_factor(factor, allowed_factors): | |
return min(allowed_factors, key=lambda x: abs(x - factor)) + 0.1 | |
def generate_audio_with_pause(srt_file_path, speaker_id, speed_of_non_edit_speech): | |
subtitles = read_srt(srt_file_path) | |
audio_clips = [] | |
# allowed_factors = [1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] | |
for i, (start_time, end_time, text) in enumerate(subtitles): | |
# print("=====================================") | |
# print("Text number:", i) | |
# print(f"Start: {start_time}, End: {end_time}, Text: {text}") | |
# Generate initial audio | |
audio_data = model.inference(text=text, speaker_id=speaker_id) | |
audio_data = audio_data / np.max(np.abs(audio_data)) | |
# Calculate required duration | |
desired_duration = time_to_seconds(end_time) - time_to_seconds(start_time) | |
current_duration = len(audio_data) / 16000 | |
# print(f"Time to seconds: {time_to_seconds(start_time)}, {time_to_seconds(end_time)}") | |
# print(f"Desired duration: {desired_duration}, Current duration: {current_duration}") | |
# Adjust audio speed by speedup | |
if current_duration > desired_duration: | |
raw_speedup_factor = current_duration / desired_duration | |
# speedup_factor = closest_speedup_factor(raw_speedup_factor, allowed_factors) | |
speedup_factor = raw_speedup_factor | |
audio_data = librosa.effects.time_stretch( | |
y=audio_data, | |
rate=speedup_factor, | |
n_fft=1024, | |
hop_length=256 | |
) | |
audio_data = audio_data / np.max(np.abs(audio_data)) | |
audio_data = audio_data * 1.2 | |
if current_duration < desired_duration: | |
if speed_of_non_edit_speech != 1: | |
audio_data = librosa.effects.time_stretch( | |
y=audio_data, | |
rate=speed_of_non_edit_speech, | |
n_fft=1024, | |
hop_length=256 | |
) | |
audio_data = audio_data / np.max(np.abs(audio_data)) | |
audio_data = audio_data * 1.2 | |
current_duration = len(audio_data) / 16000 | |
padding = int((desired_duration - current_duration) * 16000) | |
audio_data = np.concatenate([np.zeros(padding), audio_data]) | |
# print(f"Final audio duration: {len(audio_data) / 16000}") | |
# print("=====================================") | |
audio_clips.append(audio_data) | |
# Add pause | |
if i < len(subtitles) - 1: | |
next_start_time = subtitles[i + 1][0] | |
pause_duration = time_to_seconds(next_start_time) - time_to_seconds(end_time) | |
if pause_duration: | |
pause_samples = int(pause_duration * 16000) | |
audio_clips.append(np.zeros(pause_samples)) | |
final_audio = np.concatenate(audio_clips) | |
return final_audio | |
def check_input_files(srt_files): | |
if not srt_files: | |
return None | |
invalid_files = [] | |
for srt_file in srt_files: | |
if not is_valid_srt(srt_file.name): | |
invalid_files.append(srt_file.name) | |
if invalid_files: | |
raise gr.Warning(f"Invalid SRT files: {', '.join(invalid_files)}") | |
def srt_to_audio_multi(srt_files, speaker_id, speed_of_non_edit_speech): | |
output_paths = [] | |
invalid_files = [] | |
def process_file(srt_file): | |
if not is_valid_srt(srt_file.name): | |
invalid_files.append(srt_file.name) | |
return None | |
audio_data = generate_audio_with_pause(srt_file.name, speaker_id, speed_of_non_edit_speech) | |
output_path = os.path.join(cache_dir, f'output_{os.path.basename(srt_file.name)}.wav') | |
torchaudio.save(output_path, torch.tensor(audio_data).unsqueeze(0), 16000) | |
return output_path | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = [executor.submit(process_file, srt_file) for srt_file in srt_files] | |
for future in concurrent.futures.as_completed(futures): | |
result = future.result() | |
if result: | |
output_paths.append(result) | |
if invalid_files: | |
raise gr.Warning(f"Invalid SRT files: {', '.join(invalid_files)}") | |
return output_paths | |
def download_all(outputs): | |
# If no outputs, return None | |
if not outputs: | |
raise gr.Warning("No files available for download.") | |
zip_path = os.path.join(cache_dir, "all_outputs.zip") | |
with zipfile.ZipFile(zip_path, 'w') as zipf: | |
for file_path in outputs: | |
zipf.write(file_path, os.path.basename(file_path)) | |
return zip_path | |
# Initialize model | |
model = Model( | |
model_name=default_model_name, | |
speaker_url="" | |
) | |
# UI display | |
css = ''' | |
#title{text-align: center} | |
#container{display: flex; justify-content: space-between; align-items: center;} | |
#setting-box{padding: 10px; border: 1px solid #ccc; border-radius: 5px;} | |
#setting-heading{margin-bottom: 10px; text-align: center;} | |
''' | |
with gr.Blocks(css=css) as demo: | |
title = gr.HTML( | |
"""<h1>SRT to Audio Tool</h1>""", | |
elem_id="title", | |
) | |
with gr.Column(elem_id="setting-box"): | |
heading = gr.HTML("<h2>Settings</h2>", elem_id="setting-heading") | |
with gr.Row(): | |
speaker_id = gr.Dropdown( | |
label="Speaker ID", | |
choices=list(dataset_dict.keys()), | |
value=speaker_id | |
) | |
speed_of_non_edit_speech = gr.Slider( | |
label="Speed of non-edit speech", | |
minimum=1, | |
maximum=2.0, | |
step=0.1, | |
value=1.2 | |
) | |
with gr.Row(elem_id="container"): | |
inp_srt = gr.File( | |
label="Upload SRT files", | |
file_count="multiple", | |
type="filepath", | |
file_types=["srt"], | |
height=600 | |
) | |
out = gr.File( | |
label="Generated Audio Files", | |
file_count="multiple", | |
type="filepath", | |
height=600, | |
interactive=False | |
) | |
btn = gr.Button("Generate") | |
download_btn = gr.Button("Download All") | |
download_out = gr.File( | |
label="Download ZIP", | |
interactive=False, | |
height=100 | |
) | |
inp_srt.change(check_input_files, inputs=inp_srt) | |
btn.click( | |
fn=srt_to_audio_multi, | |
inputs=[inp_srt, speaker_id, speed_of_non_edit_speech], | |
outputs=out | |
) | |
download_btn.click(fn=download_all, inputs=out, outputs=download_out) | |
if __name__ == "__main__": | |
demo.launch() |