#fmt: off import streamlit as st import pandas as pd import os import tempfile import subprocess import requests import csv from models.polybert import polymer2psmiles import py3Dmol # Fix for permission error - disable usage stats if 'STREAMLIT_CONFIG_DIR' not in os.environ: os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit' # Create streamlit config directory if it doesn't exist streamlit_dir = os.environ.get('STREAMLIT_CONFIG_DIR', '/tmp/.streamlit') os.makedirs(streamlit_dir, exist_ok=True) # Create config.toml to disable usage stats config_path = os.path.join(streamlit_dir, 'config.toml') if not os.path.exists(config_path): with open(config_path, 'w') as f: f.write("""[browser] gatherUsageStats = false [server] headless = true enableCORS = false enableXsrfProtection = false """) # fmt: on aa2resn = { 'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE', 'G': 'GLY', 'H': 'HIS', 'I': 'ILE', 'K': 'LYS', 'L': 'LEU', 'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN', 'R': 'ARG', 'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR' } # Fancy header st.markdown("""

🧬 Plastic Degradation Predictor

Predict the degradability of plastics using protein sequences and polymer SMILES


""", unsafe_allow_html=True) st.write("Enter a UniProt ID or paste a protein sequence. Select a polymer from the list below.") # Load polymer names and SMILES # Only show polymers with SMILES in the dropdown polymer_csv = os.path.join(os.path.dirname( __file__), 'data/polymer2tok.csv') polymer_options = [] with open(polymer_csv, newline='') as f: reader = csv.DictReader(f) for row in reader: name = row['polymer'] smiles = polymer2psmiles.get(name, '') if smiles: # Only include polymers with SMILES polymer_options.append(f"{name} | {smiles}") input_type = st.radio("Input type", ["UniProt ID", "Protein Sequence"]) if input_type == "UniProt ID": uniprot_id = st.text_input("Enter UniProt ID", "P69905") sequence = "" if uniprot_id: # Fetch sequence from UniProt url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta" resp = requests.get(url) if resp.status_code == 200: fasta = resp.text sequence = "".join(fasta.split("\n")[1:]) st.success(f"Fetched sequence for {uniprot_id}") st.code(sequence) else: st.error("Failed to fetch sequence from UniProt.") else: sequence = st.text_area("Paste protein sequence", "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG") polymer = st.selectbox("Select polymer", polymer_options) selected_polymer = polymer.split('|')[0].strip() if '|' in polymer else polymer ckpt = "src/checkpoints/weights.ckpt" plm = "esm2_t33_650M_UR50D" if st.button("Predict degradation", type="primary"): if not sequence or not selected_polymer: st.error("Please provide both sequence and polymer.") else: # Create temp CSV with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp: tmp.write("sequence,polymer\n") tmp.write(f"{sequence},{selected_polymer}\n") tmp_path = tmp.name output_path = os.path.join(tempfile.gettempdir(), "predictions.csv") st.write("Running prediction...") result = subprocess.run([ "python", "src/predict.py", "--ckpt", ckpt, "--plm", plm, "--csv", tmp_path, "--output", output_path, "--attn" ], capture_output=True, text=True) if result.returncode == 0 and os.path.exists(output_path): df = pd.read_csv(output_path) if 'time' in df.columns: df = df.rename(columns={'time': 'running time'}) st.markdown(f"""

Prediction Complete!

Your input has been processed. See the results below:

Degradation: {df['pred'].values[0]} (Probability: {df['prob'].values[0]:.4f})

""", unsafe_allow_html=True) st.download_button("⬇️ Download Results", data=df.to_csv( index=False), file_name="predictions.csv", type="primary") # Show top-N attention residues if attention file exists attn_dir = os.path.join(os.path.dirname( output_path), "predictions.attn") attn_path = os.path.join(attn_dir, "0.pt") if os.path.exists(attn_path): import torch attn = torch.load(attn_path) # attn[0][0]: shape (num_heads, seq_len, seq_len) or (1, seq_len, seq_len) attn_matrix = attn[0][0] if isinstance( attn[0], (list, tuple)) else attn[0] # Average over heads if needed if attn_matrix.ndim == 3: attn_matrix = attn_matrix.mean(0) # For each residue, sum attention weights residue_scores = attn_matrix.sum(0).cpu().numpy() topN = min(10, len(residue_scores)) top_idx = residue_scores.argsort()[::-1][:topN] st.markdown(f"**Top {topN} high-attention residues:**") st.write(pd.DataFrame({ "Amino Acid": [sequence[i] for i in top_idx], "Residue Index": top_idx+1, "Attention Score": residue_scores[top_idx] })) else: st.info("No attention file found for visualization.") else: st.error("Prediction failed. See details below:") st.text(result.stderr) # If UniProt ID, try to download AlphaFold structure structure_path = None if input_type == "UniProt ID" and uniprot_id: af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.cif" # If attention available, highlight top residues highlight_residues = None attn_dir = os.path.join(tempfile.gettempdir(), "predictions.attn") attn_path = os.path.join(attn_dir, "0.pt") if os.path.exists(attn_path): import torch attn = torch.load(attn_path) attn_matrix = attn[0][0] if isinstance( attn[0], (list, tuple)) else attn[0] if attn_matrix.ndim == 3: attn_matrix = attn_matrix.mean(0) residue_scores = attn_matrix.sum(0).cpu().numpy() topN = min(10, len(residue_scores)) top_idx = residue_scores.argsort()[::-1][:topN] # Molstar selection: list of residue numbers (1-based) highlight_residues = [int(i+1) for i in top_idx] structure_path = os.path.join( tempfile.gettempdir(), f"AF-{uniprot_id}-F1-model_v4.cif") try: r = requests.get(af_url) if r.status_code == 200: with open(structure_path, "wb") as f: f.write(r.content) st.success( f"AlphaFold structure downloaded: {structure_path}") else: st.warning( "AlphaFoldDB structure not found for this UniProt ID.") except Exception as e: st.warning(f"AlphaFoldDB download error: {e}") if input_type == "UniProt ID" and uniprot_id and os.path.exists(attn_path) and os.path.exists(structure_path): st.markdown("### 3D Structure Visualization (stmol)") import torch from stmol import showmol attn = torch.load(attn_path) attn_matrix = attn[0][0] if isinstance( attn[0], (list, tuple)) else attn[0] if attn_matrix.ndim == 3: attn_matrix = attn_matrix.mean(0) residue_scores = attn_matrix.sum(0).cpu().numpy() topN = min(10, len(residue_scores)) top_idx = residue_scores.argsort()[::-1][:topN] labels = [ f"{sequence[i]}{i+1}: {residue_scores[i]:.4g}" for i in top_idx] with open(structure_path, "r") as cif_file: cif_data = cif_file.read() view = py3Dmol.view(width=600, height=400) view.addModel(cif_data, "cif") view.setStyle({"cartoon": {"color": "lightgray"}}) for i, idx in enumerate(top_idx): resi_num = int(idx+1) view.setStyle( {"resi": resi_num}, { "cartoon": {"color": "red"}}) view.addResLabels( {"resi": resi_num}, { "font": 'Arial', "fontColor": 'black', "showBackground": False, "screenOffset": {"x": 0, "y": 0}}) view.zoomTo() showmol(view, height=600, width='100%') # --- Footer: License and References --- st.markdown(""" ---

License

Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)
View full license details
""", unsafe_allow_html=True)