import spaces import torch import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import gradio as gr import traceback import gc import numpy as np import librosa from pydub import AudioSegment from pydub.effects import normalize from huggingface_hub import snapshot_download from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav # Set basic CPU optimization flags os.environ["OMP_NUM_THREADS"] = str(os.cpu_count()) torch.set_num_threads(os.cpu_count()) def download_weights(): """Download model weights from HuggingFace if not already present.""" repo_id = "mrfakename/MegaTTS3-VoiceCloning" weights_dir = "checkpoints" if not os.path.exists(weights_dir): print("Downloading model weights from HuggingFace...") snapshot_download( repo_id=repo_id, local_dir=weights_dir, local_dir_use_symlinks=False, resume_download=True ) print("Model weights downloaded successfully!") else: print("Model weights already exist.") return weights_dir # Download weights and initialize model download_weights() print("Initializing MegaTTS3 model...") # Force model to use CPU infer_pipe = MegaTTS3DiTInfer(device="cpu") print(f"Model loaded successfully on CPU with {os.cpu_count()} threads!") def reset_model(): """Reset the inference pipeline""" global infer_pipe try: print("Reinitializing MegaTTS3 model...") infer_pipe = MegaTTS3DiTInfer(device="cpu") print("Model reinitialized successfully on CPU!") return True except Exception as e: print(f"Failed to reinitialize model: {e}") return False def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w, speed_factor): if not inp_audio or not inp_text: gr.Warning("Please provide both reference audio and text to generate.") return None try: print(f"Generating speech with: {inp_text}...") print(f"Running on CPU with {os.cpu_count()} threads...") # Robustly preprocess audio try: processed_audio_path = preprocess_audio_robust(inp_audio) # Use existing cut_wav for final trimming cut_wav(processed_audio_path, max_len=28) wav_path = processed_audio_path except Exception as audio_error: gr.Warning(f"Audio preprocessing failed: {str(audio_error)}") return None # Read audio file with open(wav_path, 'rb') as file: file_content = file.read() # Generate speech with proper error handling try: with torch.no_grad(): # Use no_grad for inference resource_context = infer_pipe.preprocess(file_content) wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w) # Apply speed adjustment if needed if speed_factor != 1.0: wav_bytes = adjust_speed(wav_bytes, speed_factor) # Clean up memory after successful generation cleanup_memory() return wav_bytes except RuntimeError as e: print(f"Error during inference: {e}") # Try to reset the model if reset_model(): gr.Warning("Error occurred. Model has been reset. Please try again.") else: gr.Warning("Error occurred and model reset failed. Please restart the application.") return None except Exception as e: traceback.print_exc() gr.Warning(f"Speech generation failed: {str(e)}") # Clean up memory on any error cleanup_memory() return None def adjust_speed(wav_bytes, speed_factor): """Adjust the speed of the audio without changing pitch""" try: # Create temp file temp_input = "temp_input.wav" temp_output = "temp_output.wav" with open(temp_input, "wb") as f: f.write(wav_bytes) # Load audio audio = AudioSegment.from_file(temp_input) # Apply speed change if speed_factor != 1.0: # Manually adjust frame rate to change speed without pitch alteration new_frame_rate = int(audio.frame_rate * speed_factor) audio = audio._spawn(audio.raw_data, overrides={ "frame_rate": new_frame_rate }).set_frame_rate(audio.frame_rate) # Export result audio.export(temp_output, format="wav") # Read and return with open(temp_output, "rb") as f: result = f.read() # Clean up temp files os.remove(temp_input) os.remove(temp_output) return result except Exception as e: print(f"Speed adjustment failed: {e}") return wav_bytes # Return original if adjustment fails def cleanup_memory(): """Clean up system memory.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def preprocess_audio_robust(audio_path, target_sr=22050, max_duration=30): """Robustly preprocess audio""" try: # Load with pydub for robust format handling audio = AudioSegment.from_file(audio_path) # Convert to mono if stereo if audio.channels > 1: audio = audio.set_channels(1) # Limit duration to prevent memory issues if len(audio) > max_duration * 1000: # pydub uses milliseconds audio = audio[:max_duration * 1000] # Normalize audio to prevent clipping audio = normalize(audio) # Convert to target sample rate audio = audio.set_frame_rate(target_sr) # Export to temporary WAV file with specific parameters temp_path = audio_path.replace(os.path.splitext(audio_path)[1], '_processed.wav') audio.export( temp_path, format="wav", parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)] ) # Validate the audio with librosa wav, sr = librosa.load(temp_path, sr=target_sr, mono=True) # Check for invalid values if np.any(np.isnan(wav)) or np.any(np.isinf(wav)): raise ValueError("Audio contains NaN or infinite values") # Ensure reasonable amplitude range if np.max(np.abs(wav)) < 1e-6: raise ValueError("Audio signal is too quiet") # Re-save the validated audio import soundfile as sf sf.write(temp_path, wav, sr) return temp_path except Exception as e: print(f"Audio preprocessing failed: {e}") raise ValueError(f"Failed to process audio: {str(e)}") with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo: with gr.Row(): with gr.Column(): reference_audio = gr.Audio( label="Reference Audio", type="filepath", sources=["upload", "microphone"] ) text_input = gr.Textbox( label="Text to Generate", placeholder="Enter the text you want to synthesize...", lines=3 ) with gr.Accordion("Advanced Options", open=False): infer_timestep = gr.Number( label="Inference Timesteps", value=32, minimum=1, maximum=100, step=1 ) p_w = gr.Number( label="Intelligibility Weight", value=1.4, minimum=0.1, maximum=5.0, step=0.1 ) t_w = gr.Number( label="Similarity Weight", value=3.0, minimum=0.1, maximum=10.0, step=0.1 ) speed_factor = gr.Slider( label="Speed Adjustment", value=1.0, minimum=0.5, maximum=2.0, step=0.1, info="1.0 = normal speed, <1.0 = slower, >1.0 = faster" ) generate_btn = gr.Button("Generate Speech", variant="primary") with gr.Column(): output_audio = gr.Audio(label="Generated Audio") generate_btn.click( fn=generate_speech, inputs=[reference_audio, text_input, infer_timestep, p_w, t_w, speed_factor], outputs=[output_audio] ) if __name__ == '__main__': demo.launch(server_name='0.0.0.0', server_port=7860, share=True)