import numpy as np from collections import Counter from torch.utils.data import Subset from sklearn.model_selection import train_test_split def __balance_val_split(dataset, val_split=0.): targets = np.array(dataset.targets) train_indices, val_indices = train_test_split( np.arange(targets.shape[0]), test_size=val_split, stratify=targets ) train_dataset = Subset(dataset, indices=train_indices) val_dataset = Subset(dataset, indices=val_indices) return train_dataset, val_dataset def __split_of_train_sequence(subset: Subset, train_split=1.0): if train_split == 1: return subset targets = np.array([subset.dataset.targets[i] for i in subset.indices]) train_indices, _ = train_test_split( np.arange(targets.shape[0]), test_size=1 - train_split, stratify=targets ) train_dataset = Subset(subset.dataset, indices=[subset.indices[i] for i in train_indices]) return train_dataset def __log_class_statistics(subset: Subset): train_classes = [subset.dataset.targets[i] for i in subset.indices] print(dict(Counter(train_classes)))