#fmt: off
import streamlit as st
import pandas as pd
import os
import tempfile
import subprocess
import requests
import csv
from models.polybert import polymer2psmiles
import py3Dmol
# Fix for permission error - disable usage stats
if 'STREAMLIT_CONFIG_DIR' not in os.environ:
os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit'
# Create streamlit config directory if it doesn't exist
streamlit_dir = os.environ.get('STREAMLIT_CONFIG_DIR', '/tmp/.streamlit')
os.makedirs(streamlit_dir, exist_ok=True)
# Create config.toml to disable usage stats
config_path = os.path.join(streamlit_dir, 'config.toml')
if not os.path.exists(config_path):
with open(config_path, 'w') as f:
f.write("""[browser]
gatherUsageStats = false
[server]
headless = true
enableCORS = false
enableXsrfProtection = false
""")
# fmt: on
aa2resn = {
'A': 'ALA',
'C': 'CYS',
'D': 'ASP',
'E': 'GLU',
'F': 'PHE',
'G': 'GLY',
'H': 'HIS',
'I': 'ILE',
'K': 'LYS',
'L': 'LEU',
'M': 'MET',
'N': 'ASN',
'P': 'PRO',
'Q': 'GLN',
'R': 'ARG',
'S': 'SER',
'T': 'THR',
'V': 'VAL',
'W': 'TRP',
'Y': 'TYR'
}
# Fancy header
st.markdown("""
🧬 Plastic Degradation Predictor
Predict the degradability of plastics using protein sequences and polymer SMILES
""", unsafe_allow_html=True)
st.write("Enter a UniProt ID or paste a protein sequence. Select a polymer from the list below.")
# Load polymer names and SMILES
# Only show polymers with SMILES in the dropdown
polymer_csv = os.path.join(os.path.dirname(
__file__), 'data/polymer2tok.csv')
polymer_options = []
with open(polymer_csv, newline='') as f:
reader = csv.DictReader(f)
for row in reader:
name = row['polymer']
smiles = polymer2psmiles.get(name, '')
if smiles: # Only include polymers with SMILES
polymer_options.append(f"{name} | {smiles}")
input_type = st.radio("Input type", ["UniProt ID", "Protein Sequence"])
if input_type == "UniProt ID":
uniprot_id = st.text_input("Enter UniProt ID", "P69905")
sequence = ""
if uniprot_id:
# Fetch sequence from UniProt
url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
resp = requests.get(url)
if resp.status_code == 200:
fasta = resp.text
sequence = "".join(fasta.split("\n")[1:])
st.success(f"Fetched sequence for {uniprot_id}")
st.code(sequence)
else:
st.error("Failed to fetch sequence from UniProt.")
else:
sequence = st.text_area("Paste protein sequence",
"MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG")
polymer = st.selectbox("Select polymer", polymer_options)
selected_polymer = polymer.split('|')[0].strip() if '|' in polymer else polymer
ckpt = "src/checkpoints/weights.ckpt"
plm = "esm2_t33_650M_UR50D"
if st.button("Predict degradation", type="primary"):
if not sequence or not selected_polymer:
st.error("Please provide both sequence and polymer.")
else:
# Create temp CSV
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp:
tmp.write("sequence,polymer\n")
tmp.write(f"{sequence},{selected_polymer}\n")
tmp_path = tmp.name
output_path = os.path.join(tempfile.gettempdir(), "predictions.csv")
st.write("Running prediction...")
result = subprocess.run([
"python", "src/predict.py",
"--ckpt", ckpt,
"--plm", plm,
"--csv", tmp_path,
"--output", output_path,
"--attn"
], capture_output=True, text=True)
if result.returncode == 0 and os.path.exists(output_path):
df = pd.read_csv(output_path)
if 'time' in df.columns:
df = df.rename(columns={'time': 'running time'})
st.markdown(f"""
✅ Prediction Complete!
Your input has been processed. See the results below:
Degradation: {df['pred'].values[0]} (Probability: {df['prob'].values[0]:.4f})
""", unsafe_allow_html=True)
st.download_button("⬇️ Download Results", data=df.to_csv(
index=False), file_name="predictions.csv", type="primary")
# Show top-N attention residues if attention file exists
attn_dir = os.path.join(os.path.dirname(
output_path), "predictions.attn")
attn_path = os.path.join(attn_dir, "0.pt")
if os.path.exists(attn_path):
import torch
attn = torch.load(attn_path)
# attn[0][0]: shape (num_heads, seq_len, seq_len) or (1, seq_len, seq_len)
attn_matrix = attn[0][0] if isinstance(
attn[0], (list, tuple)) else attn[0]
# Average over heads if needed
if attn_matrix.ndim == 3:
attn_matrix = attn_matrix.mean(0)
# For each residue, sum attention weights
residue_scores = attn_matrix.sum(0).cpu().numpy()
topN = min(10, len(residue_scores))
top_idx = residue_scores.argsort()[::-1][:topN]
st.markdown(f"**Top {topN} high-attention residues:**")
st.write(pd.DataFrame({
"Amino Acid": [sequence[i] for i in top_idx],
"Residue Index": top_idx+1,
"Attention Score": residue_scores[top_idx]
}))
else:
st.info("No attention file found for visualization.")
else:
st.error("Prediction failed. See details below:")
st.text(result.stderr)
# If UniProt ID, try to download AlphaFold structure
structure_path = None
if input_type == "UniProt ID" and uniprot_id:
af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.cif"
# If attention available, highlight top residues
highlight_residues = None
attn_dir = os.path.join(tempfile.gettempdir(), "predictions.attn")
attn_path = os.path.join(attn_dir, "0.pt")
if os.path.exists(attn_path):
import torch
attn = torch.load(attn_path)
attn_matrix = attn[0][0] if isinstance(
attn[0], (list, tuple)) else attn[0]
if attn_matrix.ndim == 3:
attn_matrix = attn_matrix.mean(0)
residue_scores = attn_matrix.sum(0).cpu().numpy()
topN = min(10, len(residue_scores))
top_idx = residue_scores.argsort()[::-1][:topN]
# Molstar selection: list of residue numbers (1-based)
highlight_residues = [int(i+1) for i in top_idx]
structure_path = os.path.join(
tempfile.gettempdir(), f"AF-{uniprot_id}-F1-model_v4.cif")
try:
r = requests.get(af_url)
if r.status_code == 200:
with open(structure_path, "wb") as f:
f.write(r.content)
st.success(
f"AlphaFold structure downloaded: {structure_path}")
else:
st.warning(
"AlphaFoldDB structure not found for this UniProt ID.")
except Exception as e:
st.warning(f"AlphaFoldDB download error: {e}")
if input_type == "UniProt ID" and uniprot_id and os.path.exists(attn_path) and os.path.exists(structure_path):
st.markdown("### 3D Structure Visualization (stmol)")
import torch
from stmol import showmol
attn = torch.load(attn_path)
attn_matrix = attn[0][0] if isinstance(
attn[0], (list, tuple)) else attn[0]
if attn_matrix.ndim == 3:
attn_matrix = attn_matrix.mean(0)
residue_scores = attn_matrix.sum(0).cpu().numpy()
topN = min(10, len(residue_scores))
top_idx = residue_scores.argsort()[::-1][:topN]
labels = [
f"{sequence[i]}{i+1}: {residue_scores[i]:.4g}" for i in top_idx]
with open(structure_path, "r") as cif_file:
cif_data = cif_file.read()
view = py3Dmol.view(width=600, height=400)
view.addModel(cif_data, "cif")
view.setStyle({"cartoon": {"color": "lightgray"}})
for i, idx in enumerate(top_idx):
resi_num = int(idx+1)
view.setStyle(
{"resi": resi_num}, {
"cartoon": {"color": "red"}})
view.addResLabels(
{"resi": resi_num},
{
"font": 'Arial', "fontColor": 'black',
"showBackground": False, "screenOffset": {"x": 0, "y": 0}})
view.zoomTo()
showmol(view, height=600, width='100%')
# --- Footer: License and References ---
st.markdown("""
---
License
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)
View full license details
""", unsafe_allow_html=True)