import os import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from transformers import GPT2Tokenizer, GPT2Model from torchaudio.transforms import MelSpectrogram, InverseMelScale, GriffinLim import torchaudio from sklearn.model_selection import train_test_split from tqdm import tqdm from torch.optim.lr_scheduler import CosineAnnealingLR from torch.amp import GradScaler, autocast class TextToSpeechDataset(Dataset): def __init__(self, text_files, audio_files, tokenizer, mel_transform, max_length=512): self.text_files = text_files self.audio_files = audio_files self.tokenizer = tokenizer self.mel_transform = mel_transform self.max_length = max_length def __len__(self): return len(self.text_files) def __getitem__(self, idx): # Load text with open(self.text_files[idx], 'r') as f: text = f.read().strip() # Tokenize text text_tokens = self.tokenizer.encode( text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt" ).squeeze(0) # Load audio and convert to mel spectrogram waveform, sample_rate = torchaudio.load(self.audio_files[idx]) mel_spec = self.mel_transform(waveform) return text_tokens, mel_spec.squeeze(0) def collate_fn(batch): text_tokens, mel_specs = zip(*batch) # Pad text tokens max_text_len = max(tokens.size(0) for tokens in text_tokens) text_tokens_padded = torch.stack([ torch.cat([tokens, torch.zeros(max_text_len - tokens.size(0), dtype=tokens.dtype)], dim=0) if tokens.size(0) < max_text_len else tokens[:max_text_len] for tokens in text_tokens ]) # Pad mel spectrograms max_mel_len = max(spec.size(1) for spec in mel_specs) mel_specs_padded = torch.stack([ F.pad(spec, (0, max_mel_len - spec.size(1))) if spec.size(1) < max_mel_len else spec[:, :max_mel_len] for spec in mel_specs ]) return text_tokens_padded, mel_specs_padded class VAEDecoder(nn.Module): def __init__(self, latent_dim, mel_channels=80): super().__init__() # Encoder part (probabilistic) self.fc_mu = nn.Linear(latent_dim, latent_dim) self.fc_var = nn.Linear(latent_dim, latent_dim) # Decoder part self.decoder_layers = nn.Sequential( nn.Linear(latent_dim, 512), nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, mel_channels * 80), # Output mel spectrogram nn.Unflatten(1, (mel_channels, 80)) ) def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + eps * std def forward(self, z): mu = self.fc_mu(z) log_var = self.fc_var(z) # Reparameterization trick z = self.reparameterize(mu, log_var) # Decode mel_spec = self.decoder_layers(z) return mel_spec, mu, log_var class TextToSpeechModel(nn.Module): def __init__(self, text_encoder, vae_decoder, latent_dim=256): super().__init__() self.text_encoder = text_encoder self.vae_decoder = vae_decoder # Projection layer to map encoder output to latent space self.projection = nn.Linear(text_encoder.config.hidden_size, latent_dim) def forward(self, text_tokens): # Encode text encoder_output = self.text_encoder(text_tokens).last_hidden_state # Mean pooling of encoder output text_embedding = encoder_output.mean(dim=1) # Project to latent space latent_z = self.projection(text_embedding) # Decode to mel spectrogram mel_spec, mu, log_var = self.vae_decoder(latent_z) return mel_spec, mu, log_var def vae_loss(reconstruction, target, mu, log_var): # Reconstruction loss (MSE) recon_loss = F.mse_loss(reconstruction, target, reduction='mean') # KL Divergence loss kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) return recon_loss + 0.001 * kl_loss def train_model(num_epochs=10, accumulation_steps=16): # Tokenizer and mel spectrogram transform tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token # Mel spectrogram configuration mel_transform = MelSpectrogram( sample_rate=16000, n_mels=80, n_fft=1024, hop_length=256 ) # Data preparation text_folder = './texts' audio_folder = './audio' # Load text and audio files text_files = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.endswith('.txt')] audio_files = [os.path.join(audio_folder, f) for f in os.listdir(audio_folder) if f.endswith('.wav')] # Split dataset train_texts, val_texts, train_audios, val_audios = train_test_split( text_files, audio_files, test_size=0.1, random_state=42 ) # Create datasets and dataloaders train_dataset = TextToSpeechDataset(train_texts, train_audios, tokenizer, mel_transform) val_dataset = TextToSpeechDataset(val_texts, val_audios, tokenizer, mel_transform) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn) val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn) # Model components text_encoder = GPT2Model.from_pretrained('gpt2') vae_decoder = VAEDecoder(latent_dim=256) # Combine into full model model = TextToSpeechModel(text_encoder, vae_decoder) # Device setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # Optimizer and scheduler optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6) # Gradient scaler scaler = GradScaler() best_val_loss = float('inf') # Training loop for epoch in range(num_epochs): model.train() train_loss = 0 for batch_idx, (text_tokens, mel_specs) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")): text_tokens = text_tokens.to(device) mel_specs = mel_specs.to(device) with autocast(dtype=torch.float16, device_type='cuda'): # Forward pass reconstructed_mel, mu, log_var = model(text_tokens) # Compute loss loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var) # Scaled loss and backpropagation loss = loss / accumulation_steps scaler.scale(loss).backward() if (batch_idx + 1) % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() train_loss += loss.item() # Validation model.eval() val_loss = 0 with torch.no_grad(): for text_tokens, mel_specs in tqdm(val_loader, desc=f"Validation {epoch+1}"): text_tokens = text_tokens.to(device) mel_specs = mel_specs.to(device) reconstructed_mel, mu, log_var = model(text_tokens) loss = vae_loss(reconstructed_mel, mel_specs, mu, log_var) val_loss += loss.item() # Scheduler step scheduler.step() # Print epoch summary print(f'Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}') # Model saving if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), 'best_tts_model.pth') return model # Run training trained_model = train_model() # Optional: Inference function for generating mel spectrograms def generate_mel_spectrogram(text, model, tokenizer, device): model.eval() with torch.no_grad(): # Tokenize input text text_tokens = tokenizer.encode( text, return_tensors="pt", truncation=True, padding='max_length', max_length=512 ).to(device) # Generate mel spectrogram mel_spec, _, _ = model(text_tokens) return mel_spec # Optional: Convert mel spectrogram back to audio def mel_to_audio(mel_spec, sample_rate=16000): # Use griffin-lim for mel spectrogram inversion inverse_mel = InverseMelScale(sample_rate=sample_rate) griffin_lim = GriffinLim(sample_rate=sample_rate) # Convert mel spectrogram back to waveform waveform = griffin_lim(inverse_mel(mel_spec)) return waveform