esmfold / app.py
Witold Wydmański
feat: Add get_esm2_embeddings function
8ee7dbf
raw
history blame
7.85 kB
import gradio as gr
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
import torch
from logging import getLogger
logger = getLogger(__name__)
def convert_outputs_to_pdb(outputs):
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = outputs["atom37_atom_exists"]
pdbs = []
for i in range(outputs["aatype"].shape[0]):
aa = outputs["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = outputs["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=outputs["plddt"][i],
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
)
pdbs.append(to_pdb(pred))
return pdbs[0]
def fold_prot_locally(sequence):
logger.info("Folding: " + sequence)
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
with torch.no_grad():
output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
return pdb
def get_esm2_embeddings(sequence):
logger.info("Getting embeddings for: " + sequence)
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
with torch.no_grad():
aa = tokenized_input
L = aa.shape[1]
device = tokenized_input.device
attention_mask = torch.ones_like(aa, device=device)
# === ESM ===
esmaa = model.af2_idx_to_esm_idx(aa, attention_mask)
esm_s = model.compute_language_model_representations(esmaa)
return {"res": esm_s.cpu().tolist()}
def get_esmfold_embeddings(sequence):
logger.info("Getting embeddings for: " + sequence)
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
with torch.no_grad():
output = model(tokenized_input)
return {"res": output["s_s"].cpu().tolist()}
def suggest(option):
if option == "Plastic degradation protein":
suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ"
elif option == "Antifreeze protein":
suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH"
elif option == "AI Generated protein":
suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS"
elif option == "7-bladed propeller fold":
suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK"
else:
suggestion = ""
return suggestion
def molecule(mol):
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `"""
+ mol
+ """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } });
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
})
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
sample_code = """
from gradio_client import Client
client = Client("https://wwydmanski-esmfold.hf.space/")
def fold_huggingface(sequence, fname=None):
result = client.predict(
sequence, # str in 'sequence' Textbox component
api_name="/pdb")
if fname is None:
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".pdb", prefix="esmfold_") as fp:
fp.write(result)
fp.flush()
return fp.name
else:
with open(fname, "w") as fp:
fp.write(result)
fp.flush()
return fname
pdb_fname = fold_huggingface("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN")
"""
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).cuda()
model.esm = model.esm.half()
torch.backends.cuda.matmul.allow_tf32 = True
with gr.Blocks() as demo:
gr.Markdown("# ESMFold")
with gr.Row():
with gr.Column():
inp = gr.Textbox(lines=1, label="Sequence")
name = gr.Dropdown(label="Choose a Sample Protein", value="Plastic degradation protein", choices=["Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"])
btn = gr.Button("🔬 Predict Structure ")
with gr.Row():
with gr.Column():
gr.Markdown("## Sample code")
gr.Code(sample_code, label="Sample usage", language="python", interactive=False)
with gr.Row():
gr.Markdown("## Output")
with gr.Row():
with gr.Column():
out = gr.Code(label="Output", interactive=False)
with gr.Column():
out_mol = gr.HTML(label="3D Structure")
with gr.Row(visible=False):
with gr.Column():
gr.Markdown("## Embeddings")
embs = gr.JSON(label="Embeddings")
name.change(fn=suggest, inputs=name, outputs=inp)
btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
btn.click(get_esm2_embeddings, inputs=[inp], outputs=[embs], api_name="esm2_embeddings")
out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
demo.launch()