|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
from fuson_plm.utils.logging import log_update |
|
|
|
def split_clusters_train_test(X, y, benchmark_cluster_reps=[], random_state = 1, test_size = 0.20): |
|
|
|
log_update(f"\tPerforming split: all clusters -> train clusters ({round(1-test_size,3)}) and test clusters ({test_size})") |
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state) |
|
|
|
|
|
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test") |
|
X_test += benchmark_cluster_reps |
|
|
|
|
|
assert len(X_train)==len(set(X_train)) |
|
assert len(X_test)==len(set(X_test)) |
|
|
|
return { |
|
'X_train': X_train, |
|
'X_test': X_test |
|
} |
|
|
|
def split_clusters_train_val_test(X, y, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50): |
|
|
|
log_update(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})") |
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=random_state_1) |
|
log_update(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})") |
|
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=random_state_2) |
|
|
|
|
|
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test") |
|
X_test += benchmark_cluster_reps |
|
|
|
|
|
assert len(X_train)==len(set(X_train)) |
|
assert len(X_val)==len(set(X_val)) |
|
assert len(X_test)==len(set(X_test)) |
|
|
|
return { |
|
'X_train': X_train, |
|
'X_val': X_val, |
|
'X_test': X_test |
|
} |
|
|
|
def split_clusters(cluster_representatives: list, val_set = True, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50): |
|
"""" |
|
Cluster-splitting method amenable to either train-test or train-val-test. |
|
For train-val-test, there are two splits. |
|
""" |
|
log_update("\nPerforming splits...") |
|
|
|
X = [x for x in cluster_representatives if not(x in benchmark_cluster_reps)] |
|
y = [0]*len(X) |
|
|
|
split_dict = None |
|
if val_set: |
|
split_dict = split_clusters_train_val_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps, |
|
random_state_1 = random_state_1, random_state_2 = random_state_2, |
|
test_size_1 = test_size_1, test_size_2 = test_size_2) |
|
else: |
|
split_dict = split_clusters_train_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps, |
|
random_state = random_state_1, |
|
test_size = test_size_1) |
|
|
|
return split_dict |
|
|
|
def check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=None): |
|
""" |
|
Args: |
|
train_clusters (pd.DataFrame): |
|
val_clusters (pd.DataFrame): (optional - can pass None if there is no validation set) |
|
test_clusters (pd.DataFrame): |
|
""" |
|
|
|
|
|
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
if val_clusters is not None: |
|
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
if test_clusters is not None: |
|
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
|
|
|
|
n_train_clusters = len(train_clustersgb) |
|
n_val_clusters, n_test_clusters = 0, 0 |
|
if val_clusters is not None: |
|
n_val_clusters = len(val_clustersgb) |
|
if test_clusters is not None: |
|
n_test_clusters = len(test_clustersgb) |
|
n_clusters = n_train_clusters + n_val_clusters + n_test_clusters |
|
|
|
assert len(train_clusters['representative seq_id'].unique()) == len(train_clustersgb) |
|
if val_clusters is not None: |
|
assert len(val_clusters['representative seq_id'].unique()) == len(val_clustersgb) |
|
if test_clusters is not None: |
|
assert len(test_clusters['representative seq_id'].unique()) == len(test_clustersgb) |
|
|
|
train_cluster_pcnt = round(100*n_train_clusters/n_clusters,2) |
|
if val_clusters is not None: |
|
val_cluster_pcnt = round(100*n_val_clusters/n_clusters,2) |
|
if test_clusters is not None: |
|
test_cluster_pcnt = round(100*n_test_clusters/n_clusters,2) |
|
|
|
|
|
n_train_proteins = len(train_clusters) |
|
n_val_proteins, n_test_proteins = 0, 0 |
|
if val_clusters is not None: |
|
n_val_proteins = len(val_clusters) |
|
if test_clusters is not None: |
|
n_test_proteins = len(test_clusters) |
|
n_proteins = n_train_proteins + n_val_proteins + n_test_proteins |
|
|
|
assert len(train_clusters) == sum(train_clustersgb['member count']) |
|
if val_clusters is not None: |
|
assert len(val_clusters) == sum(val_clustersgb['member count']) |
|
if test_clusters is not None: |
|
assert len(test_clusters) == sum(test_clustersgb['member count']) |
|
|
|
train_protein_pcnt = round(100*n_train_proteins/n_proteins,2) |
|
if val_clusters is not None: |
|
val_protein_pcnt = round(100*n_val_proteins/n_proteins,2) |
|
if test_clusters is not None: |
|
test_protein_pcnt = round(100*n_test_proteins/n_proteins,2) |
|
|
|
|
|
log_update("\nCluster breakdown...") |
|
log_update(f"Total clusters = {n_clusters}, total proteins = {n_proteins}") |
|
log_update(f"\tTrain set:\n\t\tTotal Clusters = {len(train_clustersgb)} ({train_cluster_pcnt}%)\n\t\tTotal Proteins = {len(train_clusters)} ({train_protein_pcnt}%)") |
|
if val_clusters is not None: |
|
log_update(f"\tVal set:\n\t\tTotal Clusters = {len(val_clustersgb)} ({val_cluster_pcnt}%)\n\t\tTotal Proteins = {len(val_clusters)} ({val_protein_pcnt}%)") |
|
if test_clusters is not None: |
|
log_update(f"\tTest set:\n\t\tTotal Clusters = {len(test_clustersgb)} ({test_cluster_pcnt}%)\n\t\tTotal Proteins = {len(test_clusters)} ({test_protein_pcnt}%)") |
|
|
|
|
|
train_protein_ids = set(train_clusters['member seq_id']) |
|
train_protein_seqs = set(train_clusters['member seq']) |
|
if val_clusters is not None: |
|
val_protein_ids = set(val_clusters['member seq_id']) |
|
val_protein_seqs = set(val_clusters['member seq']) |
|
if test_clusters is not None: |
|
test_protein_ids = set(test_clusters['member seq_id']) |
|
test_protein_seqs = set(test_clusters['member seq']) |
|
|
|
|
|
log_update("\nChecking for overlap...") |
|
if (val_clusters is not None) and (test_clusters is not None): |
|
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}\n\t\tVal-Test Overlap: {len(val_protein_ids.intersection(test_protein_ids))}") |
|
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}\n\t\tVal-Test Overlap: {len(val_protein_seqs.intersection(test_protein_seqs))}") |
|
if (val_clusters is not None) and (test_clusters is None): |
|
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}") |
|
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}") |
|
if (val_clusters is None) and (test_clusters is not None): |
|
log_update(f"\tSequence IDs...\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}") |
|
log_update(f"\tSequences...\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}") |
|
|
|
|
|
if val_clusters is not None: |
|
assert len(train_protein_seqs.intersection(val_protein_seqs))==0 |
|
if test_clusters is not None: |
|
assert len(train_protein_seqs.intersection(test_protein_seqs))==0 |
|
if (val_clusters is not None) and (test_clusters is not None): |
|
assert len(val_protein_seqs.intersection(test_protein_seqs))==0 |
|
|
|
|
|
if not(benchmark_sequences is None): |
|
bench_in_train = len(train_clusters.loc[train_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique()) |
|
bench_in_val, bench_in_test = 0, 0 |
|
if val_clusters is not None: |
|
bench_in_val = len(val_clusters.loc[val_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique()) |
|
if test_clusters is not None: |
|
bench_in_test = len(test_clusters.loc[test_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique()) |
|
|
|
|
|
log_update("\nChecking for benchmark sequence presence in test, and absence from train and val...") |
|
log_update(f"\tTotal benchmark sequences: {len(benchmark_sequences)}") |
|
log_update(f"\tBenchmark sequences in train: {bench_in_train}") |
|
if val_clusters is not None: |
|
log_update(f"\tBenchmark sequences in val: {bench_in_val}") |
|
if test_clusters is not None: |
|
log_update(f"\tBenchmark sequences in test: {bench_in_test}") |
|
assert bench_in_train == bench_in_val == 0 |
|
assert bench_in_test == len(benchmark_sequences) |
|
|
|
def check_class_distributions(train_df, val_df, test_df, class_col='class'): |
|
""" |
|
Checks class distributions within train, val, and test sets. |
|
Expects input dataframes to have 'sequence' column and 'class' column |
|
""" |
|
train_vc = pd.DataFrame(train_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'train_count'}) |
|
train_vc['train_pct'] = (train_vc['train_count'] / train_vc['train_count'].sum()).round(3)*100 |
|
if val_df is not None: |
|
val_vc = pd.DataFrame(val_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'val_count'}) |
|
val_vc['val_pct'] = (val_vc['val_count'] / val_vc['val_count'].sum()).round(3)*100 |
|
test_vc = pd.DataFrame(test_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'test_count'}) |
|
test_vc['test_pct'] = (test_vc['test_count'] / test_vc['test_count'].sum()).round(3)*100 |
|
|
|
if val_df is not None: |
|
compare = pd.concat([train_vc, val_vc, test_vc], axis=1) |
|
compare['train-val diff'] = (compare['train_pct'] - compare['val_pct']).apply(lambda x: abs(x)) |
|
compare['val-test diff'] = (compare['val_pct'] - compare['test_pct']).apply(lambda x: abs(x)) |
|
else: |
|
compare = pd.concat([train_vc, test_vc], axis=1) |
|
compare['train-test diff'] = (compare['train_pct'] - compare['test_pct']).apply(lambda x: abs(x)) |
|
|
|
compare_str = compare.to_string(index=False) |
|
compare_str = "\t" + compare_str.replace("\n","\n\t") |
|
log_update(f"\nClass distribution:\n{compare_str}") |