yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
import argparse
import os
from os.path import dirname
import numpy as np
import torch
import yaml
from pytorch_lightning.utilities import rank_zero_warn
def train_val_test_split(dset_len, train_size, val_size, test_size, seed):
assert (train_size is None) + (val_size is None) + (test_size is None) <= 1, "Only one of train_size, val_size, test_size is allowed to be None."
is_float = (isinstance(train_size, float), isinstance(val_size, float), isinstance(test_size, float))
train_size = round(dset_len * train_size) if is_float[0] else train_size
val_size = round(dset_len * val_size) if is_float[1] else val_size
test_size = round(dset_len * test_size) if is_float[2] else test_size
if train_size is None:
train_size = dset_len - val_size - test_size
elif val_size is None:
val_size = dset_len - train_size - test_size
elif test_size is None:
test_size = dset_len - train_size - val_size
if train_size + val_size + test_size > dset_len:
if is_float[2]:
test_size -= 1
elif is_float[1]:
val_size -= 1
elif is_float[0]:
train_size -= 1
assert train_size >= 0 and val_size >= 0 and test_size >= 0, (
f"One of training ({train_size}), validation ({val_size}) or "
f"testing ({test_size}) splits ended up with a negative size."
)
total = train_size + val_size + test_size
assert dset_len >= total, f"The dataset ({dset_len}) is smaller than the combined split sizes ({total})."
if total < dset_len:
rank_zero_warn(f"{dset_len - total} samples were excluded from the dataset")
idxs = np.arange(dset_len, dtype=np.int64)
idxs = np.random.default_rng(seed).permutation(idxs)
idx_train = idxs[:train_size]
idx_val = idxs[train_size: train_size + val_size]
idx_test = idxs[train_size + val_size: total]
return np.array(idx_train), np.array(idx_val), np.array(idx_test)
def make_splits(dataset_len, train_size, val_size, test_size, seed, filename=None, splits=None):
if splits is not None:
splits = np.load(splits)
idx_train = splits["idx_train"]
idx_val = splits["idx_val"]
idx_test = splits["idx_test"]
else:
idx_train, idx_val, idx_test = train_val_test_split(dataset_len, train_size, val_size, test_size, seed)
if filename is not None:
np.savez(filename, idx_train=idx_train, idx_val=idx_val, idx_test=idx_test)
return torch.from_numpy(idx_train), torch.from_numpy(idx_val), torch.from_numpy(idx_test)
class LoadFromFile(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.name.endswith("yaml") or values.name.endswith("yml"):
with values as f:
config = yaml.load(f, Loader=yaml.FullLoader)
for key in config.keys():
if key not in namespace:
raise ValueError(f"Unknown argument in config file: {key}")
namespace.__dict__.update(config)
else:
raise ValueError("Configuration file must end with yaml or yml")
class LoadFromCheckpoint(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
ckpt = torch.load(values, map_location="cpu")
config = ckpt["hyper_parameters"]
for key in config.keys():
if key not in namespace:
raise ValueError(f"Unknown argument in the model checkpoint: {key}")
namespace.__dict__.update(config)
namespace.__dict__.update(load_model=values)
def save_argparse(args, filename, exclude=None):
os.makedirs(dirname(filename), exist_ok=True)
if filename.endswith("yaml") or filename.endswith("yml"):
if isinstance(exclude, str):
exclude = [exclude]
args = args.__dict__.copy()
for exl in exclude:
del args[exl]
yaml.dump(args, open(filename, "w"))
else:
raise ValueError("Configuration file should end with yaml or yml")
def number(text):
if text is None or text == "None":
return None
try:
num_int = int(text)
except ValueError:
num_int = None
num_float = float(text)
if num_int == num_float:
return num_int
return num_float
class MissingLabelException(Exception):
pass