Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import random | |
import tempfile | |
from pathlib import Path | |
from functools import partial | |
import gradio as gr | |
import torch | |
import numpy as np | |
import pandas as pd | |
from Bio.PDB.Polypeptide import protein_letters_3to1 | |
from biopandas.pdb import PandasPdb | |
from colour import Color | |
from colour import RGB_TO_COLOR_NAMES | |
from mutils.proteins import AMINO_ACID_CODES_1 | |
from mutils.pdb import download_pdb | |
from mutils.mutations import Mutation | |
from ppiref.extraction import PPIExtractor | |
from ppiref.utils.ppi import PPIPath | |
from ppiref.utils.residue import Residue | |
from ppiformer.tasks.node import DDGPPIformer | |
from ppiformer.utils.api import download_weights, predict_ddg | |
from ppiformer.utils.torch import fill_diagonal | |
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR | |
random.seed(0) | |
def process_inputs(inputs, temp_dir): | |
pdb_code, pdb_path, partners, muts, muts_path = inputs | |
# Check inputs | |
if not pdb_code and not pdb_path: | |
raise gr.Error("PPI structure not specified.") | |
if pdb_code and pdb_path: | |
gr.Warning("Both PDB code and PDB file specified. Using PDB file.") | |
if not partners: | |
raise gr.Error("Partners not specified.") | |
if not muts and not muts_path: | |
raise gr.Error("Mutations not specified.") | |
if muts and muts_path: | |
gr.Warning("Both mutations and mutations file specified. Using mutations file.") | |
# Prepare PDB input | |
if pdb_path: | |
pdb_path = Path(pdb_path) | |
else: | |
try: | |
pdb_code = pdb_code.strip().lower() | |
pdb_path = temp_dir / f'pdb/{pdb_code}.pdb' | |
download_pdb(pdb_code, path=pdb_path) | |
except: | |
raise gr.Error("PDB download failed.") | |
partners = list(map(lambda x: x.strip(), partners.split(','))) | |
# Extract PPI into temp dir | |
try: | |
ppi_dir = temp_dir / 'ppi' | |
extractor = PPIExtractor(out_dir=ppi_dir, nest_out_dir=True, join=True, radius=10.0) | |
extractor.extract(pdb_path, partners=partners) | |
ppi_path = PPIPath.construct(ppi_dir, pdb_path.stem, partners) | |
except: | |
raise gr.Error("PPI extraction failed.") | |
# Prepare mutations input | |
if muts_path: | |
muts_path = Path(muts_path) | |
muts = muts_path.read_text() | |
muts = list(map(lambda x: x.strip(), muts.split(';'))) | |
return pdb_path, ppi_path, muts | |
def plot_3dmol(pdb_path, ppi_path, muts, attn, mut_id=0): | |
# 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py | |
# Read PDB for 3Dmol.js | |
with open(pdb_path, "r") as fp: | |
lines = fp.readlines() | |
mol = "" | |
for l in lines: | |
mol += l | |
mol = mol.replace("OT1", "O ") | |
mol = mol.replace("OT2", "OXT") | |
# Read PPI to customize 3Dmol.js visualization | |
ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM'] | |
ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True) | |
ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1) | |
ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x) | |
muts_id = Mutation(muts[mut_id]).wt_to_graphein() # flatten ids of all sp muts | |
ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1) | |
# Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues) | |
attn = torch.nan_to_num(attn, nan=1e-10) | |
attn_sub = attn[:, mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens | |
idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy()) | |
attn_sub = fill_diagonal(attn_sub, 1e-10) | |
attn_mutated = attn_sub[..., idx_mutated, :] | |
attn_mutated.shape | |
attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3)) | |
attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min()) | |
attns_per_token += 1e-10 | |
ppi_df['attn'] = attns_per_token.numpy() | |
chains = ppi_df.sort_values('attn', ascending=False)['chain_id'].unique() | |
# Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/ | |
styles = [] | |
zoom_atoms = [] | |
# Cartoon chains | |
preferred_colors = ['LimeGreen', 'HotPink', 'RoyalBlue'] | |
all_colors = [c[0] for c in RGB_TO_COLOR_NAMES.values()] | |
all_colors = [c for c in all_colors if c not in preferred_colors + ['Black', 'White']] | |
random.shuffle(all_colors) | |
all_colors = preferred_colors + all_colors | |
all_colors = [Color(c) for c in all_colors] | |
chain_to_color = dict(zip(chains, all_colors)) | |
for chain in chains: | |
styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}]) | |
# Stick PPI and atoms for zoom | |
# TODO Insertions | |
for _, row in ppi_df.iterrows(): | |
color = copy.deepcopy(chain_to_color[row['chain_id']]) | |
color.saturation = row['attn'] | |
color = color.hex_l | |
if row['mutated']: | |
styles.append([ | |
{'chain': row['chain_id'], 'resi': str(row['residue_number'])}, | |
{'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}} | |
]) | |
zoom_atoms.append(row['atom_number']) | |
else: | |
styles.append([ | |
{'chain': row['chain_id'], 'resi': str(row['residue_number'])}, | |
{'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}} | |
]) | |
# Convert style dicts to JS lines | |
styles = ''.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');\n' for dcts in styles]) | |
# Convert zoom atoms to 3DMol.js selection and add labels for mutated residues | |
zoom_animation_duration = 500 | |
sel = '{\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}' | |
zoom = 'viewer.zoomTo(' + sel + ',' + f'{zoom_animation_duration});' | |
for atom in zoom_atoms: | |
sel = '{\"serial\": ' + str(atom) + '}' | |
row = ppi_df[ppi_df['atom_number'] == atom].iloc[0] | |
label = protein_letters_3to1[row['residue_name']] + row['chain_id'] + str(row['residue_number']) + row['insertion'] | |
styles += 'viewer.addLabel(' + f"\"{label}\"," + "{fontSize:16, fontColor:\"red\", backgroundOpacity: 0.0}," + sel + ');\n' | |
# Construct 3Dmol.js visualization script embedded in HTML | |
html = ( | |
"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<style> | |
body{ | |
font-family:sans-serif | |
} | |
.mol-container { | |
width: 100%; | |
height: 600px; | |
position: relative; | |
} | |
.mol-container select{ | |
background-image:None; | |
} | |
</style> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div id="container" class="mol-container"></div> | |
<script> | |
let pdb = `""" | |
+ mol | |
+ """` | |
$(document).ready(function () { | |
let element = $("#container"); | |
let config = { backgroundColor: "white" }; | |
let viewer = $3Dmol.createViewer(element, config); | |
viewer.addModel(pdb, "pdb"); | |
viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}}); | |
""" | |
+ styles | |
+ zoom | |
+ """ | |
viewer.render(); | |
}) | |
</script> | |
</body></html>""" | |
) | |
print(html) | |
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" | |
def predict(models, temp_dir, *inputs): | |
# Process input | |
pdb_path, ppi_path, muts = process_inputs(inputs, temp_dir) | |
print(ppi_path, muts) | |
# Predict | |
try: | |
ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True) | |
except: | |
raise gr.Error("Prediction failed. Please double check your inputs.") | |
# Create dataframe | |
ddg = ddg.detach().numpy().tolist() | |
ddg = np.round(ddg, 3) | |
df = list(zip(muts, ddg)) | |
# Create dataframe file | |
path = 'ppiformer_ddg_predictions.csv' | |
pd.DataFrame(df).rename(columns={0: "Mutation", 1: "ddG [kcal/mol]"}).to_csv(path, index=False) | |
# Create 3DMol plot | |
plot = plot_3dmol(pdb_path, ppi_path, muts, attn) | |
return df, path, plot | |
app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink")) | |
with app: | |
# Input GUI | |
gr.Markdown(value="# PPIformer Web") | |
gr.Image("assets/readme-dimer-close-up.png") | |
gr.Markdown(value=""" | |
[PPIformer](https://github.com/anton-bushuiev/PPIformer/tree/main) is a state-of-the-art predictor of the effects of mutations on protein-protein interactions (PPIs), | |
as quantified by the binding energy changes (ddG). The model was pre-trained on the [PPIRef](https://github.com/anton-bushuiev/PPIRef) | |
dataset via a coarse-grained structural masked modeling and fine-tuned on [SKEMPI v2.0](https://life.bsc.es/pid/skempi2) via log odds. | |
PPIformer was shown to successfully identify known favorable mutations of the [staphylokinase thrombolytic](https://pubmed.ncbi.nlm.nih.gov/10942387/) | |
and a [human antibody](https://www.pnas.org/doi/10.1073/pnas.2122954119) against the SARS-CoV-2 spike protein. Please see more details in [our paper](https://arxiv.org/abs/2310.18515). | |
To use PPIformer on your data, please specify the PPI structure (PDB code or file), interacting proteins of interest (chain codes in the file) and mutations | |
(semicolon-separated list or file with mutations in the [standard format](https://foldxsuite.crg.eu/parameter/mutant-file)). For inspiration, you can use one of the examples below: | |
click on one of the rows to pre-fill the inputs. After specifying the inputs, press the button to predict the effects of mutations on the PPI. Currently the model runs on CPU, so the prediction may take a few minutes. | |
After making a prediction with the model, you will see binding free energy changes (ddG values) for each mutation and a 3D visualization of the PPI with mutated residues highlighted in red. The visualization additionally shows | |
the attention coefficients of the model for the nearest neighboring residues, which quantifies the contribution of the residues to the predicted ddG value. The brighted and thicker a reisudes is, the more attention the model paid to it. | |
Currently, the web only visualizes the first mutation in the list. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## PPI structure") | |
with gr.Row(): | |
pdb_code = gr.Textbox(placeholder="1BUI", label="PDB code", info="Protein Data Bank code (https://www.rcsb.org/)") | |
partners = gr.Textbox(placeholder="A,B,C", label="Partners", info="Protein chains in the PDB file forming the PPI interface") | |
pdb_path = gr.File(file_count="single", label="Or PDB file instead of PDB code") | |
with gr.Column(): | |
gr.Markdown("## Mutations") | |
muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A,FC47A;SC16A;FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination") | |
muts_path = gr.File(file_count="single", label="Or file with mutations") | |
examples = gr.Examples( | |
examples=[ | |
["1BUI", "A,B,C", "SC16A,FC47A;SC16A;FC47A"], | |
["3QIB", "A,B,P,C,D", "YP7F,TP12S;YP7F;TP12S"], | |
["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])] | |
], | |
inputs=[pdb_code, partners, muts], | |
label="Examples (click on a line to pre-fill inputs)" | |
) | |
# Predict GUI | |
predict_button = gr.Button(value="Predict effects of mutations on PPI", variant="primary") | |
# Output GUI | |
gr.Markdown("## Predictions") | |
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True) | |
df = gr.Dataframe( | |
headers=["Mutation", "ddG [kcal/mol]"], | |
datatype=["str", "number"], | |
col_count=(2, "fixed"), | |
) | |
plot = gr.HTML() | |
# Download weights from Zenodo | |
download_weights() | |
# Load models | |
models = [ | |
DDGPPIformer.load_from_checkpoint( | |
PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt', | |
map_location=torch.device('cpu') | |
).eval() | |
for i in range(3) | |
] | |
# Create temporary directory for storing downloaded PDBs and extracted PPIs | |
temp_dir_obj = tempfile.TemporaryDirectory() | |
temp_dir = Path(temp_dir_obj.name) | |
# Main logic | |
inputs = [pdb_code, pdb_path, partners, muts, muts_path] | |
outputs = [df, df_file, plot] | |
predict = partial(predict, models, temp_dir) | |
predict_button.click(predict, inputs=inputs, outputs=outputs) | |
app.launch(allowed_paths=['./assets']) | |