Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
File size: 12,135 Bytes
ffaff91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import pandas as pd
from sklearn.model_selection import train_test_split
from fuson_plm.utils.logging import log_update

def split_clusters_train_test(X, y, benchmark_cluster_reps=[], random_state = 1, test_size = 0.20):
    # cluster with random state fixed for reproducible results
    log_update(f"\tPerforming split: all clusters -> train clusters ({round(1-test_size,3)}) and test clusters ({test_size})")
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # add benchmark representatives back to X_test
    log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
    X_test += benchmark_cluster_reps
    
    # assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
    assert len(X_train)==len(set(X_train))
    assert len(X_test)==len(set(X_test))
        
    return {
        'X_train': X_train,
        'X_test': X_test
    }
    
def split_clusters_train_val_test(X, y, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
    # cluster with random state fixed for reproducible results
    log_update(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})")
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=random_state_1)
    log_update(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})")
    X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=random_state_2)
    
    # add benchmark representatives back to X_test
    log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
    X_test += benchmark_cluster_reps
    
    # assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
    assert len(X_train)==len(set(X_train))
    assert len(X_val)==len(set(X_val))
    assert len(X_test)==len(set(X_test))
        
    return {
        'X_train': X_train,
        'X_val': X_val,
        'X_test': X_test
    }
    
def split_clusters(cluster_representatives: list, val_set = True, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
    """"
    Cluster-splitting method amenable to either train-test or train-val-test. 
    For train-val-test, there are two splits. 
    """
    log_update("\nPerforming splits...")
    # Approx. 80/10/10 split
    X = [x for x in cluster_representatives if not(x in benchmark_cluster_reps)]     # X, for splitting, does NOT include benchmark reps. We'll add these clusters to test. 
    y = [0]*len(X)      # y is a dummy array here; there are no values. 

    split_dict = None
    if val_set:
        split_dict = split_clusters_train_val_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps, 
                                      random_state_1 = random_state_1, random_state_2 = random_state_2, 
                                      test_size_1 = test_size_1, test_size_2 = test_size_2)
    else:
        split_dict = split_clusters_train_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps, 
                                      random_state = random_state_1,
                                      test_size = test_size_1)
    
    return split_dict
        
