|
|
""" |
|
|
AbMelt Complete Pipeline - Hugging Face Space Implementation |
|
|
Full molecular dynamics simulation pipeline for antibody thermostability prediction |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
import tempfile |
|
|
import threading |
|
|
import time |
|
|
import json |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
import traceback |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent / "src")) |
|
|
|
|
|
from structure_generator import StructureGenerator |
|
|
from gromacs_pipeline import GromacsPipeline, GromacsError |
|
|
from descriptor_calculator import DescriptorCalculator |
|
|
from ml_predictor import ThermostabilityPredictor |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AbMeltPipeline: |
|
|
"""Complete AbMelt pipeline for HF Space""" |
|
|
|
|
|
def __init__(self): |
|
|
self.structure_gen = StructureGenerator() |
|
|
self.predictor = None |
|
|
self.current_job = None |
|
|
self.job_status = {} |
|
|
|
|
|
|
|
|
try: |
|
|
models_dir = Path(__file__).parent / "models" |
|
|
self.predictor = ThermostabilityPredictor(models_dir) |
|
|
logger.info("ML predictor initialized") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize ML predictor: {e}") |
|
|
|
|
|
def run_complete_pipeline(self, heavy_chain, light_chain, sim_time_ns=10, |
|
|
temperatures="300,350,400", progress_callback=None): |
|
|
""" |
|
|
Run the complete AbMelt pipeline |
|
|
|
|
|
Args: |
|
|
heavy_chain (str): Heavy chain variable region sequence |
|
|
light_chain (str): Light chain variable region sequence |
|
|
sim_time_ns (int): Simulation time in nanoseconds |
|
|
temperatures (str): Comma-separated temperatures |
|
|
progress_callback (callable): Function to update progress |
|
|
|
|
|
Returns: |
|
|
dict: Results including predictions and intermediate files |
|
|
""" |
|
|
results = { |
|
|
'success': False, |
|
|
'predictions': {}, |
|
|
'intermediate_files': {}, |
|
|
'descriptors': {}, |
|
|
'error': None, |
|
|
'logs': [] |
|
|
} |
|
|
|
|
|
temp_list = [int(t.strip()) for t in temperatures.split(',')] |
|
|
job_id = f"job_{int(time.time())}" |
|
|
|
|
|
try: |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0, "Starting AbMelt pipeline...") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(10, "Generating antibody structure with ImmuneBuilder...") |
|
|
|
|
|
structure_path = self.structure_gen.generate_structure( |
|
|
heavy_chain, light_chain |
|
|
) |
|
|
results['intermediate_files']['structure'] = structure_path |
|
|
results['logs'].append("β Structure generation completed") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(20, "Preparing GROMACS molecular dynamics system...") |
|
|
|
|
|
md_pipeline = GromacsPipeline() |
|
|
|
|
|
try: |
|
|
prepared_system = md_pipeline.prepare_system(structure_path) |
|
|
results['intermediate_files']['prepared_system'] = prepared_system |
|
|
results['logs'].append("β GROMACS system preparation completed") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(30, f"Running MD simulations at {len(temp_list)} temperatures...") |
|
|
|
|
|
trajectories = md_pipeline.run_md_simulations( |
|
|
temperatures=temp_list, |
|
|
sim_time_ns=sim_time_ns |
|
|
) |
|
|
results['intermediate_files']['trajectories'] = trajectories |
|
|
results['logs'].append(f"β MD simulations completed for {len(temp_list)} temperatures") |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(80, "Calculating molecular descriptors...") |
|
|
|
|
|
descriptor_calc = DescriptorCalculator(md_pipeline.work_dir) |
|
|
|
|
|
|
|
|
topology_files = {temp: os.path.join(md_pipeline.work_dir, f"md_{temp}.tpr") |
|
|
for temp in temp_list} |
|
|
|
|
|
descriptors = descriptor_calc.calculate_all_descriptors( |
|
|
trajectories, topology_files |
|
|
) |
|
|
results['descriptors'] = descriptors |
|
|
results['logs'].append("β Descriptor calculation completed") |
|
|
|
|
|
|
|
|
desc_csv_path = os.path.join(md_pipeline.work_dir, "descriptors.csv") |
|
|
descriptor_calc.export_descriptors_csv(descriptors, desc_csv_path) |
|
|
results['intermediate_files']['descriptors_csv'] = desc_csv_path |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(90, "Making thermostability predictions...") |
|
|
|
|
|
if self.predictor: |
|
|
predictions = self.predictor.predict_thermostability(descriptors) |
|
|
results['predictions'] = predictions |
|
|
results['logs'].append("β Thermostability predictions completed") |
|
|
else: |
|
|
results['logs'].append("β ML predictor not available") |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(100, "Pipeline completed successfully!") |
|
|
|
|
|
results['success'] = True |
|
|
|
|
|
except GromacsError as e: |
|
|
error_msg = f"GROMACS error: {str(e)}" |
|
|
results['error'] = error_msg |
|
|
results['logs'].append(f"β {error_msg}") |
|
|
logger.error(error_msg) |
|
|
|
|
|
finally: |
|
|
|
|
|
try: |
|
|
md_pipeline.cleanup() |
|
|
except: |
|
|
pass |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Pipeline error: {str(e)}" |
|
|
results['error'] = error_msg |
|
|
results['logs'].append(f"β {error_msg}") |
|
|
logger.error(f"Pipeline failed: {traceback.format_exc()}") |
|
|
|
|
|
finally: |
|
|
|
|
|
try: |
|
|
self.structure_gen.cleanup() |
|
|
except: |
|
|
pass |
|
|
|
|
|
return results |
|
|
|
|
|
def create_interface(): |
|
|
"""Create the Gradio interface""" |
|
|
|
|
|
pipeline = AbMeltPipeline() |
|
|
|
|
|
with gr.Blocks(title="AbMelt: Complete MD Pipeline", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 𧬠AbMelt: Complete Molecular Dynamics Pipeline |
|
|
|
|
|
**Predict antibody thermostability through multi-temperature molecular dynamics simulations** |
|
|
|
|
|
This space implements the complete AbMelt protocol from sequence to thermostability predictions: |
|
|
- Structure generation with ImmuneBuilder |
|
|
- Multi-temperature MD simulations (300K, 350K, 400K) |
|
|
- Comprehensive descriptor calculation |
|
|
- Machine learning predictions for Tagg, Tm,on, and Tm |
|
|
|
|
|
β οΈ **Note**: Full pipeline takes 2-4 hours per antibody due to MD simulation requirements. |
|
|
""") |
|
|
|
|
|
with gr.Tab("π Complete Pipeline"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Input Sequences") |
|
|
heavy_chain = gr.Textbox( |
|
|
label="Heavy Chain Variable Region", |
|
|
placeholder="Enter VH amino acid sequence (e.g., QVQLVQSGAEVKKPG...)", |
|
|
lines=3, |
|
|
info="Variable region of heavy chain (VH)" |
|
|
) |
|
|
light_chain = gr.Textbox( |
|
|
label="Light Chain Variable Region", |
|
|
placeholder="Enter VL amino acid sequence (e.g., DIQMTQSPSSLSASVGDR...)", |
|
|
lines=3, |
|
|
info="Variable region of light chain (VL)" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Simulation Parameters") |
|
|
sim_time = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=100, |
|
|
value=10, |
|
|
step=10, |
|
|
label="Simulation time (ns)", |
|
|
info="Longer simulations are more accurate but take more time" |
|
|
) |
|
|
temperatures = gr.Textbox( |
|
|
label="Temperatures (K)", |
|
|
value="300,350,400", |
|
|
info="Comma-separated temperatures for MD simulations" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Pipeline Progress") |
|
|
status_text = gr.Textbox( |
|
|
label="Current Status", |
|
|
value="Ready to start...", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
run_button = gr.Button("π¬ Run Complete Pipeline", variant="primary") |
|
|
|
|
|
gr.Markdown("### Estimated Time") |
|
|
time_estimate = gr.Textbox( |
|
|
label="Estimated Completion Time", |
|
|
value="Not calculated", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("### π Results") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Thermostability Predictions") |
|
|
tagg_result = gr.Number( |
|
|
label="Tagg - Aggregation Temperature (Β°C)", |
|
|
info="Temperature at which aggregation begins", |
|
|
interactive=False |
|
|
) |
|
|
tmon_result = gr.Number( |
|
|
label="Tm,on - Melting Temperature On-pathway (Β°C)", |
|
|
info="On-pathway melting temperature", |
|
|
interactive=False |
|
|
) |
|
|
tm_result = gr.Number( |
|
|
label="Tm - Overall Melting Temperature (Β°C)", |
|
|
info="Overall thermal melting temperature", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Pipeline Logs") |
|
|
pipeline_logs = gr.Textbox( |
|
|
label="Execution Log", |
|
|
lines=8, |
|
|
info="Real-time pipeline progress and status", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("### π Download Results") |
|
|
|
|
|
with gr.Row(): |
|
|
structure_download = gr.File( |
|
|
label="Generated Structure (PDB)" |
|
|
) |
|
|
descriptors_download = gr.File( |
|
|
label="Calculated Descriptors (CSV)" |
|
|
) |
|
|
trajectory_info = gr.Textbox( |
|
|
label="Trajectory Information", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Tab("β‘ Quick Prediction"): |
|
|
gr.Markdown(""" |
|
|
### Upload Pre-calculated Descriptors |
|
|
If you have already calculated MD descriptors, upload them here for quick predictions. |
|
|
""") |
|
|
|
|
|
descriptor_upload = gr.File( |
|
|
label="Upload Descriptor CSV", |
|
|
file_types=[".csv"] |
|
|
) |
|
|
quick_predict_btn = gr.Button("π― Quick Predict", variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
|
quick_tagg = gr.Number(label="Tagg (Β°C)", interactive=False) |
|
|
quick_tmon = gr.Number(label="Tm,on (Β°C)", interactive=False) |
|
|
quick_tm = gr.Number(label="Tm (Β°C)", interactive=False) |
|
|
|
|
|
with gr.Tab("π Information"): |
|
|
gr.Markdown(""" |
|
|
### About AbMelt |
|
|
|
|
|
AbMelt is a computational protocol for predicting antibody thermostability using molecular dynamics simulations and machine learning. |
|
|
|
|
|
#### Method Overview: |
|
|
1. **Structure Generation**: Uses ImmuneBuilder to generate 3D antibody structures from sequences |
|
|
2. **System Preparation**: Prepares molecular dynamics simulation system with GROMACS |
|
|
3. **Multi-temperature MD**: Runs simulations at 300K, 350K, and 400K |
|
|
4. **Descriptor Calculation**: Computes structural and dynamic descriptors |
|
|
5. **ML Prediction**: Uses Random Forest models to predict thermostability |
|
|
|
|
|
#### Predictions: |
|
|
- **Tagg**: Aggregation temperature - when antibodies start to clump together |
|
|
- **Tm,on**: On-pathway melting temperature - structured unfolding temperature |
|
|
- **Tm**: Overall melting temperature - general thermal stability |
|
|
|
|
|
#### Citation: |
|
|
``` |
|
|
@article{rollins2024, |
|
|
title = {{AbMelt}: {Learning} {antibody} {thermostability} from {molecular} {dynamics}}, |
|
|
journal = {preprint}, |
|
|
author = {Rollins, Zachary A and Widatalla, Talal and Cheng, Alan C and Metwally, Essam}, |
|
|
month = feb, |
|
|
year = {2024} |
|
|
} |
|
|
``` |
|
|
|
|
|
#### Computational Requirements: |
|
|
- Full pipeline: 2-4 hours per antibody |
|
|
- Memory: ~8GB for typical antibody |
|
|
- Storage: ~2GB for trajectory files |
|
|
""") |
|
|
|
|
|
|
|
|
def update_time_estimate(sim_time_val, temps_str): |
|
|
try: |
|
|
temp_count = len([t.strip() for t in temps_str.split(',') if t.strip()]) |
|
|
base_time_minutes = sim_time_val * temp_count * 15 |
|
|
total_time = base_time_minutes + 30 |
|
|
|
|
|
hours = total_time // 60 |
|
|
minutes = total_time % 60 |
|
|
|
|
|
if hours > 0: |
|
|
return f"~{hours}h {minutes}m" |
|
|
else: |
|
|
return f"~{minutes}m" |
|
|
except: |
|
|
return "Unable to estimate" |
|
|
|
|
|
def run_pipeline_wrapper(heavy, light, sim_time_val, temps_str): |
|
|
"""Wrapper to run pipeline with progress updates""" |
|
|
|
|
|
|
|
|
if not heavy or not light: |
|
|
return ( |
|
|
None, None, None, |
|
|
"β Error: Both heavy and light chain sequences are required", |
|
|
None, None, None |
|
|
) |
|
|
|
|
|
if len(heavy.strip()) < 50 or len(light.strip()) < 50: |
|
|
return ( |
|
|
None, None, None, |
|
|
"β Error: Sequences seem too short. Please provide complete variable regions (>50 residues each)", |
|
|
None, None, None |
|
|
) |
|
|
|
|
|
|
|
|
progress_updates = [] |
|
|
|
|
|
def progress_callback(percent, message): |
|
|
progress_updates.append(f"[{percent}%] {message}") |
|
|
return progress_updates |
|
|
|
|
|
try: |
|
|
|
|
|
results = pipeline.run_complete_pipeline( |
|
|
heavy, light, sim_time_val, temps_str, progress_callback |
|
|
) |
|
|
|
|
|
|
|
|
predictions = results.get('predictions', {}) |
|
|
logs = "\\n".join(results.get('logs', [])) |
|
|
|
|
|
if results.get('error'): |
|
|
logs += f"\\nβ {results['error']}" |
|
|
|
|
|
|
|
|
structure_file = results.get('intermediate_files', {}).get('structure') |
|
|
desc_file = results.get('intermediate_files', {}).get('descriptors_csv') |
|
|
traj_info = None |
|
|
|
|
|
if results.get('intermediate_files', {}).get('trajectories'): |
|
|
traj_count = len(results['intermediate_files']['trajectories']) |
|
|
traj_info = f"Generated {traj_count} trajectory files" |
|
|
|
|
|
|
|
|
tagg_val = predictions.get('tagg', {}).get('value') |
|
|
tmon_val = predictions.get('tmon', {}).get('value') |
|
|
tm_val = predictions.get('tm', {}).get('value') |
|
|
|
|
|
return ( |
|
|
tagg_val, tmon_val, tm_val, |
|
|
logs, |
|
|
structure_file, desc_file, traj_info |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"β Pipeline failed: {str(e)}" |
|
|
logger.error(f"Pipeline wrapper failed: {traceback.format_exc()}") |
|
|
return ( |
|
|
None, None, None, |
|
|
error_msg, |
|
|
None, None, None |
|
|
) |
|
|
|
|
|
def quick_prediction(desc_file): |
|
|
"""Handle quick prediction from uploaded descriptors""" |
|
|
if desc_file is None: |
|
|
return None, None, None, "Please upload a descriptor CSV file" |
|
|
|
|
|
try: |
|
|
|
|
|
df = pd.read_csv(desc_file.name) |
|
|
descriptors = df.iloc[0].to_dict() |
|
|
|
|
|
|
|
|
if pipeline.predictor: |
|
|
predictions = pipeline.predictor.predict_thermostability(descriptors) |
|
|
|
|
|
tagg_val = predictions.get('tagg', {}).get('value') |
|
|
tmon_val = predictions.get('tmon', {}).get('value') |
|
|
tm_val = predictions.get('tm', {}).get('value') |
|
|
|
|
|
return tagg_val, tmon_val, tm_val |
|
|
else: |
|
|
return None, None, None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Quick prediction failed: {e}") |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
sim_time.change( |
|
|
update_time_estimate, |
|
|
inputs=[sim_time, temperatures], |
|
|
outputs=time_estimate |
|
|
) |
|
|
|
|
|
temperatures.change( |
|
|
update_time_estimate, |
|
|
inputs=[sim_time, temperatures], |
|
|
outputs=time_estimate |
|
|
) |
|
|
|
|
|
run_button.click( |
|
|
run_pipeline_wrapper, |
|
|
inputs=[heavy_chain, light_chain, sim_time, temperatures], |
|
|
outputs=[ |
|
|
tagg_result, tmon_result, tm_result, |
|
|
pipeline_logs, |
|
|
structure_download, descriptors_download, trajectory_info |
|
|
] |
|
|
) |
|
|
|
|
|
quick_predict_btn.click( |
|
|
quick_prediction, |
|
|
inputs=descriptor_upload, |
|
|
outputs=[quick_tagg, quick_tmon, quick_tm] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo = create_interface() |
|
|
demo.queue(max_size=3) |
|
|
demo.launch(share=True) |