|
|
|
import gradio as gr |
|
import numpy as np |
|
import os, tempfile |
|
import torch |
|
import py3Dmol |
|
from huggingface_hub import login |
|
|
|
|
|
from esm.utils.structure.protein_chain import ProteinChain |
|
from esm.models.esm3 import ESM3 |
|
from esm.sdk.api import ( |
|
ESMProtein, |
|
GenerationConfig, |
|
) |
|
|
|
from gradio_molecule3d import Molecule3D |
|
|
|
|
|
theme = gr.themes.Monochrome( |
|
primary_hue="gray", |
|
) |
|
|
|
|
|
|
|
def get_model(model_name, token): |
|
login(token=token) |
|
|
|
if torch.cuda.is_available(): |
|
model = ESM3.from_pretrained(model_name, device=torch.device("cuda")) |
|
else: |
|
model = ESM3.from_pretrained(model_name, device=torch.device("cpu")) |
|
|
|
|
|
return model |
|
|
|
|
|
def get_pdb(pdb_id, chain_id): |
|
pdb = ProteinChain.from_rcsb(pdb_id, chain_id) |
|
|
|
return pdb |
|
|
|
|
|
def make_reps(res_start=None, res_end=None, main_color="whiteCarbon", highlight_color="redCarbon", main_style="cartoon", highlight_style="cartoon"): |
|
residue_range = f"{res_start}-{res_end}" if res_start != res_end else "" |
|
|
|
return [ |
|
{ |
|
"model": 0, |
|
"chain": "", |
|
"resname": "", |
|
"style": main_style, |
|
"color": main_color, |
|
"residue_range": "", |
|
"around": 0, |
|
"byres": False, |
|
"visible": True |
|
}, |
|
{ |
|
"model": 0, |
|
"chain": "", |
|
"resname": "", |
|
"style": highlight_style, |
|
"color": highlight_color, |
|
"residue_range": residue_range, |
|
"around": 0, |
|
"byres": False, |
|
"visible": True |
|
}] |
|
|
|
|
|
def render_pdb(pdb_id, chain_id, res_start, res_end, pdb_string=None): |
|
if pdb_string is None: |
|
pdb_string = get_pdb(pdb_id, chain_id).to_pdb_string() |
|
|
|
tmp_pdb = tempfile.NamedTemporaryFile(delete=False, prefix=f"{pdb_id}_chain{chain_id}_", suffix=".pdb") |
|
tmp_pdb.write(str.encode(pdb_string)) |
|
|
|
return Molecule3D(tmp_pdb.name, reps=make_reps(res_start=res_start, res_end=res_end)) |
|
|
|
|
|
def scaffold(model_name, token, pdb_id, chain_id, motif_start, motif_end, prompt_length, insert_size): |
|
pdb = get_pdb(pdb_id, chain_id) |
|
|
|
|
|
motif_inds = np.arange(motif_start, motif_end) |
|
motif_sequence = pdb[motif_inds].sequence |
|
motif_atom37_positions = pdb[motif_inds].atom37_positions |
|
|
|
|
|
sequence_prompt = ["_"]*prompt_length |
|
sequence_prompt[insert_size:insert_size+len(motif_sequence)] = list(motif_sequence) |
|
sequence_prompt = "".join(sequence_prompt) |
|
|
|
|
|
structure_prompt = torch.full((prompt_length, 37, 3), np.nan) |
|
structure_prompt[insert_size:insert_size+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions) |
|
|
|
|
|
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt) |
|
sequence_generation_config = GenerationConfig(track="sequence", |
|
num_steps=sequence_prompt.count("_") // 2, |
|
temperature=0.5) |
|
|
|
model = get_model(model_name, token) |
|
sequence_generation = model.generate(protein_prompt, sequence_generation_config) |
|
generated_sequence = sequence_generation.sequence |
|
|
|
|
|
structure_prediction_config = GenerationConfig( |
|
track="structure", |
|
num_steps=len(sequence_generation) // 8, |
|
temperature=0.7, |
|
) |
|
structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence) |
|
structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config) |
|
|
|
structure_prediction_chain = structure_prediction.to_protein_chain() |
|
motif_inds_in_generation = np.arange(insert_size, insert_size+len(motif_sequence)) |
|
structure_prediction_chain.align(pdb, mobile_inds=motif_inds_in_generation, target_inds=motif_inds) |
|
|
|
|
|
structure_orig_highlight = render_pdb(pdb_id, chain_id, res_start=motif_start, res_end=motif_end) |
|
structure_new_highlight = render_pdb(pdb_id, chain_id, res_start=insert_size, res_end=insert_size+len(motif_sequence), |
|
pdb_string=structure_prediction_chain.to_pdb_string()) |
|
|
|
return [ |
|
pdb.sequence, |
|
motif_sequence, |
|
structure_orig_highlight, |
|
|
|
sequence_prompt, |
|
|
|
|
|
generated_sequence, |
|
|
|
|
|
structure_new_highlight |
|
] |
|
|
|
|
|
def ss_edit(model_name, token, pdb_id, chain_id, region_start, region_end, shortened_region_length, shortening_ss8): |
|
pdb = get_pdb(pdb_id, chain_id) |
|
edit_region = np.arange(region_start, region_end) |
|
|
|
|
|
sequence_prompt = pdb.sequence[:edit_region[0]] + "_" * shortened_region_length + pdb.sequence[edit_region[-1] + 1:] |
|
|
|
|
|
ss8_prompt = shortening_ss8[:edit_region[0]] + (((shortened_region_length - 3) // 2) * "H" + "C"*3 + ((shortened_region_length - 3) // 2) * "H") + shortening_ss8[edit_region[-1] + 1:] |
|
|
|
|
|
original_sequence = pdb.sequence |
|
original_ss8 = shortening_ss8 |
|
original_ss8_region = " "*edit_region[0] + shortening_ss8[edit_region[0]:edit_region[-1]+1] |
|
|
|
proposed_ss8_region = " "*edit_region[0] + ss8_prompt[edit_region[0]:edit_region[0]+shortened_region_length] |
|
|
|
|
|
protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt) |
|
|
|
|
|
model = get_model(model_name, token) |
|
sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=protein_prompt.sequence.count("_") // 2, temperature=0.5)) |
|
|
|
|
|
structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track="structure", num_steps=len(protein_prompt) // 4, temperature=0)) |
|
structure_prediction_chain = structure_prediction.to_protein_chain() |
|
|
|
structure_orig_highlight = render_pdb(pdb_id, chain_id, res_start=region_start, res_end=region_end) |
|
structure_new_highlight = render_pdb(pdb_id, chain_id, res_start=region_start, res_end=region_end, |
|
pdb_string=structure_prediction_chain.to_pdb_string()) |
|
|
|
return [ |
|
original_sequence, |
|
original_ss8, |
|
original_ss8_region, |
|
structure_orig_highlight, |
|
sequence_prompt, |
|
ss8_prompt, |
|
proposed_ss8_region, |
|
|
|
sequence_generation, |
|
structure_new_highlight |
|
] |
|
|
|
|
|
def sasa_edit(model_name, token, pdb_id, chain_id, span_start, span_end, n_samples): |
|
pdb = get_pdb(pdb_id, chain_id) |
|
|
|
structure_prompt = torch.full((len(pdb), 37, 3), torch.nan) |
|
structure_prompt[span_start:span_end] = torch.tensor(pdb[span_start:span_end].atom37_positions, dtype=torch.float32) |
|
|
|
sasa_prompt = [None]*len(pdb) |
|
sasa_prompt[span_start:span_end] = [40.0]*(span_end - span_start) |
|
|
|
protein_prompt = ESMProtein(sequence="_"*len(pdb), coordinates=structure_prompt, sasa=sasa_prompt) |
|
|
|
model = get_model(model_name, token) |
|
|
|
generated_proteins = [] |
|
for i in range(n_samples): |
|
|
|
sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=len(protein_prompt) // 8, temperature=0.7)) |
|
|
|
structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track="structure", num_steps=len(protein_prompt) // 32)) |
|
generated_proteins.append(structure_prediction) |
|
|
|
|
|
generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True) |
|
|
|
structure_orig_highlight = render_pdb(pdb_id, chain_id, res_start=span_start, res_end=span_end) |
|
structure_new_highlight = render_pdb(pdb_id, chain_id, res_start=span_start, res_end=span_end, |
|
pdb_string=generated_proteins[0].to_protein_chain().to_pdb_string()) |
|
|
|
return [ |
|
protein_prompt.sequence, |
|
structure_orig_highlight, |
|
|
|
|
|
structure_new_highlight |
|
] |
|
|
|
|
|
|
|
scaffold_app = gr.Interface( |
|
fn=scaffold, |
|
inputs=[ |
|
gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True), |
|
gr.Textbox(value = "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", label="Hugging Face Token", type="password"), |
|
gr.Textbox(value="1ITU", label = "PDB Code"), |
|
gr.Textbox(value="A", label = "Chain"), |
|
gr.Number(value=123, label="Motif Start"), |
|
gr.Number(value=146, label="Motif End"), |
|
gr.Number(value=200, label="Prompt Length"), |
|
gr.Number(value=72, label="Insert Size") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Sequence"), |
|
gr.Textbox(label="Motif Sequence"), |
|
Molecule3D(label="Original Structure"), |
|
|
|
gr.Textbox(label="Sequence Prompt"), |
|
|
|
|
|
gr.Textbox(label="Generated Sequence"), |
|
Molecule3D(label="Generated Structure") |
|
] |
|
) |
|
|
|
|
|
ss_app = gr.Interface( |
|
fn=ss_edit, |
|
inputs=[ |
|
gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True), |
|
gr.Textbox(value = "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", label="Hugging Face Token", type="password"), |
|
gr.Textbox(value = "7XBQ", label="PDB ID"), |
|
gr.Textbox(value = "A", label="Chain ID"), |
|
gr.Number(value=38, label="Edit Region Start"), |
|
gr.Number(value=111, label="Edit Region End"), |
|
gr.Number(value=45, label="Shortened Region Length"), |
|
gr.Textbox(value="CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC", label="SS8 Shortening") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Original Sequence"), |
|
gr.Textbox(label="Original SS8"), |
|
gr.Textbox(label="Original SS8 Edit Region"), |
|
Molecule3D(label="Original Structure"), |
|
gr.Textbox(label="Sequence Prompt"), |
|
gr.Textbox(label="Edited SS8 Prompt"), |
|
gr.Textbox(label="Proposed SS8 of Edit Region"), |
|
|
|
gr.Textbox(label="Generated Sequence"), |
|
Molecule3D(label="Generated Structure") |
|
] |
|
) |
|
|
|
|
|
sasa_app = gr.Interface( |
|
fn=sasa_edit, |
|
inputs=[ |
|
gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True), |
|
gr.Textbox(value = "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", label="Hugging Face Token", type="password"), |
|
gr.Textbox(value = "1LBS", label="PDB ID"), |
|
gr.Textbox(value = "A", label="Chain ID"), |
|
gr.Number(value=105, label="Span Start"), |
|
gr.Number(value=116, label="Span End"), |
|
|
|
gr.Number(value=1, label="Number of Samples") |
|
], |
|
outputs = [ |
|
gr.Textbox(label="Protein Prompt"), |
|
Molecule3D(label="Original Structure"), |
|
|
|
|
|
Molecule3D(label="Best Generated Structure") |
|
] |
|
) |
|
|
|
|
|
protein_viewer = gr.Interface( |
|
fn=render_pdb, |
|
inputs=[ |
|
gr.Textbox(value = "1LBS", label="PDB ID"), |
|
gr.Textbox(value = "A", label="Chain ID"), |
|
gr.Number(value=10, label="Residue Highlight Start"), |
|
gr.Number(value=20, label="Residue Highlight End") |
|
], |
|
outputs=[ |
|
Molecule3D(label="3D Structure") |
|
] |
|
) |
|
|
|
|
|
|
|
with gr.Blocks(theme=theme) as esm_app: |
|
with gr.Row(): |
|
gr.Markdown( |
|
""" |
|
# ESM3: A frontier language model for biology. |
|
Model Created By: [EvolutionaryScale](https://www.evolutionaryscale.ai) |
|
- Press Release: https://www.evolutionaryscale.ai/blog/esm3-release |
|
- GitHub: https://github.com/evolutionaryscale/esm |
|
- HuggingFace Model: https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1 |
|
|
|
Spaces App By: [[Colby T. Ford](https://colbyford.com)] from [Tuple, The Cloud Genomics Company](https://tuple.xyz) |
|
|
|
NOTE: You will need to agree to EvolutionaryScale's [license agreement](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1) to use the model. Then, create and paste your HuggingFace token in the appropriate field. |
|
""" |
|
) |
|
with gr.Row(): |
|
gr.TabbedInterface([ |
|
scaffold_app, |
|
ss_app, |
|
sasa_app, |
|
protein_viewer |
|
], |
|
[ |
|
"Scaffolding Example", |
|
"Secondary Structure Editing Example", |
|
"SASA Editing Example", |
|
"PDB Viewer" |
|
]) |
|
|
|
if __name__ == "__main__": |
|
esm_app.launch() |
|
|