|
import argparse |
|
import numpy as np |
|
import torch |
|
import yaml |
|
|
|
|
|
class LoadFromFile(argparse.Action): |
|
|
|
def __call__(self, parser, namespace, values, option_string=None): |
|
if values.name.endswith("yaml") or values.name.endswith("yml"): |
|
with values as f: |
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
for key in config.keys(): |
|
if key not in namespace: |
|
raise ValueError(f"Unknown argument in config file: {key}") |
|
if ( |
|
"load_model" in config |
|
and namespace.load_model is not None |
|
and config["load_model"] != namespace.load_model |
|
): |
|
Warning( |
|
f"The load model argument was specified as a command line argument " |
|
f"({namespace.load_model}) and in the config file ({config['load_model']}). " |
|
f"Ignoring 'load_model' from the config file and loading {namespace.load_model}." |
|
) |
|
del config["load_model"] |
|
namespace.__dict__.update(config) |
|
else: |
|
raise ValueError("Configuration file must end with yaml or yml") |
|
|
|
|
|
|
|
def train_val_test_split(dset_len, train_size, val_size, test_size, seed, order=None): |
|
assert (train_size is None) + (val_size is None) + ( |
|
test_size is None |
|
) <= 1, "Only one of train_size, val_size, test_size is allowed to be None." |
|
is_float = ( |
|
isinstance(train_size, float), |
|
isinstance(val_size, float), |
|
isinstance(test_size, float), |
|
) |
|
|
|
train_size = round(dset_len * train_size) if is_float[0] else train_size |
|
val_size = round(dset_len * val_size) if is_float[1] else val_size |
|
test_size = round(dset_len * test_size) if is_float[2] else test_size |
|
|
|
if train_size is None: |
|
train_size = dset_len - val_size - test_size |
|
elif val_size is None: |
|
val_size = dset_len - train_size - test_size |
|
elif test_size is None: |
|
test_size = dset_len - train_size - val_size |
|
|
|
if train_size + val_size + test_size > dset_len: |
|
if is_float[2]: |
|
test_size -= 1 |
|
elif is_float[1]: |
|
val_size -= 1 |
|
elif is_float[0]: |
|
train_size -= 1 |
|
|
|
assert train_size >= 0 and val_size >= 0 and test_size >= 0, ( |
|
f"One of training ({train_size}), validation ({val_size}) or " |
|
f"testing ({test_size}) splits ended up with a negative size." |
|
) |
|
|
|
total = train_size + val_size + test_size |
|
assert dset_len >= total, ( |
|
f"The dataset ({dset_len}) is smaller than the " |
|
f"combined split sizes ({total})." |
|
) |
|
if total < dset_len: |
|
Warning(f"{dset_len - total} samples were excluded from the dataset") |
|
|
|
idxs = np.arange(dset_len, dtype=np.int) |
|
if order is None: |
|
idxs = np.random.default_rng(seed).permutation(idxs) |
|
|
|
idx_train = idxs[:train_size] |
|
idx_val = idxs[train_size: train_size + val_size] |
|
idx_test = idxs[train_size + val_size: total] |
|
|
|
if order is not None: |
|
idx_train = [order[i] for i in idx_train] |
|
idx_val = [order[i] for i in idx_val] |
|
idx_test = [order[i] for i in idx_test] |
|
|
|
return np.array(idx_train), np.array(idx_val), np.array(idx_test) |
|
|
|
|
|
def make_splits_train_val_test( |
|
dataset_len, |
|
train_size, |
|
val_size, |
|
test_size, |
|
seed, |
|
filename=None, |
|
splits=None, |
|
order=None, |
|
): |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
idx_test = splits["idx_test"] |
|
else: |
|
idx_train, idx_val, idx_test = train_val_test_split( |
|
dataset_len, train_size, val_size, test_size, seed, order |
|
) |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val, idx_test=idx_test) |
|
|
|
return ( |
|
torch.from_numpy(idx_train), |
|
torch.from_numpy(idx_val), |
|
torch.from_numpy(idx_test), |
|
) |
|
|
|
|
|
def train_val_split(dset_len, train_size, val_size, seed, order=None): |
|
assert (train_size is None) + ( |
|
val_size is None) <= 1, "Only one of train_size, val_size, test_size is allowed to be None." |
|
is_float = ( |
|
isinstance(train_size, float), |
|
isinstance(val_size, float), |
|
) |
|
|
|
|
|
train_size = round(dset_len * train_size) if is_float[0] else train_size |
|
val_size = round(dset_len * val_size) if is_float[1] else val_size |
|
|
|
if train_size is None: |
|
train_size = dset_len - val_size |
|
elif val_size is None: |
|
val_size = dset_len - train_size |
|
|
|
if train_size + val_size > dset_len: |
|
if is_float[1]: |
|
val_size -= 1 |
|
elif is_float[0]: |
|
train_size -= 1 |
|
|
|
assert train_size >= 0 and val_size >= 0, ( |
|
f"One of training ({train_size}), validation ({val_size})" |
|
f" splits ended up with a negative size." |
|
) |
|
|
|
total = train_size + val_size |
|
assert dset_len >= total, ( |
|
f"The dataset ({dset_len}) is smaller than the " |
|
f"combined split sizes ({total})." |
|
) |
|
if total < dset_len: |
|
Warning(f"{dset_len - total} samples were excluded from the dataset") |
|
|
|
idxs = np.arange(dset_len, dtype=np.int) |
|
if order is None: |
|
idxs = np.random.default_rng(seed).permutation(idxs) |
|
|
|
idx_train = idxs[:train_size] |
|
idx_val = idxs[train_size: total] |
|
|
|
if order is not None: |
|
idx_train = [order[i] for i in idx_train] |
|
idx_val = [order[i] for i in idx_val] |
|
|
|
return np.array(idx_train), np.array(idx_val) |
|
|
|
|
|
def make_splits_train_val( |
|
dset, |
|
train_size, |
|
val_size, |
|
seed, |
|
batch_size=48, |
|
filename=None, |
|
splits=None, |
|
order=None, |
|
): |
|
dset_len = len(dset) |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
else: |
|
idx_train, idx_val = train_val_split(dset_len, train_size, val_size, seed, order) |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val) |
|
|
|
return ( |
|
torch.from_numpy(idx_train), |
|
torch.from_numpy(idx_val), |
|
) |
|
|
|
|
|
def make_splits_train_val_by_anno( |
|
dset, |
|
train_size, |
|
val_size, |
|
seed, |
|
batch_size=48, |
|
filename=None, |
|
splits=None, |
|
order=None, |
|
): |
|
dset_len = len(dset) |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
else: |
|
idx_train = np.where(dset.data['split'] == 'train')[0] |
|
idx_val = np.where(dset.data['split'] == 'val')[0] |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val) |
|
|
|
return ( |
|
torch.from_numpy(idx_train), |
|
torch.from_numpy(idx_val), |
|
) |
|
|
|
|
|
def train_val_split_by_uniprot_id(dset, train_size, val_size, seed, batch_size=48, order=None): |
|
assert (train_size is None) + ( |
|
val_size is None) <= 1, "Only one of train_size, val_size, test_size is allowed to be None." |
|
is_float = ( |
|
isinstance(train_size, float), |
|
isinstance(val_size, float), |
|
) |
|
dset_len: int = len(dset) |
|
|
|
train_size = round(dset_len * train_size) if is_float[0] else train_size |
|
val_size = round(dset_len * val_size) if is_float[1] else val_size |
|
|
|
if train_size is None: |
|
train_size = dset_len - val_size |
|
elif val_size is None: |
|
val_size = dset_len - train_size |
|
|
|
if train_size + val_size > dset_len: |
|
if is_float[1]: |
|
val_size -= 1 |
|
elif is_float[0]: |
|
train_size -= 1 |
|
|
|
assert train_size >= 0 and val_size >= 0, ( |
|
f"One of training ({train_size}), validation ({val_size})" |
|
f" splits ended up with a negative size." |
|
) |
|
|
|
total = train_size + val_size |
|
assert dset_len >= total, ( |
|
f"The dataset ({dset_len}) is smaller than the " |
|
f"combined split sizes ({total})." |
|
) |
|
if total < dset_len: |
|
Warning(f"{dset_len - total} samples were excluded from the dataset") |
|
|
|
uniprot_freq_table = dset.data.uniprotID.value_counts() |
|
selected_val_uniprotIDs, _ = select_by_uniprot(uniprot_freq_table, val_size) |
|
|
|
idxs = np.arange(dset_len, dtype=np.int) |
|
idx_train = idxs[np.isin(dset.data.uniprotID, selected_val_uniprotIDs, invert=True)] |
|
idx_val = idxs[np.isin(dset.data.uniprotID, selected_val_uniprotIDs)] |
|
|
|
if order is None: |
|
idx_train = np.random.default_rng(seed).permutation(idx_train) |
|
idx_val = np.random.default_rng(seed).permutation(idx_val) |
|
idx_train = guarantee_no_same_protein_in_one_batch( |
|
idx_train, batch_size, np.array(dset.data.uniprotID.iloc[idx_train] + dset.data.ENST.iloc[idx_train]) |
|
) |
|
idx_val = guarantee_no_same_protein_in_one_batch( |
|
idx_val, batch_size, np.array(dset.data.uniprotID.iloc[idx_val] + dset.data.ENST.iloc[idx_val]) |
|
) |
|
else: |
|
idx_train = [order[i] for i in idx_train] |
|
idx_val = [order[i] for i in idx_val] |
|
|
|
return np.array(idx_train), np.array(idx_val) |
|
|
|
|
|
def make_splits_train_val_by_uniprot_id( |
|
dset, |
|
train_size, |
|
val_size, |
|
seed, |
|
batch_size=48, |
|
filename=None, |
|
splits=None, |
|
order=None, |
|
): |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
else: |
|
idx_train, idx_val = train_val_split_by_uniprot_id(dset, train_size, val_size, seed, batch_size, order) |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val) |
|
|
|
return ( |
|
torch.from_numpy(idx_train), |
|
torch.from_numpy(idx_val), |
|
) |
|
|
|
|
|
def train_val_split_by_good_batch(dset, train_size, val_size, seed, batch_size=48, order=None): |
|
assert (train_size is None) + ( |
|
val_size is None) <= 1, "Only one of train_size, val_size, test_size is allowed to be None." |
|
is_float = ( |
|
isinstance(train_size, float), |
|
isinstance(val_size, float), |
|
) |
|
dset_len: int = len(dset) |
|
|
|
train_size = round(dset_len * train_size) if is_float[0] else train_size |
|
val_size = round(dset_len * val_size) if is_float[1] else val_size |
|
|
|
if train_size is None: |
|
train_size = dset_len - val_size |
|
elif val_size is None: |
|
val_size = dset_len - train_size |
|
|
|
if train_size + val_size > dset_len: |
|
if is_float[1]: |
|
val_size -= 1 |
|
elif is_float[0]: |
|
train_size -= 1 |
|
|
|
assert train_size >= 0 and val_size >= 0, ( |
|
f"One of training ({train_size}), validation ({val_size})" |
|
f" splits ended up with a negative size." |
|
) |
|
|
|
total = train_size + val_size |
|
assert dset_len >= total, ( |
|
f"The dataset ({dset_len}) is smaller than the " |
|
f"combined split sizes ({total})." |
|
) |
|
if total < dset_len: |
|
Warning(f"{dset_len - total} samples were excluded from the dataset") |
|
|
|
idxs = np.arange(dset_len, dtype=np.int) |
|
if order is None: |
|
idxs = np.random.default_rng(seed).permutation(idxs) |
|
|
|
idx_train = idxs[:train_size] |
|
idx_val = idxs[train_size: total] |
|
|
|
if order is None: |
|
idx_train = np.random.default_rng(seed).permutation(idx_train) |
|
idx_val = np.random.default_rng(seed).permutation(idx_val) |
|
idx_train = guarantee_good_batch(idx_train, batch_size, dset) |
|
else: |
|
idx_train = [order[i] for i in idx_train] |
|
idx_val = [order[i] for i in idx_val] |
|
|
|
return np.array(idx_train), np.array(idx_val) |
|
|
|
|
|
def make_splits_train_val_by_good_batch( |
|
dset, |
|
train_size, |
|
val_size, |
|
seed, |
|
batch_size=48, |
|
filename=None, |
|
splits=None, |
|
order=None, |
|
): |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
else: |
|
idx_train, idx_val = train_val_split_by_good_batch(dset, train_size, val_size, seed, batch_size, order) |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val) |
|
|
|
return ( |
|
torch.from_numpy(idx_train), |
|
torch.from_numpy(idx_val), |
|
) |
|
|
|
|
|
def train_val_split_by_uniprot_id_good_batch(dset, train_size, val_size, seed, batch_size=48, order=None): |
|
assert (train_size is None) + ( |
|
val_size is None) <= 1, "Only one of train_size, val_size, test_size is allowed to be None." |
|
is_float = ( |
|
isinstance(train_size, float), |
|
isinstance(val_size, float), |
|
) |
|
dset_len: int = len(dset) |
|
|
|
train_size = round(dset_len * train_size) if is_float[0] else train_size |
|
val_size = round(dset_len * val_size) if is_float[1] else val_size |
|
|
|
if train_size is None: |
|
train_size = dset_len - val_size |
|
elif val_size is None: |
|
val_size = dset_len - train_size |
|
|
|
if train_size + val_size > dset_len: |
|
if is_float[1]: |
|
val_size -= 1 |
|
elif is_float[0]: |
|
train_size -= 1 |
|
|
|
assert train_size >= 0 and val_size >= 0, ( |
|
f"One of training ({train_size}), validation ({val_size})" |
|
f" splits ended up with a negative size." |
|
) |
|
|
|
total = train_size + val_size |
|
assert dset_len >= total, ( |
|
f"The dataset ({dset_len}) is smaller than the " |
|
f"combined split sizes ({total})." |
|
) |
|
if total < dset_len: |
|
Warning(f"{dset_len - total} samples were excluded from the dataset") |
|
|
|
uniprot_freq_table = dset.data.uniprotID.value_counts() |
|
selected_val_uniprotIDs, _ = select_by_uniprot(uniprot_freq_table, val_size) |
|
|
|
idxs = np.arange(dset_len, dtype=np.int) |
|
idx_train = idxs[np.isin(dset.data.uniprotID, selected_val_uniprotIDs, invert=True)] |
|
idx_val = idxs[np.isin(dset.data.uniprotID, selected_val_uniprotIDs)] |
|
|
|
if order is None: |
|
idx_train = np.random.default_rng(seed).permutation(idx_train) |
|
idx_val = np.random.default_rng(seed).permutation(idx_val) |
|
idx_train = guarantee_good_batch(idx_train, batch_size, dset) |
|
else: |
|
idx_train = [order[i] for i in idx_train] |
|
idx_val = [order[i] for i in idx_val] |
|
|
|
return np.array(idx_train), np.array(idx_val) |
|
|
|
|
|
def make_splits_train_val_by_uniprot_id_good_batch( |
|
dset, |
|
train_size, |
|
val_size, |
|
seed, |
|
batch_size=48, |
|
filename=None, |
|
splits=None, |
|
order=None, |
|
): |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
else: |
|
idx_train, idx_val = train_val_split_by_uniprot_id_good_batch(dset, train_size, val_size, seed, batch_size, order) |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val) |
|
|
|
return ( |
|
torch.from_numpy(idx_train), |
|
torch.from_numpy(idx_val), |
|
) |
|
|
|
|
|
def train_val_test_split_by_uniprot_id(dset, train_size, val_size, test_size, seed, batch_size=48, order=None): |
|
assert (train_size is None) + (val_size is None) + ( |
|
test_size is None |
|
) <= 1, "Only one of train_size, val_size, test_size is allowed to be None." |
|
is_float = ( |
|
isinstance(train_size, float), |
|
isinstance(val_size, float), |
|
isinstance(test_size, float), |
|
) |
|
dset_len: int = len(dset) |
|
train_size = round(dset_len * train_size) if is_float[0] else train_size |
|
val_size = round(dset_len * val_size) if is_float[1] else val_size |
|
test_size = round(dset_len * test_size) if is_float[2] else test_size |
|
|
|
if train_size is None: |
|
train_size = dset_len - val_size - test_size |
|
elif val_size is None: |
|
val_size = dset_len - train_size - test_size |
|
elif test_size is None: |
|
test_size = dset_len - train_size - val_size |
|
|
|
if train_size + val_size + test_size > dset_len: |
|
if is_float[2]: |
|
test_size -= 1 |
|
elif is_float[1]: |
|
val_size -= 1 |
|
elif is_float[0]: |
|
train_size -= 1 |
|
|
|
assert train_size >= 0 and val_size >= 0 and test_size >= 0, ( |
|
f"One of training ({train_size}), validation ({val_size}) or " |
|
f"testing ({test_size}) splits ended up with a negative size." |
|
) |
|
|
|
total = train_size + val_size + test_size |
|
assert dset_len >= total, ( |
|
f"The dataset ({dset_len}) is smaller than the " |
|
f"combined split sizes ({total})." |
|
) |
|
if total < dset_len: |
|
Warning(f"{dset_len - total} samples were excluded from the dataset") |
|
|
|
uniprot_freq_table = dset.data.uniprotID.value_counts() |
|
selected_test_uniprotIDs, uniprot_freq_table = select_by_uniprot(uniprot_freq_table, test_size) |
|
selected_val_uniprotIDs, uniprot_freq_table = select_by_uniprot(uniprot_freq_table, val_size) |
|
|
|
idxs = np.arange(dset_len, dtype=np.int) |
|
|
|
idx_test = idxs[np.isin(dset.data.uniprotID, selected_test_uniprotIDs)] |
|
idx_val = idxs[np.isin(dset.data.uniprotID, selected_val_uniprotIDs)] |
|
idx_train = idxs[np.isin(dset.data.uniprotID, selected_test_uniprotIDs + selected_val_uniprotIDs, invert=True)] |
|
|
|
if order is None: |
|
idx_train = np.random.default_rng(seed).permutation(idx_train) |
|
idx_val = np.random.default_rng(seed).permutation(idx_val) |
|
idx_test = np.random.default_rng(seed).permutation(idx_test) |
|
idx_train = guarantee_no_same_protein_in_one_batch( |
|
idx_train, batch_size, np.array(dset.data.uniprotID.iloc[idx_train] + dset.data.ENST.iloc[idx_train]) |
|
) |
|
idx_val = guarantee_no_same_protein_in_one_batch( |
|
idx_val, batch_size, np.array(dset.data.uniprotID.iloc[idx_val] + dset.data.ENST.iloc[idx_val]) |
|
) |
|
idx_test = guarantee_no_same_protein_in_one_batch( |
|
idx_test, batch_size, np.array(dset.data.uniprotID.iloc[idx_test] + dset.data.ENST.iloc[idx_test]) |
|
) |
|
else: |
|
idx_train = [order[i] for i in idx_train] |
|
idx_val = [order[i] for i in idx_val] |
|
idx_test = [order[i] for i in idx_test] |
|
|
|
return np.array(idx_train), np.array(idx_val), np.array(idx_test) |
|
|
|
|
|
def make_splits_train_val_test_by_uniprot_id( |
|
dset, |
|
train_size, |
|
val_size, |
|
test_size, |
|
seed, |
|
batch_size=48, |
|
filename=None, |
|
splits=None, |
|
order=None, |
|
): |
|
if splits is not None: |
|
splits = np.load(splits) |
|
idx_train = splits["idx_train"] |
|
idx_val = splits["idx_val"] |
|
idx_test = splits["idx_test"] |
|
else: |
|
idx_train, idx_val, idx_test = train_val_test_split_by_uniprot_id( |
|
dset, train_size, val_size, test_size, seed, batch_size, order |
|
) |
|
|
|
if filename is not None: |
|
np.savez(filename, idx_train=idx_train, idx_val=idx_val, idx_test=idx_test) |
|
|
|
return ( |
|
torch.from_numpy(idx_train), |
|
torch.from_numpy(idx_val), |
|
torch.from_numpy(idx_test), |
|
) |
|
|
|
|
|
def reshuffle_train_by_uniprot_id(idx_train, batch_size, dset, seed=None): |
|
idx_train = guarantee_no_same_protein_in_one_batch( |
|
idx_train, batch_size, |
|
np.array(dset.data.uniprotID.iloc[idx_train] + dset.data.ENST.iloc[idx_train]), |
|
seed=seed |
|
) |
|
return idx_train |
|
|
|
|
|
def reshuffle_train_by_good_batch(idx_train, batch_size, dset, seed=None): |
|
idx_train = guarantee_good_batch( |
|
idx_train, batch_size, dset, seed=seed |
|
) |
|
return idx_train |
|
|
|
|
|
def reshuffle_train_by_uniprot_id_good_batch(idx_train, batch_size, dset, seed=None): |
|
return reshuffle_train_by_good_batch(idx_train, batch_size, dset, seed) |
|
|
|
|
|
def reshuffle_train_by_anno(idx_train, batch_size, dset, seed=None): |
|
if seed is not None: |
|
idx_train = np.random.default_rng(seed).permutation(idx_train) |
|
else: |
|
idx_train = np.random.permutation(idx_train) |
|
return idx_train |
|
|
|
|
|
def reshuffle_train(idx_train, batch_size, dset, seed=None): |
|
if seed is not None: |
|
idx_train = np.random.default_rng(seed).permutation(idx_train) |
|
else: |
|
idx_train = np.random.permutation(idx_train) |
|
return idx_train |
|
|
|
|
|
def select_by_uniprot(freq_table, number_to_select): |
|
selected = 0 |
|
selected_uniprotIDs = [] |
|
candidates = freq_table[freq_table <= number_to_select - selected] |
|
while selected < number_to_select and len(candidates) > 0: |
|
selected_uniprotID = np.random.choice(candidates.index) |
|
selected_uniprotIDs.append(selected_uniprotID) |
|
selected += freq_table[selected_uniprotID] |
|
|
|
freq_table = freq_table.drop(selected_uniprotID) |
|
candidates = freq_table[freq_table <= number_to_select - selected] |
|
return selected_uniprotIDs, freq_table |
|
|
|
|
|
def guarantee_no_same_protein_in_one_batch(idxs, batch_size, protein_identifiers, seed=0): |
|
assert len(idxs) == len(protein_identifiers) |
|
if seed is not None: |
|
np.random.seed(seed) |
|
|
|
result = [] |
|
while len(protein_identifiers) >= batch_size: |
|
unique_protein_identifiers, first_idx, counts = np.unique(protein_identifiers, |
|
return_index=True, return_counts=True) |
|
if len(unique_protein_identifiers) < batch_size: |
|
break |
|
unique_random_selected = np.random.choice(unique_protein_identifiers, |
|
size=batch_size, replace=False, p=counts / sum(counts)) |
|
unique_random_selected_idx = first_idx[[np.argwhere(unique_protein_identifiers == i)[0][0] |
|
for i in unique_random_selected]] |
|
result += list(idxs[unique_random_selected_idx]) |
|
|
|
idxs = np.delete(idxs, unique_random_selected_idx) |
|
protein_identifiers = np.delete(protein_identifiers, unique_random_selected_idx) |
|
|
|
if len(idxs) > 0: |
|
result += list(idxs) |
|
return np.array(result) |
|
|
|
|
|
def guarantee_good_batch_not_same_len(idxs, batch_size, dset, seed=0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert batch_size >= 4 |
|
if seed is not None: |
|
np.random.seed(seed) |
|
|
|
pos_idxs = idxs[dset.data[dset._y_columns[0]].iloc[idxs] == 1] |
|
pos_ids = dset.data["uniprotID"].iloc[pos_idxs].to_numpy() |
|
neg_idxs = idxs[dset.data[dset._y_columns[0]].iloc[idxs] != 1] |
|
neg_ids = dset.data["uniprotID"].iloc[neg_idxs].to_numpy() |
|
|
|
result = [] |
|
|
|
for i in range(len(pos_idxs)): |
|
id = pos_ids[i] |
|
result.append(pos_idxs[i]) |
|
|
|
neg_idx = neg_idxs[neg_ids == id] |
|
if len(neg_idx) > 0: |
|
result.append(np.random.choice(neg_idx)) |
|
else: |
|
result.append(np.random.choice(neg_idxs)) |
|
|
|
pos_idx = pos_idxs[pos_ids != id] |
|
if len(pos_idx) == 0: |
|
pos_idx = pos_idxs[pos_idxs != pos_idxs[i]] |
|
result.append(np.random.choice(pos_idx)) |
|
|
|
neg_idx = neg_idxs[neg_ids != id] |
|
if len(neg_idx) == 0: |
|
neg_idx = neg_idxs[neg_idxs != neg_idxs[i]] |
|
result.append(np.random.choice(neg_idx)) |
|
|
|
if batch_size > 4: |
|
result += list(np.random.choice(idxs, size=batch_size - 4, replace=False)) |
|
|
|
unused_idxs = np.setdiff1d(idxs, result) |
|
if len(unused_idxs) > 0: |
|
result += list(unused_idxs) |
|
return np.array(result) |
|
|
|
|
|
def guarantee_good_batch(idxs, batch_size, dset, seed=0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert batch_size >= 4 |
|
if seed is not None: |
|
np.random.seed(seed) |
|
if not isinstance(idxs, np.ndarray): |
|
idxs = np.array(idxs) |
|
|
|
pos_label = 3 |
|
if sum(dset.data[dset._y_columns[0]].iloc[idxs] == 1) > 0: |
|
pos_label = 1 |
|
pos_idxs = idxs[dset.data[dset._y_columns[0]].iloc[idxs] == pos_label] |
|
pos_ids = dset.data["uniprotID"].iloc[pos_idxs].to_numpy() |
|
neg_idxs = idxs[dset.data[dset._y_columns[0]].iloc[idxs] != pos_label] |
|
neg_ids = dset.data["uniprotID"].iloc[neg_idxs].to_numpy() |
|
|
|
result = [] |
|
|
|
while len(pos_idxs) > 0: |
|
this_batch_added = 0 |
|
|
|
pos_idx = np.random.choice(pos_idxs) |
|
id = pos_ids[pos_idxs == pos_idx][0] |
|
result.append(pos_idx) |
|
this_batch_added += 1 |
|
|
|
pos_ids = np.delete(pos_ids, np.argwhere(pos_idxs == pos_idx)) |
|
idxs = np.delete(idxs, np.argwhere(idxs == pos_idx)) |
|
pos_idxs = np.delete(pos_idxs, np.argwhere(pos_idxs == pos_idx)) |
|
|
|
neg_idx = neg_idxs[neg_ids == id] |
|
if len(neg_idx) > 0: |
|
neg_idx = np.random.choice(neg_idx) |
|
elif len(neg_idxs) > 0: |
|
neg_idx = np.random.choice(neg_idxs) |
|
else: |
|
neg_idx = None |
|
if neg_idx is not None: |
|
result.append(neg_idx) |
|
this_batch_added += 1 |
|
|
|
neg_ids = np.delete(neg_ids, np.argwhere(neg_idxs == neg_idx)) |
|
idxs = np.delete(idxs, np.argwhere(idxs == neg_idx)) |
|
neg_idxs = np.delete(neg_idxs, np.argwhere(neg_idxs == neg_idx)) |
|
|
|
pos_idx_candidate = pos_idxs[pos_ids != id] |
|
if len(pos_idx_candidate) == 0: |
|
pos_idx_candidate = pos_idxs[pos_idxs != pos_idx] |
|
if len(pos_idx_candidate) > 0: |
|
pos_idx = np.random.choice(pos_idx_candidate) |
|
result.append(pos_idx) |
|
this_batch_added += 1 |
|
|
|
pos_ids = np.delete(pos_ids, np.argwhere(pos_idxs == pos_idx)) |
|
idxs = np.delete(idxs, np.argwhere(idxs == pos_idx)) |
|
pos_idxs = np.delete(pos_idxs, np.argwhere(pos_idxs == pos_idx)) |
|
|
|
neg_idx_candidate = neg_idxs[neg_ids != id] |
|
if len(neg_idx_candidate) == 0: |
|
neg_idx_candidate = neg_idxs[neg_idxs != neg_idx] |
|
if len(neg_idx_candidate) > 0: |
|
neg_idx = np.random.choice(neg_idx_candidate) |
|
result.append(neg_idx) |
|
this_batch_added += 1 |
|
|
|
neg_ids = np.delete(neg_ids, np.argwhere(neg_idxs == neg_idx)) |
|
idxs = np.delete(idxs, np.argwhere(idxs == neg_idx)) |
|
neg_idxs = np.delete(neg_idxs, np.argwhere(neg_idxs == neg_idx)) |
|
|
|
if batch_size > this_batch_added and len(idxs) >= batch_size - this_batch_added: |
|
to_add = np.random.choice(idxs, size=batch_size - this_batch_added, replace=False) |
|
result += list(to_add) |
|
|
|
for idx in to_add: |
|
pos_ids = np.delete(pos_ids, np.argwhere(pos_idxs == idx)) |
|
pos_idxs = np.delete(pos_idxs, np.argwhere(pos_idxs == idx)) |
|
neg_ids = np.delete(neg_ids, np.argwhere(neg_idxs == idx)) |
|
neg_idxs = np.delete(neg_idxs, np.argwhere(neg_idxs == idx)) |
|
idxs = np.delete(idxs, np.argwhere(idxs == idx)) |
|
|
|
unused_idxs = np.setdiff1d(idxs, result) |
|
if len(unused_idxs) > 0: |
|
result += list(unused_idxs) |
|
return np.array(result) |
|
|
|
|
|
def save_argparse(args, filename, exclude=None): |
|
import json |
|
|
|
if filename.endswith("yaml") or filename.endswith("yml"): |
|
if isinstance(exclude, str): |
|
exclude = [exclude] |
|
args = args.__dict__.copy() |
|
for exl in exclude: |
|
del args[exl] |
|
|
|
ds_arg = args.get("dataset_arg") |
|
if ds_arg is not None and isinstance(ds_arg, str): |
|
args["dataset_arg"] = json.loads(args["dataset_arg"]) |
|
yaml.dump(args, open(filename, "w")) |
|
else: |
|
raise ValueError("Configuration file should end with yaml or yml") |
|
|