lhallee commited on
Commit
b015e7a
·
verified ·
1 Parent(s): 13e4436

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +114 -59
README.md CHANGED
@@ -12,36 +12,40 @@ The GitHub with the implementation and requirements.txt can be found [here](http
12
 
13
  # FastESMFold
14
 
15
- FastESMFold is a self-contained, HuggingFace-compatible reimplementation of ESMFold with **built-in Test-Time Training (TTT)** and multi-backend attention (SDPA, Flash, Flex).
16
 
17
  No dependency on `fair-esm`, `proteinttt`, or `openfold`. Just `transformers`, `torch`, and `einops`.
18
 
19
- ## Key Features
20
 
21
- - **Always-on TTT**: Runs 10 steps of masked language model adaptation via LoRA before folding. Returns the structure with the highest pLDDT across all steps.
22
- - **Best structure selection**: Folds after each TTT step, tracks per-step pLDDT, returns the best.
23
- - **FastESM2 attention**: SDPA/Flash/Flex backends for the 3B ESM2 backbone.
24
- - **Self-contained LoRA**: lora_diffusion-compatible implementation (no peft dependency). `Normal(0, 1/r)` initialization, `scale=alpha`.
25
- - **3.5B parameters**: Full ESMFold v1 architecture (ESM2-3B backbone + folding trunk).
26
 
27
- ## Benchmark
28
 
29
- Tested on 10 difficult sequences on A10G GPU:
30
 
31
- | Metric | Value |
32
- |--------|-------|
33
- | Mean baseline pLDDT | 0.549 |
34
- | Mean best TTT pLDDT | 0.637 |
35
- | Mean improvement | +0.088 |
36
- | Sequences improved >5pt | 5/10 |
37
- | Time per sequence | ~20-45s |
38
- | GPU memory peak | 18.3 GB |
 
 
39
 
40
- On the hardest sequence (baseline pLDDT 0.38), TTT achieves 0.72 (+34 points).
 
 
 
 
 
41
 
42
  ## Use with transformers
43
 
44
- ### Basic usage
 
45
  ```python
46
  import torch
47
  from transformers import AutoModel
@@ -52,16 +56,38 @@ model = AutoModel.from_pretrained(
52
  torch_dtype=torch.float32,
53
  ).cuda().eval()
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  result = model.fold_protein("MKTLLILAVVAAALA...")
56
  print(f"pLDDT: {result['plddt']:.3f}")
57
- print(f"Best step: {result['best_step']}")
58
- print(f"Step pLDDTs: {result['step_plddts']}")
59
- print(f"PDB length: {len(result['pdb_string'])} chars")
 
 
 
60
  ```
61
 
62
  ### Return values
63
 
64
- `fold_protein(sequence)` returns a dict with:
65
 
66
  | Key | Type | Description |
67
  |-----|------|-------------|
@@ -71,28 +97,38 @@ print(f"PDB length: {len(result['pdb_string'])} chars")
71
  | `step_plddts` | list[float] | pLDDT at each step [baseline, s1, ..., s10] |
72
  | `best_step` | int | Which step produced the best structure (0=baseline) |
73
 
74
- ### Loading from Synthyra/ESMFold-v1 with custom config
 
 
 
75
  ```python
76
- from esmfold.modeling_fast_esmfold import FastEsmFoldConfig, FastEsmForProteinFolding
77
-
78
- config = FastEsmFoldConfig.from_pretrained("Synthyra/ESMFold-v1")
79
- config.attn_backend = "sdpa"
80
- config.ttt_config = {
81
- "lr": 4e-4,
82
- "steps": 10,
83
- "ags": 4,
84
- "batch_size": 4,
85
- "lora_rank": 8,
86
- "lora_alpha": 32.0,
87
- "seed": 0,
88
- }
89
- model = FastEsmForProteinFolding.from_pretrained(
90
- "Synthyra/ESMFold-v1",
91
- config=config,
92
- torch_dtype=torch.float32,
93
- ).cuda().eval()
94
  ```
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  ## Attention backends
97
 
98
  The ESM2 backbone supports multiple attention backends via `config.attn_backend`:
@@ -104,41 +140,60 @@ The ESM2 backbone supports multiple attention backends via `config.attn_backend`
104
  | Flex Attention | `"flex"` | Skips padding tokens via block mask. First use compiles a Triton kernel. |
105
  | Auto | `"auto"` | Picks best available: `kernels_flash` > `flex` > `sdpa`. |
106
 
 
 
 
 
 
 
 
 
107
  ## TTT Configuration
108
 
109
- TTT parameters can be customized via `config.ttt_config`:
110
 
111
  | Parameter | Default | Description |
112
  |-----------|---------|-------------|
113
  | `lr` | 4e-4 | Learning rate for SGD optimizer |
114
- | `steps` | 10 | Number of optimizer steps |
115
  | `ags` | 4 | Gradient accumulation steps per optimizer step |
116
  | `batch_size` | 4 | Batch size for masked language model training |
117
  | `mask_ratio` | 0.15 | Fraction of tokens to mask |
118
- | `lora_rank` | 8 | LoRA rank (0 for full fine-tuning) |
119
- | `lora_alpha` | 32.0 | LoRA scaling factor |
120
- | `seed` | 0 | Random seed for reproducibility |
 
121
 
122
  ## How TTT Works
123
 
124
  1. **Baseline fold** (step 0): Standard ESMFold prediction
125
- 2. **LoRA injection**: Rank-8 LoRA adapters on ESM2 attention Q/K/V projections
126
- 3. **Masked LM training**: 10 optimizer steps of BERT-style masked language modeling on the input sequence
127
  4. **Per-step folding**: After each optimizer step, fold the sequence and record pLDDT
128
  5. **Best selection**: Return the structure with the highest pLDDT
129
  6. **Reset**: Restore LoRA weights to initial state for the next sequence
130
 
131
- This is based on the ProteinTTT paper (test-time compute for protein structure prediction).
132
 
133
- ### Citation
134
- If you use this implementation please cite it:
135
- ```
136
- @misc {FastPLMs,
137
- author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
138
- title = { FastPLMs: Fast, efficient, protein language model inference from Huggingface AutoModel.},
139
  year = {2024},
140
- url = { https://huggingface.co/Synthyra/ESMplusplus_small },
141
- DOI = { 10.57967/hf/3726 },
142
- publisher = { Hugging Face }
 
 
 
 
 
 
 
 
 
 
143
  }
144
  ```
 
12
 
13
  # FastESMFold
14
 
15
+ FastESMFold is a self-contained, HuggingFace-compatible reimplementation of ESMFold with optional **Test-Time Training (TTT)** and multi-backend attention (SDPA, Flash, Flex).
16
 
17
  No dependency on `fair-esm`, `proteinttt`, or `openfold`. Just `transformers`, `torch`, and `einops`.
18
 
19
+ ## Why Test-Time Training?
20
 
21
+ Protein language models like ESM2 are trained on millions of sequences, but at inference time they process each new protein in a single forward pass with no adaptation. This is a missed opportunity: the input sequence itself contains structural signal that the model could learn from.
 
 
 
 
22
 
23
+ **Test-Time Training (TTT)** adapts the model to each individual protein before predicting its structure. The idea is simple: before folding, we briefly train the ESM2 backbone on the input sequence using masked language modeling (the same objective it was pretrained with). This forces the model to "study" the specific sequence, strengthening its internal representation of that protein's structural features.
24
 
25
+ The adaptation uses **LoRA** (Low-Rank Adaptation) for efficiency: only small adapter weights are trained (~4.4M parameters out of 3.5B), and the base model is restored after each prediction. This takes 20-45 seconds per sequence on an A10G GPU but can dramatically improve structure prediction quality, especially on difficult targets where standard ESMFold produces low-confidence predictions.
26
 
27
+ **When is TTT most useful?**
28
+ - Sequences with low baseline pLDDT (< 0.5): TTT can improve pLDDT by 10-30+ points
29
+ - Novel proteins with limited homology in training data
30
+ - Disordered or multi-domain proteins where ESMFold struggles
31
+
32
+ **When is TTT unnecessary?**
33
+ - Sequences that already fold well (baseline pLDDT > 0.7): TTT rarely helps and may slightly degrade predictions
34
+ - High-throughput screening where speed matters more than accuracy
35
+
36
+ ## Key Features
37
 
38
+ - **Standard ESMFold**: Full ESMFold v1 structure prediction, loadable via `AutoModel`
39
+ - **Optional TTT**: Enable test-time training for improved structure prediction on difficult sequences
40
+ - **Best structure selection**: When TTT is enabled, folds after each step and returns the structure with the highest pLDDT
41
+ - **FastESM2 attention**: SDPA/Flash/Flex backends for the 3B ESM2 backbone
42
+ - **Self-contained LoRA**: lora_diffusion-compatible implementation (no peft dependency)
43
+ - **3.5B parameters**: Full ESMFold v1 architecture (ESM2-3B backbone + folding trunk)
44
 
45
  ## Use with transformers
46
 
47
+ ### Standard structure prediction (no TTT)
48
+
49
  ```python
50
  import torch
51
  from transformers import AutoModel
 
56
  torch_dtype=torch.float32,
57
  ).cuda().eval()
58
 
59
+ # Standard fold (no TTT)
60
+ with torch.no_grad():
61
+ output = model.infer("MKTLLILAVVAAALA...")
62
+ pdb_strings = model.output_to_pdb(output)
63
+ plddt = output["plddt"].mean().item()
64
+ print(f"pLDDT: {plddt:.3f}")
65
+ ```
66
+
67
+ ### Structure prediction with TTT
68
+
69
+ TTT adapts the ESM2 backbone to a specific input sequence via masked language modeling before folding. This can dramatically improve pLDDT on difficult sequences (e.g., 0.38 to 0.72).
70
+
71
+ ```python
72
+ # Configure TTT
73
+ model._ttt_cfg.steps = 10 # 10 optimizer steps (default)
74
+ model._ttt_cfg.lora_rank = 8 # LoRA rank (default)
75
+ model._ttt_cfg.lora_alpha = 32 # LoRA scale (default)
76
+
77
+ # fold_protein() runs TTT, folds after each step, returns best structure
78
  result = model.fold_protein("MKTLLILAVVAAALA...")
79
  print(f"pLDDT: {result['plddt']:.3f}")
80
+ print(f"Best step: {result['best_step']} (0=baseline, 1-10=TTT steps)")
81
+ print(f"Step pLDDTs: {[f'{p:.2f}' for p in result['step_plddts']]}")
82
+
83
+ # Save PDB
84
+ with open("structure.pdb", "w") as f:
85
+ f.write(result["pdb_string"])
86
  ```
87
 
88
  ### Return values
89
 
90
+ `fold_protein(sequence)` returns a dict:
91
 
92
  | Key | Type | Description |
93
  |-----|------|-------------|
 
97
  | `step_plddts` | list[float] | pLDDT at each step [baseline, s1, ..., s10] |
98
  | `best_step` | int | Which step produced the best structure (0=baseline) |
99
 
100
+ ### Disabling TTT
101
+
102
+ To use FastESMFold as a standard ESMFold (no TTT), set `steps=0` or call `infer()` directly:
103
+
104
  ```python
105
+ # Option 1: Set TTT steps to 0
106
+ config = AutoConfig.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True)
107
+ config.ttt_config = {"steps": 0}
108
+ model = AutoModel.from_pretrained("Synthyra/FastESMFold", config=config, trust_remote_code=True)
109
+ result = model.fold_protein("MKTLLILAVVAAALA...") # No TTT, just baseline fold
110
+
111
+ # Option 2: Call infer() directly (inherited from EsmForProteinFolding)
112
+ with torch.no_grad():
113
+ output = model.infer("MKTLLILAVVAAALA...")
114
+ pdb_strings = model.output_to_pdb(output)
 
 
 
 
 
 
 
 
115
  ```
116
 
117
+ ## TTT Benchmark
118
+
119
+ Tested on 10 difficult sequences on A10G GPU:
120
+
121
+ | Metric | Value |
122
+ |--------|-------|
123
+ | Mean baseline pLDDT | 0.549 |
124
+ | Mean best TTT pLDDT | 0.637 |
125
+ | Mean improvement | +0.088 |
126
+ | Sequences improved >5pt | 5/10 |
127
+ | Time per sequence | ~20-45s |
128
+ | GPU memory peak | 18.3 GB |
129
+
130
+ On the hardest sequence (baseline pLDDT 0.38), TTT improves to 0.72 (+34 points).
131
+
132
  ## Attention backends
133
 
134
  The ESM2 backbone supports multiple attention backends via `config.attn_backend`:
 
140
  | Flex Attention | `"flex"` | Skips padding tokens via block mask. First use compiles a Triton kernel. |
141
  | Auto | `"auto"` | Picks best available: `kernels_flash` > `flex` > `sdpa`. |
142
 
143
+ ```python
144
+ from transformers import AutoConfig, AutoModel
145
+
146
+ config = AutoConfig.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True)
147
+ config.attn_backend = "kernels_flash"
148
+ model = AutoModel.from_pretrained("Synthyra/FastESMFold", config=config, trust_remote_code=True)
149
+ ```
150
+
151
  ## TTT Configuration
152
 
153
+ TTT parameters are set via `config.ttt_config` (a dict) or by modifying `model._ttt_cfg` after loading:
154
 
155
  | Parameter | Default | Description |
156
  |-----------|---------|-------------|
157
  | `lr` | 4e-4 | Learning rate for SGD optimizer |
158
+ | `steps` | 10 | Number of optimizer steps (0 to disable TTT) |
159
  | `ags` | 4 | Gradient accumulation steps per optimizer step |
160
  | `batch_size` | 4 | Batch size for masked language model training |
161
  | `mask_ratio` | 0.15 | Fraction of tokens to mask |
162
+ | `lora_rank` | 8 | LoRA rank (0 for full backbone fine-tuning) |
163
+ | `lora_alpha` | 32.0 | LoRA scaling factor (applied as `scale=alpha`, matching lora_diffusion) |
164
+ | `seed` | 0 | Random seed for reproducible LoRA initialization and masking |
165
+ | `lora_target_class` | `"EsmSelfAttention"` | Which module class to inject LoRA into |
166
 
167
  ## How TTT Works
168
 
169
  1. **Baseline fold** (step 0): Standard ESMFold prediction
170
+ 2. **LoRA injection**: Rank-8 LoRA adapters on all `nn.Linear` layers inside ESM2 attention modules
171
+ 3. **Masked LM training**: 10 optimizer steps (each with 4 gradient accumulation sub-steps) of BERT-style masked language modeling on the input sequence
172
  4. **Per-step folding**: After each optimizer step, fold the sequence and record pLDDT
173
  5. **Best selection**: Return the structure with the highest pLDDT
174
  6. **Reset**: Restore LoRA weights to initial state for the next sequence
175
 
176
+ ## Citations
177
 
178
+ If you use this implementation, please cite FastPLMs and the original ProteinTTT paper:
179
+
180
+ ```bibtex
181
+ @misc{FastPLMs,
182
+ author = {Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
183
+ title = {FastPLMs: Fast, efficient, protein language model inference from Huggingface AutoModel.},
184
  year = {2024},
185
+ url = {https://huggingface.co/Synthyra/ESMplusplus_small},
186
+ DOI = {10.57967/hf/3726},
187
+ publisher = {Hugging Face}
188
+ }
189
+
190
+ @misc{bushuiev2026proteinneed,
191
+ title = {One protein is all you need},
192
+ author = {Anton Bushuiev and Roman Bushuiev and Olga Pimenova and Nikola Zadorozhny and Raman Samusevich and Elisabet Manaskova and Rachel Seongeun Kim and Hannes St\"ark and Jiri Sedlar and Martin Steinegger and Tom\'a\v{s} Pluskal and Josef Sivic},
193
+ year = {2026},
194
+ eprint = {2411.02109},
195
+ archivePrefix= {arXiv},
196
+ primaryClass = {cs.LG},
197
+ url = {https://arxiv.org/abs/2411.02109},
198
  }
199
  ```