Vexa's picture
Duplicate from pivich/sovits-new
d5d7329
raw
history blame contribute delete
No virus
2.88 kB
from __future__ import annotations
from pathlib import Path
from random import Random
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from .hparams import HParams
class TextAudioDataset(Dataset):
def __init__(self, hps: HParams, is_validation: bool = False):
self.datapaths = [
Path(x).parent / (Path(x).name + ".data.pt")
for x in Path(
hps.data.validation_files if is_validation else hps.data.training_files
)
.read_text("utf-8")
.splitlines()
]
self.hps = hps
self.random = Random(hps.train.seed)
self.random.shuffle(self.datapaths)
self.max_spec_len = 800
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
with Path(self.datapaths[index]).open("rb") as f:
data = torch.load(f, weights_only=True, map_location="cpu")
# cut long data randomly
spec_len = data["mel_spec"].shape[1]
hop_len = self.hps.data.hop_length
if spec_len > self.max_spec_len:
start = self.random.randint(0, spec_len - self.max_spec_len)
end = start + self.max_spec_len - 10
for key in data.keys():
if key == "audio":
data[key] = data[key][:, start * hop_len : end * hop_len]
elif key == "spk":
continue
else:
data[key] = data[key][..., start:end]
torch.cuda.empty_cache()
return data
def __len__(self) -> int:
return len(self.datapaths)
def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor:
max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array]))
max_x = array[max_idx]
x_padded = [
F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0)
for x_ in array
]
return torch.stack(x_padded)
class TextAudioCollate(nn.Module):
def forward(
self, batch: Sequence[dict[str, torch.Tensor]]
) -> tuple[torch.Tensor, ...]:
batch = [b for b in batch if b is not None]
batch = list(sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True))
lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long()
results = {}
for key in batch[0].keys():
if key not in ["spk"]:
results[key] = _pad_stack([b[key] for b in batch]).cpu()
else:
results[key] = torch.tensor([[b[key]] for b in batch]).cpu()
return (
results["content"],
results["f0"],
results["spec"],
results["mel_spec"],
results["audio"],
results["spk"],
lengths,
results["uv"],
)