| | import torch |
| | import pandas as pd |
| | import seaborn as sns |
| | import matplotlib.pyplot as plt |
| | from umap import UMAP |
| | from sklearn.manifold import TSNE |
| | from sklearn.decomposition import PCA |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | path = "/workspace/sg666/MDpLM/benchmarks/Generation" |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | esm_model_path = "facebook/esm2_t33_650M_UR50D" |
| |
|
| | |
| | def load_esm2_model(model_name): |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModel.from_pretrained(model_name).to(device) |
| | return tokenizer, model |
| |
|
| | def get_latents(model, tokenizer, sequence): |
| | inputs = tokenizer(sequence, return_tensors="pt").to(device) |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | embeddings = outputs.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy().tolist() |
| | return embeddings |
| |
|
| | |
| | def parse_fasta_file(file_path): |
| | with open(file_path, 'r') as file: |
| | lines = file.readlines() |
| |
|
| | sequences = [] |
| | current_seq = [] |
| | current_type = "UniProt" |
| | |
| | for line in lines: |
| | line = line.strip() |
| | if line.startswith('>'): |
| | if current_seq: |
| | sequences.append(("".join(current_seq), current_type)) |
| | current_seq = [] |
| | else: |
| | current_seq.append(line) |
| | if current_seq: |
| | sequences.append(("".join(current_seq), current_type)) |
| | |
| | return pd.DataFrame(sequences, columns=["Sequence", "Sequence Source"]).sample(100).reset_index(drop=True) |
| |
|
| |
|
| | |
| | protgpt2_sequences = pd.read_csv(path + "/ProtGPT2/protgpt2_generated_sequences.csv") |
| | protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('<|ENDOFTEXT|>', '', regex=False) |
| | protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('""', '', regex=False) |
| | protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('\n', '', regex=False) |
| | protgpt2_sequences['Sequence'] = protgpt2_sequences['Sequence'].str.replace('X', 'G', regex=False) |
| | protgpt2_sequences.drop(columns=['Perplexity'], inplace=True) |
| | protgpt2_sequences['Sequence Source'] = "ProtGPT2" |
| | bad_sequences = [] |
| | for seq in protgpt2_sequences['Sequence']: |
| | for residue in seq: |
| | if residue in ['B', 'U', 'Z', 'O']: |
| | bad_sequences.append(seq) |
| | protgpt2_sequences = protgpt2_sequences[~protgpt2_sequences['Sequence'].isin(bad_sequences)] |
| |
|
| |
|
| | |
| | memdlm_sequences = pd.read_csv(path + "/mdlm_de-novo_generation_results.csv") |
| | memdlm_sequences.rename(columns={"Generated Sequence": "Sequence"}, inplace=True) |
| | memdlm_sequences.drop(columns=['Perplexity'], inplace=True) |
| | memdlm_sequences['Sequence Source'] = "MeMDLM" |
| | memdlm_sequences.reset_index(drop=True, inplace=True) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | other_sequences = pd.read_csv("/workspace/sg666/MDpLM/data/membrane/test.csv") |
| | other_sequences['Sequence Source'] = "Test Set" |
| | other_sequences = other_sequences.sample(100) |
| |
|
| | |
| | data = pd.concat([memdlm_sequences, protgpt2_sequences, other_sequences]) |
| |
|
| |
|
| | |
| | tokenizer, model = load_esm2_model(esm_model_path) |
| | model = model.to(device) |
| |
|
| |
|
| | |
| | data['Embeddings'] = data['Sequence'].apply(lambda sequence: get_latents(model, tokenizer, sequence)) |
| | data = data.reset_index(drop=True) |
| | umap_df = pd.DataFrame(data['Embeddings'].tolist()) |
| | umap_df.index = data['Sequence Source'] |
| |
|
| |
|
| | |
| | umap = UMAP(n_components=2) |
| | umap_features = umap.fit_transform(umap_df) |
| | umap_df['UMAP1'] = umap_features[:, 0] |
| | umap_df['UMAP2'] = umap_features[:, 1] |
| |
|
| | |
| | plt.figure(figsize=(8, 5),dpi=300) |
| | sns.scatterplot(x='UMAP1', y='UMAP2', hue='Sequence Source', data=umap_df, palette=['#297272', '#ff7477', "#9A77D0"], s=100) |
| | plt.xlabel('UMAP1') |
| | plt.ylabel('UMAP2') |
| | plt.title(f'ESM-650M Embeddings of Membrane Protein Sequences') |
| | plt.savefig('esm_umap.png') |
| | plt.show() |