Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from transformers import AutoTokenizer, EsmModel | |
from sklearn.decomposition import PCA | |
from Bio.PDB import PDBParser, PDBIO | |
import py3Dmol | |
import tempfile | |
import os | |
# Load ESM-1b model and tokenizer | |
model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True) | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") | |
# Compute PCA and return scaled values for selected components | |
def compute_scaled_pca_scores(seq, components): | |
inputs = tokenizer(seq, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state[0] | |
L = len(seq) | |
embedding = embedding[1:L+1] # remove CLS and EOS | |
pca = PCA(n_components=max(components) + 1) | |
pca_result = pca.fit_transform(embedding.detach().cpu().numpy()) | |
scaled_components = [] | |
for c in components: | |
selected = pca_result[:, c] | |
scaled = (selected - selected.min()) / (selected.max() - selected.min()) * 100 | |
scaled_components.append(scaled) | |
return scaled_components | |
# Inject scores into B-factor column and save each PDB separately | |
def inject_bfactors_and_save(pdb_file, scores_list, component_indices): | |
parser = PDBParser(QUIET=True) | |
structure = parser.get_structure("prot", pdb_file.name) | |
output_paths = [] | |
for scores, idx in zip(scores_list, component_indices): | |
i = 0 | |
for model in structure: | |
for chain in model: | |
for residue in chain: | |
if i >= len(scores): | |
break | |
for atom in residue: | |
atom.bfactor = float(scores[i]) | |
i += 1 | |
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=f"_PC{idx}.pdb").name | |
io = PDBIO() | |
io.set_structure(structure) | |
io.save(out_path) | |
output_paths.append(out_path) | |
return output_paths | |
# Render structure with py3Dmol and inject script tag manually | |
def render_structure(pdb_path): | |
with open(pdb_path, 'r') as f: | |
pdb_data = f.read() | |
view = py3Dmol.view(width=600, height=400) | |
view.addModel(pdb_data, 'pdb') | |
view.setStyle({'cartoon': {'color': 'bfactor'}}) | |
view.zoomTo() | |
# Combine viewer HTML with explicit 3Dmol.js script | |
html = ( | |
'<script src="https://3Dmol.org/build/3Dmol.js"></script>' | |
+ view._make_html() | |
) | |
return html | |
# Gradio interface logic | |
def process(seq, pdb_file, component_string): | |
try: | |
components = [int(c.strip()) for c in component_string.split(",") if c.strip().isdigit()] | |
except: | |
return [], "<p style='color:red'>Error: Invalid component list. Use comma-separated integers.</p>" | |
scores_list = compute_scaled_pca_scores(seq, components) | |
pdb_paths = inject_bfactors_and_save(pdb_file, scores_list, components) | |
html_view = render_structure(pdb_paths[0]) if pdb_paths else "" | |
return pdb_paths, html_view | |
# Gradio UI | |
demo = gr.Interface( | |
fn=process, | |
inputs=[ | |
gr.Textbox(label="Input Protein Sequence (1-letter code)"), | |
gr.File(label="Upload PDB File", file_types=[".pdb"]), | |
gr.Textbox(label="Comma-separated PCA Components (e.g. 0,1,2)") | |
], | |
outputs=[ | |
gr.File(label="Download PDBs with PCA Projections", file_types=[".pdb"], file_count="multiple"), | |
gr.HTML(label="Interactive Structure Viewer (first PCA component only)") | |
], | |
title="ESM-1b PCA Component Projection with Interactive 3D Structure" | |
) | |
demo.launch() | |