#!/usr/bin/env python3
"""
Simple adpr-llama Gradio app for ADP-ribosylation site prediction
Uses PEFT adapter model with Zero GPU support
"""
import re
from typing import List, Tuple
import io
import base64
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM
import numpy as np
import spaces
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
# Model configuration
MODEL_REPO = "jbenbudd/ADPrLlama"
MODEL_REVISION = "bb35aa92145ba2b6eba78542ae65e7bc7bdb06bc" # Set to a specific commit hash like "abc123def456" if needed
CHUNK_SIZE = 21
PAD_CHAR = "-"
print(f"Loading model from {MODEL_REPO}" + (f" at revision {MODEL_REVISION}" if MODEL_REVISION else ""))
# Global variables for model caching
model = None
tokenizer = None
@spaces.GPU
def generate_prediction(prompt: str) -> str:
"""Generate prediction using the model on GPU"""
global model, tokenizer
try:
# Load model inside GPU context if not already loaded
if model is None:
print("Loading model...")
model = AutoPeftModelForCausalLM.from_pretrained(
MODEL_REPO,
revision=MODEL_REVISION,
device_map="auto", # Zero GPU will handle device placement
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_REPO,
revision=MODEL_REVISION,
use_fast=True
)
print("Model loaded successfully!")
print(f"Generating prediction for prompt length: {len(prompt)}")
# Generate prediction
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
# Decode response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_response[len(prompt):].strip()
print(f"Generated response: {response}")
return response
except Exception as e:
print(f"Error in generate_prediction: {e}")
raise gr.Error(f"Model prediction failed: {str(e)}")
def clean_sequence(sequence: str) -> str:
"""Remove non-amino acid characters and convert to uppercase"""
return re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', sequence.upper())
def chunk_sequence(sequence: str, chunk_size: int = CHUNK_SIZE) -> List[str]:
"""Split sequence into chunks of specified size, padding if necessary"""
chunks = []
for i in range(0, len(sequence), chunk_size):
chunk = sequence[i:i + chunk_size]
if len(chunk) < chunk_size:
chunk = chunk.ljust(chunk_size, PAD_CHAR)
chunks.append(chunk)
return chunks
def parse_sites(text: str) -> List[str]:
"""Extract site predictions from model output"""
match = re.search(r"Sites=<([^>]*)>", text)
if not match:
return []
sites_str = match.group(1).strip()
if not sites_str or sites_str.lower() == 'none':
return []
return [site.strip() for site in sites_str.split(',') if site.strip()]
def remap_sites(sites: List[str], chunk_index: int, original_length: int, chunk_size: int = CHUNK_SIZE) -> List[str]:
"""Remap site positions from chunk-relative to sequence-relative"""
remapped = []
chunk_start = chunk_index * chunk_size
for site in sites:
if not site:
continue
# Extract residue letter and position
match = re.match(r'([A-Z])(\d+)', site)
if not match:
continue
residue, pos_str = match.groups()
pos = int(pos_str)
# Convert to 0-based, add chunk offset, convert back to 1-based
global_pos = chunk_start + (pos - 1) + 1
# Skip if position is beyond original sequence (padding)
if global_pos <= original_length:
remapped.append(f"{residue}{global_pos}")
return remapped
def create_interactive_visualization(sequence: str, predicted_sites: List[str]):
"""Create realistic interactive 3D protein structure visualization using Plotly"""
if len(sequence) > 1000: # Reasonable limit
# Return a simple text message for very long sequences
fig = go.Figure()
fig.add_annotation(
text=f"Sequence too long for visualization
Length: {len(sequence)} residues (max: 1000)
Sites: {', '.join(predicted_sites) if predicted_sites else 'None'}",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=16), align="center"
)
fig.update_layout(
title="Sequence Too Long",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
width=800, height=400
)
return fig
# Parse site positions
site_positions = set()
for site in predicted_sites:
match = re.match(r'[A-Z](\d+)', site)
if match:
site_positions.add(int(match.group(1)) - 1) # Convert to 0-based
# Amino acid properties for coloring and structure prediction
hydrophobic = set('AILMFPWYV')
positive = set('RHK')
negative = set('DE')
polar = set('STYNQC')
special = set('GP')
def get_aa_color(aa):
"""Get color based on amino acid properties"""
if aa in hydrophobic:
return 'orange'
elif aa in positive:
return 'blue'
elif aa in negative:
return 'red'
elif aa in polar:
return 'green'
elif aa in special:
return 'purple'
else:
return 'gray'
def predict_secondary_structure(sequence):
"""Simple secondary structure prediction based on amino acid propensities"""
structure = []
for i, aa in enumerate(sequence):
# Simple heuristic: helix-forming residues tend to form helices
helix_formers = set('AEHILMRTV')
sheet_formers = set('FIVWY')
if aa in helix_formers:
structure.append('H') # Helix
elif aa in sheet_formers and i > 2 and i < len(sequence) - 3:
structure.append('S') # Sheet
else:
structure.append('C') # Coil
return structure
# Generate realistic 3D coordinates
secondary_structure = predict_secondary_structure(sequence)
x_coords = []
y_coords = []
z_coords = []
colors = []
# Starting position
x, y, z = 0.0, 0.0, 0.0
phi, psi, omega = 0.0, 0.0, 0.0 # Backbone dihedral angles
# Realistic bond lengths and angles
ca_ca_distance = 3.8 # Average C-alpha to C-alpha distance
for i, (aa, ss) in enumerate(zip(sequence, secondary_structure)):
x_coords.append(x)
y_coords.append(y)
z_coords.append(z)
colors.append(get_aa_color(aa))
# Calculate next position based on secondary structure
if ss == 'H': # Alpha helix
phi_angle = -60 # degrees
psi_angle = -45
# Helical geometry
phi += np.radians(100) # ~3.6 residues per turn
x += ca_ca_distance * np.cos(phi) * 0.6
y += ca_ca_distance * np.sin(phi) * 0.6
z += 1.5 # Rise per residue in helix
elif ss == 'S': # Beta sheet
phi_angle = -120 # degrees
psi_angle = 120
# Extended conformation
direction = (-1) ** (i // 10) # Alternate direction every 10 residues
x += ca_ca_distance * 0.9 * direction
y += ca_ca_distance * 0.3 * np.sin(i * 0.5)
z += 0.5
else: # Random coil
# More random movement
phi += np.random.uniform(-np.pi/3, np.pi/3)
psi += np.random.uniform(-np.pi/4, np.pi/4)
x += ca_ca_distance * np.cos(phi) * np.random.uniform(0.7, 1.0)
y += ca_ca_distance * np.sin(phi) * np.random.uniform(0.7, 1.0)
z += np.random.uniform(0.5, 2.0)
# Create the plot
fig = go.Figure()
# Add protein backbone as a ribbon/tube
fig.add_trace(go.Scatter3d(
x=x_coords, y=y_coords, z=z_coords,
mode='lines+markers',
line=dict(color='lightblue', width=12),
marker=dict(
size=6,
color=colors,
opacity=0.8,
line=dict(color='white', width=1)
),
name='Protein Backbone',
hovertemplate='Position %{text}
Residue: %{customdata[0]}
Type: %{customdata[1]}
Secondary Structure: %{customdata[2]}
%{text}
Position: %{text}
Residue: %{customdata}
{len(sequence)} residues, {len(predicted_sites)} ADP-ribosylation sites predicted',
x=0.5,
font=dict(size=16)
),
scene=dict(
xaxis_title='X (Ã…)',
yaxis_title='Y (Ã…)',
zaxis_title='Z (Ã…)',
camera=dict(
eye=dict(x=1.2, y=1.2, z=1.2)
),
aspectmode='cube',
bgcolor='rgba(240,240,240,0.1)',
xaxis=dict(
backgroundcolor="rgba(0, 0, 0,0)",
gridcolor="lightgray",
showbackground=True,
zerolinecolor="lightgray",
),
yaxis=dict(
backgroundcolor="rgba(0, 0, 0,0)",
gridcolor="lightgray",
showbackground=True,
zerolinecolor="lightgray",
),
zaxis=dict(
backgroundcolor="rgba(0, 0, 0,0)",
gridcolor="lightgray",
showbackground=True,
zerolinecolor="lightgray",
)
),
width=800,
height=600,
margin=dict(l=0, r=0, t=60, b=0),
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255,255,255,0.8)"
),
font=dict(family="Arial, sans-serif")
)
return fig
def create_sequence_plot(sequence: str, predicted_sites: List[str]):
"""Create a 2D sequence visualization using Plotly"""
# Parse site positions
site_positions = set()
for site in predicted_sites:
match = re.match(r'[A-Z](\d+)', site)
if match:
site_positions.add(int(match.group(1)) - 1) # Convert to 0-based
# Create sequence grid
residues_per_row = min(50, len(sequence))
rows_needed = (len(sequence) + residues_per_row - 1) // residues_per_row
# Prepare data for heatmap
grid_data = []
annotations = []
for row in range(rows_needed):
row_data = []
for col in range(residues_per_row):
seq_idx = row * residues_per_row + col
if seq_idx < len(sequence):
# 1 for PTM sites, 0 for normal residues
value = 1 if seq_idx in site_positions else 0
row_data.append(value)
# Add annotation for amino acid letter
annotations.append(
dict(
x=col, y=rows_needed - row - 1,
text=sequence[seq_idx],
showarrow=False,
font=dict(color='white' if seq_idx in site_positions else 'black', size=10)
)
)
# Add position number for PTM sites
if seq_idx in site_positions:
annotations.append(
dict(
x=col, y=rows_needed - row - 1 - 0.3,
text=str(seq_idx + 1),
showarrow=False,
font=dict(color='white', size=8)
)
)
else:
row_data.append(-1) # Empty cell
grid_data.append(row_data)
# Create heatmap
fig = go.Figure(data=go.Heatmap(
z=grid_data,
colorscale=[[0, 'lightblue'], [0.5, 'lightgray'], [1, 'red']],
showscale=False,
hovertemplate='Position: %{customdata}
Residue: %{text}
Predicted ADP-ribosylation sites: {sites_text}
" highlighted += f"Sequence: {clean_seq}
" else: highlighted = f"No ADP-ribosylation sites predicted
" highlighted += f"Sequence: {clean_seq}
" # Analysis summary analysis = f"""Original length: {original_length} residues
Chunks processed: {len(chunks)}
Sites found: {len(all_sites)}