Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from pathlib import Path | |
from random import Random | |
from typing import Sequence | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset | |
from .hparams import HParams | |
class TextAudioDataset(Dataset): | |
def __init__(self, hps: HParams, is_validation: bool = False): | |
self.datapaths = [ | |
Path(x).parent / (Path(x).name + ".data.pt") | |
for x in Path( | |
hps.data.validation_files if is_validation else hps.data.training_files | |
) | |
.read_text("utf-8") | |
.splitlines() | |
] | |
self.hps = hps | |
self.random = Random(hps.train.seed) | |
self.random.shuffle(self.datapaths) | |
self.max_spec_len = 800 | |
def __getitem__(self, index: int) -> dict[str, torch.Tensor]: | |
with Path(self.datapaths[index]).open("rb") as f: | |
data = torch.load(f, weights_only=True, map_location="cpu") | |
# cut long data randomly | |
spec_len = data["mel_spec"].shape[1] | |
hop_len = self.hps.data.hop_length | |
if spec_len > self.max_spec_len: | |
start = self.random.randint(0, spec_len - self.max_spec_len) | |
end = start + self.max_spec_len - 10 | |
for key in data.keys(): | |
if key == "audio": | |
data[key] = data[key][:, start * hop_len : end * hop_len] | |
elif key == "spk": | |
continue | |
else: | |
data[key] = data[key][..., start:end] | |
torch.cuda.empty_cache() | |
return data | |
def __len__(self) -> int: | |
return len(self.datapaths) | |
def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor: | |
max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array])) | |
max_x = array[max_idx] | |
x_padded = [ | |
F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0) | |
for x_ in array | |
] | |
return torch.stack(x_padded) | |
class TextAudioCollate(nn.Module): | |
def forward( | |
self, batch: Sequence[dict[str, torch.Tensor]] | |
) -> tuple[torch.Tensor, ...]: | |
batch = [b for b in batch if b is not None] | |
batch = list(sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True)) | |
lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long() | |
results = {} | |
for key in batch[0].keys(): | |
if key not in ["spk"]: | |
results[key] = _pad_stack([b[key] for b in batch]).cpu() | |
else: | |
results[key] = torch.tensor([[b[key]] for b in batch]).cpu() | |
return ( | |
results["content"], | |
results["f0"], | |
results["spec"], | |
results["mel_spec"], | |
results["audio"], | |
results["spk"], | |
lengths, | |
results["uv"], | |
) | |