tts2jk / tests /data_tests /test_samplers.py
juliankoe's picture
Upload folder using huggingface_hub
a9384d7
raw
history blame
7.32 kB
import functools
import random
import unittest
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
# Fixing random state to avoid random fails
torch.manual_seed(0)
dataset_config_en = BaseDatasetConfig(
formatter="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en",
)
dataset_config_pt = BaseDatasetConfig(
formatter="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="pt-br",
)
# Adding the EN samples twice to create a language unbalanced dataset
train_samples, eval_samples = load_tts_samples(
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
)
# gerenate a speaker unbalanced dataset
for i, sample in enumerate(train_samples):
if i < 5:
sample["speaker_name"] = "ljspeech-0"
else:
sample["speaker_name"] = "ljspeech-1"
def is_balanced(lang_1, lang_2):
return 0.85 < lang_1 / lang_2 < 1.2
class TestSamplers(unittest.TestCase):
def test_language_random_sampler(self): # pylint: disable=no-self-use
random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_language_balancer_weights(train_samples), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced"
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_speaker_balancer_weights(train_samples), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
spk1, spk2 = 0, 0
for index in ids:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
def test_perfect_sampler(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=False,
drop_last=True,
)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=True,
drop_last=False,
)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use
for _ in range(1000):
# gerenate a lenght unbalanced dataset with random max/min audio lenght
min_audio = random.randrange(1, 22050)
max_audio = random.randrange(44100, 220500)
for idx, item in enumerate(train_samples):
# increase the diversity of durations
random_increase = random.randrange(100, 1000)
if idx < 5:
item["audio_length"] = min_audio + random_increase
else:
item["audio_length"] = max_audio + random_increase
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples)
)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
len1, len2 = 0, 0
for index in ids:
if train_samples[index]["audio_length"] < max_audio:
len1 += 1
else:
len2 += 1
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
def test_bucket_batch_sampler(self):
bucket_size_multiplier = 2
sampler = range(len(train_samples))
sampler = BucketBatchSampler(
sampler,
data=train_samples,
batch_size=7,
drop_last=True,
sort_key=lambda x: len(x["text"]),
bucket_size_multiplier=bucket_size_multiplier,
)
# check if the samples are sorted by text lenght whuile bucketing
min_text_len_in_bucket = 0
bucket_items = []
for batch_idx, batch in enumerate(list(sampler)):
if (batch_idx + 1) % bucket_size_multiplier == 0:
for bucket_item in bucket_items:
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
min_text_len_in_bucket = 0
bucket_items = []
else:
bucket_items += batch
# check sampler length
self.assertEqual(len(sampler), len(train_samples) // 7)