PPIformer / app.py
Anton Bushuiev
Add .csv download, add inputs info
5676685
raw
history blame
13.7 kB
import copy
import random
import tempfile
from pathlib import Path
from functools import partial
import gradio as gr
import torch
import numpy as np
import pandas as pd
from Bio.PDB.Polypeptide import protein_letters_3to1
from biopandas.pdb import PandasPdb
from colour import Color
from colour import RGB_TO_COLOR_NAMES
from mutils.proteins import AMINO_ACID_CODES_1
from mutils.pdb import download_pdb
from mutils.mutations import Mutation
from ppiref.extraction import PPIExtractor
from ppiref.utils.ppi import PPIPath
from ppiref.utils.residue import Residue
from ppiformer.tasks.node import DDGPPIformer
from ppiformer.utils.api import download_weights, predict_ddg
from ppiformer.utils.torch import fill_diagonal
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
random.seed(0)
def process_inputs(inputs, temp_dir):
pdb_code, pdb_path, partners, muts, muts_path = inputs
# Check inputs
if not pdb_code and not pdb_path:
raise gr.Error("PPI structure not specified.")
if pdb_code and pdb_path:
gr.Warning("Both PDB code and PDB file specified. Using PDB file.")
if not partners:
raise gr.Error("Partners not specified.")
if not muts and not muts_path:
raise gr.Error("Mutations not specified.")
if muts and muts_path:
gr.Warning("Both mutations and mutations file specified. Using mutations file.")
# Prepare PDB input
if pdb_path:
pdb_path = Path(pdb_path)
else:
try:
pdb_code = pdb_code.strip().lower()
pdb_path = temp_dir / f'pdb/{pdb_code}.pdb'
download_pdb(pdb_code, path=pdb_path)
except:
raise gr.Error("PDB download failed.")
partners = list(map(lambda x: x.strip(), partners.split(',')))
# Extract PPI into temp dir
try:
ppi_dir = temp_dir / 'ppi'
extractor = PPIExtractor(out_dir=ppi_dir, nest_out_dir=True, join=True, radius=10.0)
extractor.extract(pdb_path, partners=partners)
ppi_path = PPIPath.construct(ppi_dir, pdb_path.stem, partners)
except:
raise gr.Error("PPI extraction failed.")
# Prepare mutations input
if muts_path:
muts_path = Path(muts_path)
muts = muts_path.read_text()
muts = list(map(lambda x: x.strip(), muts.split(';')))
return pdb_path, ppi_path, muts
def plot_3dmol(pdb_path, ppi_path, muts, attn, mut_id=0):
# 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
# Read PDB for 3Dmol.js
with open(pdb_path, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
mol = mol.replace("OT1", "O ")
mol = mol.replace("OT2", "OXT")
# Read PPI to customize 3Dmol.js visualization
ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
muts_id = Mutation(muts[mut_id]).wt_to_graphein() # flatten ids of all sp muts
ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
# Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
attn = torch.nan_to_num(attn, nan=1e-10)
attn_sub = attn[:, mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
attn_sub = fill_diagonal(attn_sub, 1e-10)
attn_mutated = attn_sub[..., idx_mutated, :]
attn_mutated.shape
attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3))
attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min())
attns_per_token += 1e-10
ppi_df['attn'] = attns_per_token.numpy()
chains = ppi_df.sort_values('attn', ascending=False)['chain_id'].unique()
# Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
styles = []
zoom_atoms = []
# Cartoon chains
preferred_colors = ['LimeGreen', 'HotPink', 'RoyalBlue']
all_colors = [c[0] for c in RGB_TO_COLOR_NAMES.values()]
all_colors = [c for c in all_colors if c not in preferred_colors + ['Black', 'White']]
random.shuffle(all_colors)
all_colors = preferred_colors + all_colors
all_colors = [Color(c) for c in all_colors]
chain_to_color = dict(zip(chains, all_colors))
for chain in chains:
styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}])
# Stick PPI and atoms for zoom
# TODO Insertions
for _, row in ppi_df.iterrows():
color = copy.deepcopy(chain_to_color[row['chain_id']])
color.saturation = row['attn']
color = color.hex_l
if row['mutated']:
styles.append([
{'chain': row['chain_id'], 'resi': str(row['residue_number'])},
{'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}}
])
zoom_atoms.append(row['atom_number'])
else:
styles.append([
{'chain': row['chain_id'], 'resi': str(row['residue_number'])},
{'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}}
])
# Convert style dicts to JS lines
styles = ''.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');\n' for dcts in styles])
# Convert zoom atoms to 3DMol.js selection and add labels for mutated residues
zoom_animation_duration = 500
sel = '{\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}'
zoom = 'viewer.zoomTo(' + sel + ',' + f'{zoom_animation_duration});'
for atom in zoom_atoms:
sel = '{\"serial\": ' + str(atom) + '}'
row = ppi_df[ppi_df['atom_number'] == atom].iloc[0]
label = protein_letters_3to1[row['residue_name']] + row['chain_id'] + str(row['residue_number']) + row['insertion']
styles += 'viewer.addLabel(' + f"\"{label}\"," + "{fontSize:16, fontColor:\"red\", backgroundOpacity: 0.0}," + sel + ');\n'
# Construct 3Dmol.js visualization script embedded in HTML
html = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `"""
+ mol
+ """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
"""
+ styles
+ zoom
+ """
viewer.render();
})
</script>
</body></html>"""
)
print(html)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
def predict(models, temp_dir, *inputs):
# Process input
pdb_path, ppi_path, muts = process_inputs(inputs, temp_dir)
print(ppi_path, muts)
# Predict
try:
ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True)
except:
raise gr.Error("Prediction failed. Please double check your inputs.")
# Create dataframe
ddg = ddg.detach().numpy().tolist()
ddg = np.round(ddg, 3)
df = list(zip(muts, ddg))
# Create dataframe file
path = 'ppiformer_ddg_predictions.csv'
pd.DataFrame(df).rename(columns={0: "Mutation", 1: "ddG [kcal/mol]"}).to_csv(path, index=False)
# Create 3DMol plot
plot = plot_3dmol(pdb_path, ppi_path, muts, attn)
return df, path, plot
app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
with app:
# Input GUI
gr.Markdown(value="# PPIformer Web")
gr.Image("assets/readme-dimer-close-up.png")
gr.Markdown(value="""
[PPIformer](https://github.com/anton-bushuiev/PPIformer/tree/main) is a state-of-the-art predictor of the effects of mutations on protein-protein interactions (PPIs),
as quantified by the binding energy changes (ddG). The model was pre-trained on the [PPIRef](https://github.com/anton-bushuiev/PPIRef)
dataset via a coarse-grained structural masked modeling and fine-tuned on [SKEMPI v2.0](https://life.bsc.es/pid/skempi2) via log odds.
PPIformer was shown to successfully identify known favorable mutations of the [staphylokinase thrombolytic](https://pubmed.ncbi.nlm.nih.gov/10942387/)
and a [human antibody](https://www.pnas.org/doi/10.1073/pnas.2122954119) against the SARS-CoV-2 spike protein. Please see more details in [our paper](https://arxiv.org/abs/2310.18515).
To use PPIformer on your data, please specify the PPI structure (PDB code or file), interacting proteins of interest (chain codes in the file) and mutations
(semicolon-separated list or file with mutations in the [standard format](https://foldxsuite.crg.eu/parameter/mutant-file)). For inspiration, you can use one of the examples below:
click on one of the rows to pre-fill the inputs. After specifying the inputs, press the button to predict the effects of mutations on the PPI. Currently the model runs on CPU, so the prediction may take a few minutes.
After making a prediction with the model, you will see binding free energy changes (ddG values) for each mutation and a 3D visualization of the PPI with mutated residues highlighted in red. The visualization additionally shows
the attention coefficients of the model for the nearest neighboring residues, which quantifies the contribution of the residues to the predicted ddG value. The brighted and thicker a reisudes is, the more attention the model paid to it.
Currently, the web only visualizes the first mutation in the list.
""")
with gr.Row():
with gr.Column():
gr.Markdown("## PPI structure")
with gr.Row():
pdb_code = gr.Textbox(placeholder="1BUI", label="PDB code", info="Protein Data Bank code (https://www.rcsb.org/)")
partners = gr.Textbox(placeholder="A,B,C", label="Partners", info="Protein chains in the PDB file forming the PPI interface")
pdb_path = gr.File(file_count="single", label="Or PDB file instead of PDB code")
with gr.Column():
gr.Markdown("## Mutations")
muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A,FC47A;SC16A;FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination")
muts_path = gr.File(file_count="single", label="Or file with mutations")
examples = gr.Examples(
examples=[
["1BUI", "A,B,C", "SC16A,FC47A;SC16A;FC47A"],
["3QIB", "A,B,P,C,D", "YP7F,TP12S;YP7F;TP12S"],
["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])]
],
inputs=[pdb_code, partners, muts],
label="Examples (click on a line to pre-fill inputs)"
)
# Predict GUI
predict_button = gr.Button(value="Predict effects of mutations on PPI", variant="primary")
# Output GUI
gr.Markdown("## Predictions")
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
df = gr.Dataframe(
headers=["Mutation", "ddG [kcal/mol]"],
datatype=["str", "number"],
col_count=(2, "fixed"),
)
plot = gr.HTML()
# Download weights from Zenodo
download_weights()
# Load models
models = [
DDGPPIformer.load_from_checkpoint(
PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
map_location=torch.device('cpu')
).eval()
for i in range(3)
]
# Create temporary directory for storing downloaded PDBs and extracted PPIs
temp_dir_obj = tempfile.TemporaryDirectory()
temp_dir = Path(temp_dir_obj.name)
# Main logic
inputs = [pdb_code, pdb_path, partners, muts, muts_path]
outputs = [df, df_file, plot]
predict = partial(predict, models, temp_dir)
predict_button.click(predict, inputs=inputs, outputs=outputs)
app.launch(allowed_paths=['./assets'])