muPPIt / skempi_distance.py
AlienChen's picture
Upload 139 files
65bd8af verified
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pickle
from collections import defaultdict
from predict import *
def compute_mean(tuple_list):
sum_count_dict = defaultdict(lambda: [0, 0]) # [sum, count]
# Iterate through the list and update the sum and count
for key, value in tuple_list:
sum_count_dict[key][0] += value # Sum of tuple[1] for the same tuple[0]
sum_count_dict[key][1] += 1 # Count the occurrences
# Calculate the mean for each unique tuple[0]
mean_dict = {key: round(sum_value[0] / sum_value[1],2) for key, sum_value in sum_count_dict.items()}
print(dict(sorted(mean_dict.items())))
def main():
df = pd.read_csv('/home/tc415/muPPIt_embedding/dataset/correct_skempi.csv')
results = []
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = muPPIt(1288, 8, 0.1, 10, 1e-4)
model.load_weights('/home/tc415/muPPIt_embedding/checkpoints/new_train_1/model-epoch=15-val_acc=0.57.ckpt')
device = 'cuda:1'
model.to(device)
model.eval()
for index, row in tqdm(df.iterrows(), total=len(df)):
binder = row['binder']
wildtype = row['wt']
mutant = row['mut']
mut_aff = np.log10(row['mut_affinity'])
wt_aff = np.log10(row['wt_affinity'])
binder_tokens = torch.tensor(tokenizer(binder)['input_ids'][1:-1]).unsqueeze(0).to(device)
mut_tokens = torch.tensor(tokenizer(mutant)['input_ids'][1:-1]).unsqueeze(0).to(device)
wt_tokens = torch.tensor(tokenizer(wildtype)['input_ids'][1:-1]).unsqueeze(0).to(device)
with torch.no_grad():
distance = model(binder_tokens, wt_tokens, mut_tokens)
# if distance > 20:
# continue
results.append((int(abs(wt_aff - mut_aff)), distance.item()))
compute_mean(results)
# with open('skempi_distance.pkl', 'wb') as f:
# pickle.dump(results, f)
# x_values = [t[0] for t in results]
# y_values = [t[1] for t in results]
# sns.kdeplot(x=x_values, y=y_values, fill=True, cmap='viridis')
# plt.xlim(0, None)
# plt.ylim(0, None)
# plt.xlabel('Affinity difference')
# plt.ylabel('Distance')
# plt.savefig('skempi_distance.png')
if __name__ == '__main__':
main()