import gradio as gr
import numpy as np
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from transformers import pipeline as pl
from GPUtil import showUtilization as gpu_usage
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import plotly.graph_objects as go
import torch
import gc
import jax
from numba import cuda
print('GPU available',torch.cuda.is_available())
print('__CUDA Device Name:',torch.cuda.get_device_name(0))
print(os.getcwd())
if "/home/user/app/alphafold" not in sys.path:
sys.path.append("/home/user/app/alphafold")
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model
def mk_mock_template(query_sequence):
"""create blank template"""
ln = len(query_sequence)
output_templates_sequence = "-" * ln
templates_all_atom_positions = np.zeros(
(ln, templates.residue_constants.atom_type_num, 3)
)
templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))
templates_aatype = templates.residue_constants.sequence_to_onehot(
output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
)
template_features = {
"template_all_atom_positions": templates_all_atom_positions[None],
"template_all_atom_masks": templates_all_atom_masks[None],
"template_aatype": np.array(templates_aatype)[None],
"template_domain_names": [f"none".encode()],
}
return template_features
def predict_structure(prefix, feature_dict, model_runners, random_seed=0):
"""Predicts structure using AlphaFold for the given sequence."""
# Run the models.
# currently we only run model1
plddts = {}
for model_name, model_runner in model_runners.items():
processed_feature_dict = model_runner.process_features(
feature_dict, random_seed=random_seed
)
prediction_result = model_runner.predict(processed_feature_dict)
b_factors = (
prediction_result["plddt"][:, None]
* prediction_result["structure_module"]["final_atom_mask"]
)
unrelaxed_protein = protein.from_prediction(
processed_feature_dict, prediction_result, b_factors
)
unrelaxed_pdb_path = f"{prefix}_unrelaxed_{model_name}.pdb"
plddts[model_name] = prediction_result["plddt"]
print(f"{model_name} {plddts[model_name].mean()}")
with open(unrelaxed_pdb_path, "w") as f:
f.write(protein.to_pdb(unrelaxed_protein))
return plddts
def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs):
protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
sequences = protgpt2(
startsequence,
max_length=length,
do_sample=True,
top_k=top_k_poolsize,
repetition_penalty=repetitionPenalty,
num_return_sequences=max_seqs,
eos_token_id=0,
)
print("Cleaning up after protGPT2")
print(gpu_usage())
del protgpt2
#torch.cuda.empty_cache()
device = cuda.get_current_device()
device.reset()
print(gpu_usage())
return sequences
def run_alphafold(startsequence):
print(gpu_usage())
device = cuda.get_current_device()
device.reset()
model_runners = {}
models = ["model_1"] # ,"model_2","model_3","model_4","model_5"]
for model_name in models:
model_config = config.model_config(model_name)
model_config.data.eval.num_ensemble = 1
model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
model_runner = model.RunModel(model_config, model_params)
model_runners[model_name] = model_runner
query_sequence = startsequence.replace("\n", "")
feature_dict = {
**pipeline.make_sequence_features(
sequence=query_sequence, description="none", num_res=len(query_sequence)
),
**pipeline.make_msa_features(
msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]]
),
**mk_mock_template(query_sequence),
}
plddts = predict_structure("test", feature_dict, model_runners)
print("Cleaning up after AF2")
print(gpu_usage())
#backend = jax.lib.xla_bridge.get_backend()
#for buf in backend.live_buffers(): buf.delete()
#device = cuda.get_current_device()
#device.reset()
#print(gpu_usage())
return plddts["model_1"]
def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs):
startsequence = inp
seqlen = length
generated_seqs = run_protgpt2(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs)
gen_seqs = [x["generated_text"] for x in generated_seqs]
print(gen_seqs)
sequencestxt = ""
for i, seq in enumerate(gen_seqs):
s = seq.replace("\n","")
s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
sequencestxt +=f">seq{i}\n{s}\n"
return sequencestxt
def update(inp):
print("Running AF on", inp)
startsequence = inp
plddts = run_alphafold(startsequence)
print(plddts)
x = np.arange(10)
#plt.style.use(["seaborn-ticks", "seaborn-talk"])
#fig = plt.figure()
#ax = fig.add_subplot(111)
#ax.plot(plddts)
#ax.set_ylabel("predicted LDDT")
#ax.set_xlabel("positions")
#ax.set_title("pLDDT")
fig = go.Figure(data=go.Scatter(x=np.arange(len(plddts)), y=plddts, hovertemplate='pLDDT: %{y:.2f}
Residue index: %{x}'))
fig.update_layout(title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white")
return (
molecule(
f"test_unrelaxed_model_1.pdb",
),
fig,
f"{np.mean(plddts):.1f} ± {np.std(plddts):.1f}",
)
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb):
mol = read_mol(pdb)
x = (
"""