Spaces:
Build error
Build error
import numpy as np | |
from collections import Counter | |
from torch.utils.data import Subset | |
from sklearn.model_selection import train_test_split | |
def __balance_val_split(dataset, val_split=0.): | |
targets = np.array(dataset.targets) | |
train_indices, val_indices = train_test_split( | |
np.arange(targets.shape[0]), | |
test_size=val_split, | |
stratify=targets | |
) | |
train_dataset = Subset(dataset, indices=train_indices) | |
val_dataset = Subset(dataset, indices=val_indices) | |
return train_dataset, val_dataset | |
def __split_of_train_sequence(subset: Subset, train_split=1.0): | |
if train_split == 1: | |
return subset | |
targets = np.array([subset.dataset.targets[i] for i in subset.indices]) | |
train_indices, _ = train_test_split( | |
np.arange(targets.shape[0]), | |
test_size=1 - train_split, | |
stratify=targets | |
) | |
train_dataset = Subset(subset.dataset, indices=[subset.indices[i] for i in train_indices]) | |
return train_dataset | |
def __log_class_statistics(subset: Subset): | |
train_classes = [subset.dataset.targets[i] for i in subset.indices] | |
print(dict(Counter(train_classes))) | |