def check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=None):
    """
    Args:
        train_clusters (pd.DataFrame): 
        val_clusters (pd.DataFrame): (optional - can pass None if there is no validation set)
        test_clusters (pd.DataFrame):
    """
    
    # Make grouped versions of these DataFrames for size analysis
    train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
    if val_clusters is not None:
        val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
    if test_clusters is not None:
        test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
    
    # Calculate stats - clusters
    n_train_clusters = len(train_clustersgb)
    n_val_clusters, n_test_clusters = 0, 0
    if val_clusters is not None:
        n_val_clusters = len(val_clustersgb)
    if test_clusters is not None:
        n_test_clusters = len(test_clustersgb)
    n_clusters = n_train_clusters + n_val_clusters + n_test_clusters
    
    assert len(train_clusters['representative seq_id'].unique()) == len(train_clustersgb)
    if val_clusters is not None:
        assert len(val_clusters['representative seq_id'].unique()) == len(val_clustersgb)
    if test_clusters is not None:
        assert len(test_clusters['representative seq_id'].unique()) == len(test_clustersgb)
    
    train_cluster_pcnt = round(100*n_train_clusters/n_clusters,2)
    if val_clusters is not None:
        val_cluster_pcnt = round(100*n_val_clusters/n_clusters,2)
    if test_clusters is not None:
        test_cluster_pcnt = round(100*n_test_clusters/n_clusters,2)
    
    # Calculate stats - proteins
    n_train_proteins = len(train_clusters)
    n_val_proteins, n_test_proteins = 0, 0
    if val_clusters is not None:
        n_val_proteins = len(val_clusters)
    if test_clusters is not None:
        n_test_proteins = len(test_clusters)
    n_proteins = n_train_proteins + n_val_proteins + n_test_proteins
    
    assert len(train_clusters) == sum(train_clustersgb['member count'])
    if val_clusters is not None:
        assert len(val_clusters) == sum(val_clustersgb['member count'])
    if test_clusters is not None:
        assert len(test_clusters) == sum(test_clustersgb['member count'])
    
    train_protein_pcnt = round(100*n_train_proteins/n_proteins,2)
    if val_clusters is not None:
        val_protein_pcnt = round(100*n_val_proteins/n_proteins,2)
    if test_clusters is not None:
        test_protein_pcnt = round(100*n_test_proteins/n_proteins,2)
    
    # Print results
    log_update("\nCluster breakdown...")
    log_update(f"Total clusters = {n_clusters}, total proteins = {n_proteins}")
    log_update(f"\tTrain set:\n\t\tTotal Clusters = {len(train_clustersgb)} ({train_cluster_pcnt}%)\n\t\tTotal Proteins = {len(train_clusters)} ({train_protein_pcnt}%)")
    if val_clusters is not None:
        log_update(f"\tVal set:\n\t\tTotal Clusters = {len(val_clustersgb)} ({val_cluster_pcnt}%)\n\t\tTotal Proteins = {len(val_clusters)} ({val_protein_pcnt}%)")
    if test_clusters is not None:
        log_update(f"\tTest set:\n\t\tTotal Clusters = {len(test_clustersgb)} ({test_cluster_pcnt}%)\n\t\tTotal Proteins = {len(test_clusters)} ({test_protein_pcnt}%)")

    # Check for overlap in both sequence ID and sequence actual
    train_protein_ids = set(train_clusters['member seq_id'])
    train_protein_seqs = set(train_clusters['member seq'])
    if val_clusters is not None:
        val_protein_ids = set(val_clusters['member seq_id'])
        val_protein_seqs = set(val_clusters['member seq'])
    if test_clusters is not None:
        test_protein_ids = set(test_clusters['member seq_id'])
        test_protein_seqs = set(test_clusters['member seq'])
    
    # Print results
    log_update("\nChecking for overlap...")
    if (val_clusters is not None) and (test_clusters is not None):
        log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}\n\t\tVal-Test Overlap: {len(val_protein_ids.intersection(test_protein_ids))}")
        log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}\n\t\tVal-Test Overlap: {len(val_protein_seqs.intersection(test_protein_seqs))}")
    if (val_clusters is not None) and (test_clusters is None):
        log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}")
        log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}")
    if (val_clusters is None) and (test_clusters is not None):
        log_update(f"\tSequence IDs...\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}")
        log_update(f"\tSequences...\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}")
    
    # Assert no sequence overlap
    if val_clusters is not None:
        assert len(train_protein_seqs.intersection(val_protein_seqs))==0
    if test_clusters is not None:
        assert len(train_protein_seqs.intersection(test_protein_seqs))==0
    if (val_clusters is not None) and (test_clusters is not None):
        assert len(val_protein_seqs.intersection(test_protein_seqs))==0

    # Finally, check that there are only benchmark sequences in test - if there are benchmark sequences
    if not(benchmark_sequences is None):
        bench_in_train = len(train_clusters.loc[train_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
        bench_in_val, bench_in_test = 0, 0
        if val_clusters is not None:
            bench_in_val = len(val_clusters.loc[val_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
        if test_clusters is not None:
            bench_in_test = len(test_clusters.loc[test_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
        
        # Assert this
        log_update("\nChecking for benchmark sequence presence in test, and absence from train and val...")
        log_update(f"\tTotal benchmark sequences: {len(benchmark_sequences)}")
        log_update(f"\tBenchmark sequences in train: {bench_in_train}")
        if val_clusters is not None:
            log_update(f"\tBenchmark sequences in val: {bench_in_val}")
        if test_clusters is not None:
            log_update(f"\tBenchmark sequences in test: {bench_in_test}")
        assert bench_in_train == bench_in_val == 0
        assert bench_in_test == len(benchmark_sequences)

def check_class_distributions(train_df, val_df, test_df, class_col='class'):
    """
    Checks class distributions within train, val, and test sets. 
    Expects input dataframes to have 'sequence' column and 'class' column
    """
    train_vc = pd.DataFrame(train_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'train_count'})
    train_vc['train_pct'] = (train_vc['train_count'] / train_vc['train_count'].sum()).round(3)*100
    if val_df is not None:
        val_vc = pd.DataFrame(val_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'val_count'})
        val_vc['val_pct'] = (val_vc['val_count'] / val_vc['val_count'].sum()).round(3)*100
    test_vc = pd.DataFrame(test_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'test_count'})
    test_vc['test_pct'] = (test_vc['test_count'] / test_vc['test_count'].sum()).round(3)*100
    # concatenate so I can see them next to each other
    if val_df is not None:
        compare = pd.concat([train_vc, val_vc, test_vc], axis=1)
        compare['train-val diff'] = (compare['train_pct'] - compare['val_pct']).apply(lambda x: abs(x))
        compare['val-test diff'] = (compare['val_pct'] - compare['test_pct']).apply(lambda x: abs(x))
    else:
        compare = pd.concat([train_vc, test_vc], axis=1)
    compare['train-test diff'] = (compare['train_pct'] - compare['test_pct']).apply(lambda x: abs(x))
    
    compare_str = compare.to_string(index=False)
    compare_str = "\t" + compare_str.replace("\n","\n\t")
    log_update(f"\nClass distribution:\n{compare_str}")