Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
FusOn-pLM / fuson_plm /utils /splitting.py
svincoff's picture
adding utility files used throughout FusOn-pLM training and benchmarking
ffaff91
raw
history blame
12.1 kB
import pandas as pd
from sklearn.model_selection import train_test_split
from fuson_plm.utils.logging import log_update
def split_clusters_train_test(X, y, benchmark_cluster_reps=[], random_state = 1, test_size = 0.20):
# cluster with random state fixed for reproducible results
log_update(f"\tPerforming split: all clusters -> train clusters ({round(1-test_size,3)}) and test clusters ({test_size})")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
# add benchmark representatives back to X_test
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
X_test += benchmark_cluster_reps
# assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
assert len(X_train)==len(set(X_train))
assert len(X_test)==len(set(X_test))
return {
'X_train': X_train,
'X_test': X_test
}
def split_clusters_train_val_test(X, y, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
# cluster with random state fixed for reproducible results
log_update(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=random_state_1)
log_update(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})")
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=random_state_2)
# add benchmark representatives back to X_test
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
X_test += benchmark_cluster_reps
# assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
assert len(X_train)==len(set(X_train))
assert len(X_val)==len(set(X_val))
assert len(X_test)==len(set(X_test))
return {
'X_train': X_train,
'X_val': X_val,
'X_test': X_test
}
def split_clusters(cluster_representatives: list, val_set = True, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
""""
Cluster-splitting method amenable to either train-test or train-val-test.
For train-val-test, there are two splits.
"""
log_update("\nPerforming splits...")
# Approx. 80/10/10 split
X = [x for x in cluster_representatives if not(x in benchmark_cluster_reps)] # X, for splitting, does NOT include benchmark reps. We'll add these clusters to test.
y = [0]*len(X) # y is a dummy array here; there are no values.
split_dict = None
if val_set:
split_dict = split_clusters_train_val_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
random_state_1 = random_state_1, random_state_2 = random_state_2,
test_size_1 = test_size_1, test_size_2 = test_size_2)
else:
split_dict = split_clusters_train_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
random_state = random_state_1,
test_size = test_size_1)
return split_dict
def check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=None):
"""
Args:
train_clusters (pd.DataFrame):
val_clusters (pd.DataFrame): (optional - can pass None if there is no validation set)
test_clusters (pd.DataFrame):
"""
# Make grouped versions of these DataFrames for size analysis
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
if val_clusters is not None:
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
if test_clusters is not None:
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
# Calculate stats - clusters
n_train_clusters = len(train_clustersgb)
n_val_clusters, n_test_clusters = 0, 0
if val_clusters is not None:
n_val_clusters = len(val_clustersgb)
if test_clusters is not None:
n_test_clusters = len(test_clustersgb)
n_clusters = n_train_clusters + n_val_clusters + n_test_clusters
assert len(train_clusters['representative seq_id'].unique()) == len(train_clustersgb)
if val_clusters is not None:
assert len(val_clusters['representative seq_id'].unique()) == len(val_clustersgb)
if test_clusters is not None:
assert len(test_clusters['representative seq_id'].unique()) == len(test_clustersgb)
train_cluster_pcnt = round(100*n_train_clusters/n_clusters,2)
if val_clusters is not None:
val_cluster_pcnt = round(100*n_val_clusters/n_clusters,2)
if test_clusters is not None:
test_cluster_pcnt = round(100*n_test_clusters/n_clusters,2)
# Calculate stats - proteins
n_train_proteins = len(train_clusters)
n_val_proteins, n_test_proteins = 0, 0
if val_clusters is not None:
n_val_proteins = len(val_clusters)
if test_clusters is not None:
n_test_proteins = len(test_clusters)
n_proteins = n_train_proteins + n_val_proteins + n_test_proteins
assert len(train_clusters) == sum(train_clustersgb['member count'])
if val_clusters is not None:
assert len(val_clusters) == sum(val_clustersgb['member count'])
if test_clusters is not None:
assert len(test_clusters) == sum(test_clustersgb['member count'])
train_protein_pcnt = round(100*n_train_proteins/n_proteins,2)
if val_clusters is not None:
val_protein_pcnt = round(100*n_val_proteins/n_proteins,2)
if test_clusters is not None:
test_protein_pcnt = round(100*n_test_proteins/n_proteins,2)
# Print results
log_update("\nCluster breakdown...")
log_update(f"Total clusters = {n_clusters}, total proteins = {n_proteins}")
log_update(f"\tTrain set:\n\t\tTotal Clusters = {len(train_clustersgb)} ({train_cluster_pcnt}%)\n\t\tTotal Proteins = {len(train_clusters)} ({train_protein_pcnt}%)")
if val_clusters is not None:
log_update(f"\tVal set:\n\t\tTotal Clusters = {len(val_clustersgb)} ({val_cluster_pcnt}%)\n\t\tTotal Proteins = {len(val_clusters)} ({val_protein_pcnt}%)")
if test_clusters is not None:
log_update(f"\tTest set:\n\t\tTotal Clusters = {len(test_clustersgb)} ({test_cluster_pcnt}%)\n\t\tTotal Proteins = {len(test_clusters)} ({test_protein_pcnt}%)")
# Check for overlap in both sequence ID and sequence actual
train_protein_ids = set(train_clusters['member seq_id'])
train_protein_seqs = set(train_clusters['member seq'])
if val_clusters is not None:
val_protein_ids = set(val_clusters['member seq_id'])
val_protein_seqs = set(val_clusters['member seq'])
if test_clusters is not None:
test_protein_ids = set(test_clusters['member seq_id'])
test_protein_seqs = set(test_clusters['member seq'])
# Print results
log_update("\nChecking for overlap...")
if (val_clusters is not None) and (test_clusters is not None):
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}\n\t\tVal-Test Overlap: {len(val_protein_ids.intersection(test_protein_ids))}")
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}\n\t\tVal-Test Overlap: {len(val_protein_seqs.intersection(test_protein_seqs))}")
if (val_clusters is not None) and (test_clusters is None):
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}")
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}")
if (val_clusters is None) and (test_clusters is not None):
log_update(f"\tSequence IDs...\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}")
log_update(f"\tSequences...\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}")
# Assert no sequence overlap
if val_clusters is not None:
assert len(train_protein_seqs.intersection(val_protein_seqs))==0
if test_clusters is not None:
assert len(train_protein_seqs.intersection(test_protein_seqs))==0
if (val_clusters is not None) and (test_clusters is not None):
assert len(val_protein_seqs.intersection(test_protein_seqs))==0
# Finally, check that there are only benchmark sequences in test - if there are benchmark sequences
if not(benchmark_sequences is None):
bench_in_train = len(train_clusters.loc[train_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
bench_in_val, bench_in_test = 0, 0
if val_clusters is not None:
bench_in_val = len(val_clusters.loc[val_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
if test_clusters is not None:
bench_in_test = len(test_clusters.loc[test_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
# Assert this
log_update("\nChecking for benchmark sequence presence in test, and absence from train and val...")
log_update(f"\tTotal benchmark sequences: {len(benchmark_sequences)}")
log_update(f"\tBenchmark sequences in train: {bench_in_train}")
if val_clusters is not None:
log_update(f"\tBenchmark sequences in val: {bench_in_val}")
if test_clusters is not None:
log_update(f"\tBenchmark sequences in test: {bench_in_test}")
assert bench_in_train == bench_in_val == 0
assert bench_in_test == len(benchmark_sequences)
def check_class_distributions(train_df, val_df, test_df, class_col='class'):
"""
Checks class distributions within train, val, and test sets.
Expects input dataframes to have 'sequence' column and 'class' column
"""
train_vc = pd.DataFrame(train_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'train_count'})
train_vc['train_pct'] = (train_vc['train_count'] / train_vc['train_count'].sum()).round(3)*100
if val_df is not None:
val_vc = pd.DataFrame(val_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'val_count'})
val_vc['val_pct'] = (val_vc['val_count'] / val_vc['val_count'].sum()).round(3)*100
test_vc = pd.DataFrame(test_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'test_count'})
test_vc['test_pct'] = (test_vc['test_count'] / test_vc['test_count'].sum()).round(3)*100
# concatenate so I can see them next to each other
if val_df is not None:
compare = pd.concat([train_vc, val_vc, test_vc], axis=1)
compare['train-val diff'] = (compare['train_pct'] - compare['val_pct']).apply(lambda x: abs(x))
compare['val-test diff'] = (compare['val_pct'] - compare['test_pct']).apply(lambda x: abs(x))
else:
compare = pd.concat([train_vc, test_vc], axis=1)
compare['train-test diff'] = (compare['train_pct'] - compare['test_pct']).apply(lambda x: abs(x))
compare_str = compare.to_string(index=False)
compare_str = "\t" + compare_str.replace("\n","\n\t")
log_update(f"\nClass distribution:\n{compare_str}")