My-TTS-Streamlit / model /dataset.py
Mohit0708's picture
Upload 24 files
be29b5b verified
import torch
import torchaudio
import pandas as pd
import os
import soundfile as sf
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
# --- CONFIGURATION ---
# We map characters to integers.
# We reserve 0 for padding, 1 for 'unknown'.
vocab = "_ abcdefghijklmnopqrstuvwxyz'.?"
char_to_id = {char: i+2 for i, char in enumerate(vocab)}
id_to_char = {i+2: char for i, char in enumerate(vocab)}
class TextProcessor:
@staticmethod
def text_to_sequence(text):
text = text.lower()
sequence = [char_to_id.get(c, 1) for c in text if c in vocab]
return torch.tensor(sequence, dtype=torch.long)
class LJSpeechDataset(Dataset):
def __init__(self, metadata_path, wavs_dir):
"""
metadata_path: Path to metadata.csv
wavs_dir: Path to the folder containing .wav files
"""
self.wavs_dir = wavs_dir
# Load CSV (Format: ID | Transcription | Normalized Transcription)
self.metadata = pd.read_csv(metadata_path, sep='|', header=None, quoting=3).iloc[:100]
# Audio Processing Setup (Mel Spectrogram)
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=22050,
n_fft=1024,
win_length=256,
hop_length=256,
n_mels=80 # Standard for TTS (Match this with your network.py!)
)
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
# 1. Get Text
row = self.metadata.iloc[idx]
file_id = row[0]
text = row[2]
text_tensor = TextProcessor.text_to_sequence(str(text))
# 2. Get Audio (BYPASSING TORCHAUDIO LOADER)
wav_path = os.path.join(self.wavs_dir, f"{file_id}.wav")
# Use soundfile directly to read the audio
# sf.read returns: audio_array (numpy), sample_rate (int)
audio_np, sample_rate = sf.read(wav_path)
# Convert Numpy -> PyTorch Tensor
# Soundfile gives [time] or [time, channels], but PyTorch wants [channels, time]
waveform = torch.from_numpy(audio_np).float()
if waveform.dim() == 1:
# If mono, add channel dimension: [time] -> [1, time]
waveform = waveform.unsqueeze(0)
else:
# If stereo, transpose: [time, channels] -> [channels, time]
waveform = waveform.transpose(0, 1)
# Resample if necessary
if sample_rate != 22050:
resampler = torchaudio.transforms.Resample(sample_rate, 22050)
waveform = resampler(waveform)
# Convert to Mel Spectrogram
mel_spec = self.mel_transform(waveform).squeeze(0)
mel_spec = mel_spec.transpose(0, 1)
return text_tensor, mel_spec
# --- BATCHING MAGIC (Collate Function) ---
# Since sentences have different lengths, we must pad them to match the longest in the batch.
def collate_fn_tts(batch):
# batch is a list of tuples: [(text1, mel1), (text2, mel2), ...]
# Separate text and mels
text_list = [item[0] for item in batch]
mel_list = [item[1] for item in batch]
# Pad sequences
# batch_first=True makes output [batch, max_len, ...]
text_padded = pad_sequence(text_list, batch_first=True, padding_value=0)
mel_padded = pad_sequence(mel_list, batch_first=True, padding_value=0.0)
return text_padded, mel_padded
# --- SANITY CHECK ---
if __name__ == "__main__":
# UPDATE THESE PATHS TO MATCH YOUR FOLDER
BASE_PATH = "LJSpeech-1.1"
csv_path = os.path.join(BASE_PATH, "metadata.csv")
wav_path = os.path.join(BASE_PATH, "wavs")
if os.path.exists(csv_path):
print("Loading Dataset...")
dataset = LJSpeechDataset(csv_path, wav_path)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn_tts)
# Get one batch
text_batch, mel_batch = next(iter(loader))
print(f"Text Batch Shape: {text_batch.shape} (Batch, Max Text Len)")
print(f"Mel Batch Shape: {mel_batch.shape} (Batch, Max Audio Len, 80)")
print("\nSUCCESS: Data pipeline is working!")
else:
print("Dataset not found. Please download LJSpeech-1.1 to run this test.")