Spaces:
Running
Running
File size: 22,927 Bytes
29bd8b5 19cd7a4 29bd8b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 |
# print("""
# __ __ _ ___ _ _ _____ _____ _ _ _ _ _ ____ _____
# | \/ | / \ |_ _| \ | |_ _| ____| \ | | / \ | \ | |/ ___| ____|
# | |\/| | / _ \ | || \| | | | | _| | \| | / _ \ | \| | | | _|
# | | | |/ ___ \ | || |\ | | | | |___| |\ |/ ___ \| |\ | |___| |___
# |_| |_/_/ \_\___|_| \_| |_| |_____|_| \_/_/ \_\_| \_|\____|_____|
# ____ ____ _____ _ _ __
# | __ )| _ \| ____| / \ | |/ /
# | _ \| |_) | _| / _ \ | ' /
# | |_) | _ <| |___ / ___ \| . \
# |____/|_| \_\_____/_/ \_\_|\_\
# """)
import os
# os.system("pip uninstall -y gradio")
# os.system("pip install gradio==3.50.2")
# os.system("pip uninstall -y spaces")
# os.system("pip install spaces==0.8")
os.system("pip uninstall -y torch")
os.system("pip install torch==2.0.1")
import sys
import copy
import random
import tempfile
import shutil
import logging
from pathlib import Path
from functools import partial
import spaces
import gradio as gr
import torch
import numpy as np
import pandas as pd
from Bio.PDB.Polypeptide import protein_letters_3to1
from biopandas.pdb import PandasPdb
from colour import Color
from colour import RGB_TO_COLOR_NAMES
from mutils.proteins import AMINO_ACID_CODES_1
from mutils.pdb import download_pdb
from mutils.mutations import Mutation
from ppiref.extraction import PPIExtractor
from ppiref.utils.ppi import PPIPath
from ppiref.utils.residue import Residue
from ppiformer.tasks.node import DDGPPIformer
from ppiformer.utils.api import download_from_zenodo
from ppiformer.utils.api import predict_ddg as predict_ddg_
from ppiformer.utils.torch import fill_diagonal
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
import pkg_resources
import sys
def print_package_versions():
installed_packages = sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set])
print("Installed packages and their versions:")
for package in installed_packages:
print(package)
print("\nPython version:")
print(sys.version)
print_package_versions()
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
random.seed(0)
@spaces.GPU
def predict_ddg(models, ppi, muts, return_attn):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"[INFO] Device on prediction: {device}")
models = [model.to(device) for model in models]
if return_attn:
ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn)
return ddg_pred.detach().cpu(), attns.detach().cpu()
else:
ddg_pred = predict_ddg_(models, ppi, muts, return_attn=return_attn)
return ddg_pred.detach().cpu()
def process_inputs(inputs, temp_dir):
pdb_code, pdb_path, partners, muts, muts_path = inputs
# Check inputs
if not pdb_code and not pdb_path:
raise gr.Error("PPI structure not specified.")
if pdb_code and pdb_path:
gr.Warning("Both PDB code and PDB file specified. Using PDB file.")
if not partners:
raise gr.Error("Partners not specified.")
if not muts and not muts_path:
raise gr.Error("Mutations not specified.")
if muts and muts_path:
gr.Warning("Both mutations and mutations file specified. Using mutations file.")
# Prepare PDB input
if pdb_path:
# convert file name to PPIRef format
new_pdb_path = temp_dir / f"pdb/{pdb_path.name.replace('_', '-')}"
new_pdb_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(pdb_path), str(new_pdb_path))
pdb_path = new_pdb_path
pdb_path = Path(pdb_path)
else:
try:
pdb_code = pdb_code.strip().lower()
pdb_path = temp_dir / f'pdb/{pdb_code}.pdb'
download_pdb(pdb_code, path=pdb_path)
except:
raise gr.Error("PDB download failed.")
# Parse partners
partners = list(map(lambda x: x.strip(), partners.split(',')))
# Add partners to file name
pdb_path = pdb_path.rename(pdb_path.with_stem(f"{pdb_path.stem}-{'-'.join(partners)}"))
# Extract PPI into temp dir
try:
ppi_dir = temp_dir / 'ppi'
extractor = PPIExtractor(out_dir=ppi_dir, nest_out_dir=True, join=True, radius=10.0)
extractor.extract(pdb_path, partners=partners)
ppi_path = PPIPath.construct(ppi_dir, pdb_path.stem, partners)
except:
raise gr.Error("PPI extraction failed.")
# Prepare mutations input
if muts_path:
muts_path = Path(muts_path)
muts = muts_path.read_text()
# Check mutations
# Basic format
try:
muts = [Mutation.from_str(m) for m in muts.strip().split(';') if m.strip()]
except Exception as e:
raise gr.Error(f'Mutations parsing failed: {e}')
# Partners
for mut in muts:
for pmut in mut.muts:
if pmut.chain not in partners:
raise gr.Error(f'Chain of point mutation {pmut} is not in the list of partners {partners}.')
# Consistency with provided .pdb
muts_on_interface = []
for mut in muts:
if mut.wt_in_pdb(ppi_path):
val = True
elif mut.wt_in_pdb(pdb_path):
val = False
else:
raise gr.Error(f'Wild-type of mutation {mut} is not in the provided .pdb file.')
muts_on_interface.append(val)
muts = [str(m) for m in muts]
return pdb_path, ppi_path, muts, muts_on_interface
def plot_3dmol(pdb_path, ppi_path, mut, attn, attn_mut_id=0):
# NOTE 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
# Read PDB for 3Dmol.js
with open(pdb_path, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
mol = mol.replace("OT1", "O ")
mol = mol.replace("OT2", "OXT")
# Read PPI to customize 3Dmol.js visualization
ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
muts_id = Mutation.from_str(mut).wt_to_graphein() # flatten ids of all sp muts
ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
# Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
attn = torch.nan_to_num(attn, nan=1e-10)
attn_sub = attn[:, attn_mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
attn_sub = fill_diagonal(attn_sub, 1e-10)
attn_mutated = attn_sub[..., idx_mutated, :]
attn_mutated.shape
attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3))
attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min())
attns_per_token += 1e-10
ppi_df['attn'] = attns_per_token.numpy()
chains = ppi_df.sort_values('attn', ascending=False)['chain_id'].unique()
# Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
styles = []
zoom_atoms = []
# Cartoon chains
preferred_colors = ['LimeGreen', 'HotPink', 'RoyalBlue']
all_colors = [c[0] for c in RGB_TO_COLOR_NAMES.values()]
all_colors = [c for c in all_colors if c not in preferred_colors + ['Black', 'White']]
random.shuffle(all_colors)
all_colors = preferred_colors + all_colors
all_colors = [Color(c) for c in all_colors]
chain_to_color = dict(zip(chains, all_colors))
for chain in chains:
styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}])
# Stick PPI and atoms for zoom
# TODO Insertions
for _, row in ppi_df.iterrows():
color = copy.deepcopy(chain_to_color[row['chain_id']])
color.saturation = row['attn']
color = color.hex_l
if row['mutated']:
styles.append([
{'chain': row['chain_id'], 'resi': str(row['residue_number'])},
{'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}}
])
zoom_atoms.append(row['atom_number'])
else:
styles.append([
{'chain': row['chain_id'], 'resi': str(row['residue_number'])},
{'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}}
])
# Convert style dicts to JS lines
styles = ''.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');\n' for dcts in styles])
# Convert zoom atoms to 3DMol.js selection and add labels for mutated residues
zoom_animation_duration = 500
sel = '{\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}'
zoom = 'viewer.zoomTo(' + sel + ',' + f'{zoom_animation_duration});'
for atom in zoom_atoms:
sel = '{\"serial\": ' + str(atom) + '}'
row = ppi_df[ppi_df['atom_number'] == atom].iloc[0]
label = protein_letters_3to1[row['residue_name']] + row['chain_id'] + str(row['residue_number']) + row['insertion']
styles += 'viewer.addLabel(' + f"\"{label}\"," + "{fontSize:16, fontColor:\"red\", backgroundOpacity: 0.0}," + sel + ');\n'
# Construct 3Dmol.js visualization script embedded in HTML
html = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `"""
+ mol
+ """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
"""
+ styles
+ zoom
+ """
viewer.render();
})
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
def predict(models, temp_dir, *inputs):
logging.info('Starting prediction')
# Process input
pdb_path, ppi_path, muts, muts_on_interface = process_inputs(inputs, temp_dir)
# Create dataframe
df = pd.DataFrame({
'Mutation': muts,
'ddG [kcal/mol]': len(muts) * [np.nan],
'10A Interface': muts_on_interface,
'Attn Id': len(muts) * [np.nan],
})
# Show warning if some mutations are not on the interface
muts_not_on_interface = df[~df['10A Interface']]['Mutation'].tolist()
n_muts_not_on_interface = len(muts_not_on_interface)
if n_muts_not_on_interface:
n_muts_warn = 5
muts_not_on_interface = ';'.join(muts_not_on_interface[:n_muts_warn])
if n_muts_not_on_interface > n_muts_warn:
muts_not_on_interface += f'... (and {n_muts_not_on_interface - n_muts_warn} more)'
gr.Warning((
f"{muts_not_on_interface} {'is' if n_muts_not_on_interface == 1 else 'are'} not on the interface. "
f"The model will predict the effect{'s' if n_muts_not_on_interface > 1 else ''} of "
f"mutation{'s' if n_muts_not_on_interface > 1 else ''} on the whole complex. "
f"This may lead to less accurate predictions."
))
logging.info('Inputs processed')
# Predict using interface for mutations on the interface and using the whole complex otherwise
attn_ppi, attn_pdb = None, None
for df_sub, path in [
[df[df['10A Interface']], ppi_path],
[df[~df['10A Interface']], pdb_path]
]:
if not len(df_sub):
continue
# Predict
try:
ddg, attn = predict_ddg(models, path, df_sub['Mutation'].tolist(), return_attn=True)
except Exception as e:
print(f"Prediction failed. {str(e)}")
raise gr.Error(f"Prediction failed. {str(e)}")
ddg = ddg.detach().numpy().tolist()
logging.info(f'Predictions made for {path}')
# Update dataframe and attention tensor
idx = df_sub.index
df.loc[idx, 'ddG [kcal/mol]'] = ddg
df.loc[idx, 'Attn Id'] = np.arange(len(idx))
if path == ppi_path:
attn_ppi = attn
else:
attn_pdb = attn
df['Attn Id'] = df['Attn Id'].astype(int)
# Round ddG values
df['ddG [kcal/mol]'] = df['ddG [kcal/mol]'].round(3)
# Create PPI-specific dropdown
dropdown = gr.Dropdown(
df['Mutation'].tolist(), value=df['Mutation'].iloc[0],
interactive=True, visible=True, label="Mutation to visualize",
)
# Predefine plot arguments for all dropdown choices
dropdown_choices_to_plot_args = {
mut: (
pdb_path,
ppi_path if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else pdb_path,
mut,
attn_ppi if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else attn_pdb,
df[df['Mutation'] == mut]['Attn Id'].iloc[0]
)
for mut in df['Mutation']
}
# Create dataframe file
path = 'ppiformer_ddg_predictions.csv'
if n_muts_not_on_interface:
df = df[['Mutation', 'ddG [kcal/mol]', '10A Interface']]
df.to_csv(path, index=False)
df = gr.Dataframe(
value=df,
headers=['Mutation', 'ddG [kcal/mol]', '10A Interface'],
datatype=['str', 'number', 'bool'],
col_count=(3, 'fixed'),
)
else:
df = df[['Mutation', 'ddG [kcal/mol]']]
df.to_csv(path, index=False)
df = gr.Dataframe(
value=df,
headers=['Mutation', 'ddG [kcal/mol]'],
datatype=['str', 'number'],
col_count=(2, 'fixed'),
)
logging.info('Prediction results prepared')
return df, path, dropdown, dropdown_choices_to_plot_args
def update_plot(dropdown, dropdown_choices_to_plot_args):
return plot_3dmol(*dropdown_choices_to_plot_args[dropdown])
app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
with app:
# Input GUI
gr.Markdown(value="""
# PPIformer Web (CPU version)
### Computational Design of Protein-Protein Interactions
""")
gr.Image("assets/readme-dimer-close-up.png")
gr.Markdown(value="""
[PPIformer](https://github.com/anton-bushuiev/PPIformer/tree/main) is a state-of-the-art predictor of the effects of mutations
on protein-protein interactions (PPIs), as quantified by the binding free energy changes (ddG). PPIformer was shown to successfully
identify known favourable mutations of the [staphylokinase thrombolytics](https://pubmed.ncbi.nlm.nih.gov/10942387/)
and a [human antibody](https://www.pnas.org/doi/10.1073/pnas.2122954119) against the SARS-CoV-2 spike protein. The model was pre-trained
on the [PPIRef](https://github.com/anton-bushuiev/PPIRef)
dataset via a coarse-grained structural masked modeling and fine-tuned on the [SKEMPI v2.0](https://life.bsc.es/pid/skempi2) dataset via log odds.
Please see more details in [our ICLR 2024 paper](https://arxiv.org/abs/2310.18515).
**Inputs.** To use PPIformer on your data, please specify the PPI structure (PDB code or .pdb file), interacting proteins of interest
(chain codes in the file) and mutations (semicolon-separated list or file with mutations in the
[standard format](https://foldxsuite.crg.eu/parameter/mutant-file): wild-type residue, chain, residue number, mutant residue).
For inspiration, you can use one of the examples below: click on one of the rows to pre-fill the inputs. After specifying the inputs,
press the button to predict the effects of mutations on the PPI. Currently the model runs on CPU, so the predictions may take a few minutes.
**Outputs.** After making a prediction with the model, you will see binding free energy changes for each mutation (ddG values in kcal/mol).
A more negative value indicates an improvement in affinity, whereas a more positive value means a reduction in affinity.
Below you will also see a 3D visualization of the PPI with wild types of mutated residues highlighted in red. The visualization additionally shows
the attention coefficients of the model for the nearest neighboring residues, which quantifies the contribution of the residues
to the predicted ddG value. The brighter and thicker a residue is, the more attention the model paid to it.
""")
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("## PPI structure")
with gr.Row(equal_height=True):
pdb_code = gr.Textbox(placeholder="1BUI", label="PDB code", info="Protein Data Bank identifier for the structure (https://www.rcsb.org/)")
partners = gr.Textbox(placeholder="A,B,C", label="Partners", info="Protein chain identifiers in the PDB file forming the PPI interface (two or more)")
pdb_path = gr.File(file_count="single", label="Or .pdb file instead of PDB code (your structure will only be used for this prediction and not stored anywhere)")
with gr.Column():
gr.Markdown("## Mutations")
muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A;FC47A;SC16A,FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination")
muts_path = gr.File(file_count="single", label="Or file with mutations")
examples = gr.Examples(
examples=[
["1BUI", "A,B,C", "SC16A,FC47A;SC16A;FC47A"],
["3QIB", "A,B,P,C,D", "YP7F,TP12S;YP7F;TP12S"],
["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])]
],
inputs=[pdb_code, partners, muts],
label="Examples (click on a line to pre-fill the inputs)",
cache_examples=False
)
# Predict GUI
predict_button = gr.Button(value="Predict effects of mutations on PPI", variant="primary")
# Output GUI
gr.Markdown("## Predictions")
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
df = gr.Dataframe(
headers=["Mutation", "ddG [kcal/mol]"],
datatype=["str", "number"],
col_count=(2, "fixed"),
)
dropdown = gr.Dropdown(interactive=True, visible=False)
dropdown_choices_to_plot_args = gr.State([])
plot = gr.HTML()
# Bottom info box
gr.Markdown(value="""
<br/>
## About this web
**Use cases**. The predictor can be used in: (i) Drug Discovery for the development of novel drugs and vaccines for various diseases such as cancer,
neurodegenerative disorders, and infectious diseases, (ii) Biotechnological Applications to develop new biocatalysts for biofuels,
industrial chemicals, and pharmaceuticals (iii) Therapeutic Protein Design to develop therapeutic proteins with enhanced stability,
specificity, and efficacy, and (iv) Mechanistic Studies to gain insights into fundamental biological processes, such as signal transduction,
gene regulation, and immune response.
**Acknowledgement**. Please, use the following citation to acknowledge the use of our service. The web server is provided free of charge for non-commercial use.
> Bushuiev, Anton, Roman Bushuiev, Petr Kouba, Anatolii Filkin, Marketa Gabrielova, Michal Gabriel, Jiri Sedlar, Tomas Pluskal, Jiri Damborsky, Stanislav Mazurenko, Josef Sivic.
> "Learning to design protein-protein interactions with enhanced generalization". The Twelfth International Conference on Learning Representations (ICLR 2024).
> [https://arxiv.org/abs/2310.18515](https://arxiv.org/abs/2310.18515).
**Contact**. Please share your feedback or report any bugs through [GitHub Issues](https://github.com/anton-bushuiev/PPIformer/issues/new), or feel free to contact us directly at [anton.bushuiev@cvut.cz](mailto:anton.bushuiev@cvut.cz).
""")
gr.Image("assets/logos.png")
# Download weights from Zenodo
download_from_zenodo('weights.zip')
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"[INFO] Device on start: {device}")
# Load models
models = [
DDGPPIformer.load_from_checkpoint(
PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
map_location=torch.device('cpu')
).eval()
for i in range(3)
]
models = [model.to(device) for model in models]
# Create temporary directory for storing downloaded PDBs and extracted PPIs
temp_dir_obj = tempfile.TemporaryDirectory()
temp_dir = Path(temp_dir_obj.name)
# Main logic
inputs = [pdb_code, pdb_path, partners, muts, muts_path]
outputs = [df, df_file, dropdown, dropdown_choices_to_plot_args]
predict = partial(predict, models, temp_dir)
predict_button.click(predict, inputs=inputs, outputs=outputs)
# Update plot on dropdown change
dropdown.change(update_plot, inputs=[dropdown, dropdown_choices_to_plot_args], outputs=[plot])
app.launch(allowed_paths=['./assets'])
|