esunAI commited on
Commit
37158e8
·
verified ·
1 Parent(s): 016900a

Add generate_amps.py

Browse files
Files changed (1) hide show
  1. src/generate_amps.py +389 -0
src/generate_amps.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import os
6
+ from datetime import datetime
7
+ #d Import torchdiffeq for proper ODE solving
8
+ try:
9
+ from torchdiffeq import odeint
10
+ TORCHDIFFEQ_AVAILABLE = True
11
+ print("✓ torchdiffeq available for proper ODE solving")
12
+ except ImportError:
13
+ TORCHDIFFEQ_AVAILABLE = False
14
+ print("⚠️ torchdiffeq not available, using manual Euler integration")
15
+
16
+ # Import your components
17
+ from compressor_with_embeddings import Compressor, Decompressor
18
+ from final_flow_model import AMPFlowMatcherCFGConcat, AMPProtFlowPipelineCFG
19
+
20
+ class AMPGenerator:
21
+ """
22
+ Generate AMP samples using trained ProtFlow model.
23
+ """
24
+
25
+ def __init__(self, model_path, device='cuda'):
26
+ self.device = device
27
+
28
+ # Load models
29
+ self._load_models(model_path)
30
+
31
+ # Load preprocessing statistics
32
+ self.stats = torch.load('normalization_stats.pt', map_location=device)
33
+
34
+ def _load_models(self, model_path):
35
+ """Load trained models."""
36
+ print("Loading trained models...")
37
+
38
+ # Load compressor and decompressor
39
+ self.compressor = Compressor().to(self.device)
40
+ self.decompressor = Decompressor().to(self.device)
41
+
42
+ self.compressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_compressor_model.pth', map_location=self.device))
43
+ self.decompressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_decompressor_model.pth', map_location=self.device))
44
+
45
+ # Load flow matching model with CFG
46
+ self.flow_model = AMPFlowMatcherCFGConcat(
47
+ hidden_dim=480,
48
+ compressed_dim=80, # 1280 // 16
49
+ n_layers=12,
50
+ n_heads=16,
51
+ dim_ff=3072,
52
+ max_seq_len=25,
53
+ use_cfg=True
54
+ ).to(self.device)
55
+
56
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
57
+
58
+ # Handle PyTorch compilation wrapper
59
+ state_dict = checkpoint['flow_model_state_dict']
60
+ new_state_dict = {}
61
+
62
+ for key, value in state_dict.items():
63
+ # Remove _orig_mod prefix if present
64
+ if key.startswith('_orig_mod.'):
65
+ new_key = key[10:] # Remove '_orig_mod.' prefix
66
+ else:
67
+ new_key = key
68
+ new_state_dict[new_key] = value
69
+
70
+ self.flow_model.load_state_dict(new_state_dict)
71
+
72
+ print(f"✓ All models loaded successfully from step {checkpoint['step']}!")
73
+ print(f" Loss at checkpoint: {checkpoint['loss']:.6f}")
74
+
75
+ # Initialize ODE solving capabilities
76
+ if TORCHDIFFEQ_AVAILABLE:
77
+ print("✓ Enhanced with proper ODE solving (torchdiffeq)")
78
+ else:
79
+ print("⚠️ Using fallback Euler integration")
80
+
81
+ def _create_ode_func(self, cfg_scale=7.5):
82
+ """Create ODE function for torchdiffeq integration."""
83
+
84
+ def ode_func(t, x):
85
+ """
86
+ ODE function: dx/dt = v_theta(x, t)
87
+
88
+ Args:
89
+ t: scalar time (single float)
90
+ x: state tensor [B*L*D] (flattened)
91
+ Returns:
92
+ dx/dt: derivative [B*L*D] (flattened)
93
+ """
94
+ # Reshape x back to [B, L, D]
95
+ batch_size, seq_len, dim = self.current_shape
96
+ x = x.view(batch_size, seq_len, dim)
97
+
98
+ # Create time tensor for batch
99
+ t_tensor = torch.full((batch_size,), t, device=self.device, dtype=x.dtype)
100
+
101
+ # Compute vector field with CFG
102
+ if cfg_scale > 0:
103
+ # With AMP condition
104
+ amp_labels = torch.full((batch_size,), 0, device=self.device) # 0 = AMP
105
+ vt_cond = self.flow_model(x, t_tensor, labels=amp_labels)
106
+
107
+ # Without condition (mask)
108
+ mask_labels = torch.full((batch_size,), 2, device=self.device) # 2 = Mask
109
+ vt_uncond = self.flow_model(x, t_tensor, labels=mask_labels)
110
+
111
+ # CFG interpolation
112
+ vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
113
+ else:
114
+ # No CFG, use mask label
115
+ mask_labels = torch.full((batch_size,), 2, device=self.device)
116
+ vt = self.flow_model(x, t_tensor, labels=mask_labels)
117
+
118
+ # Return flattened derivative
119
+ return vt.view(-1)
120
+
121
+ return ode_func
122
+
123
+ def generate_amps(self, num_samples=100, num_steps=25, batch_size=32, cfg_scale=7.5,
124
+ ode_method='dopri5', rtol=1e-5, atol=1e-6):
125
+ """
126
+ Generate AMP samples using flow matching with CFG and improved ODE solving.
127
+
128
+ Args:
129
+ num_samples: Number of AMP samples to generate
130
+ num_steps: Number of ODE solving steps (25 for good quality, 1 for reflow)
131
+ batch_size: Batch size for generation
132
+ cfg_scale: CFG guidance scale (higher = stronger conditioning)
133
+ ode_method: ODE solver method ('dopri5', 'rk4', 'euler', 'adaptive_heun')
134
+ rtol: Relative tolerance for adaptive solvers
135
+ atol: Absolute tolerance for adaptive solvers
136
+ """
137
+ method_str = f"{ode_method} ODE solver" if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler' else "manual Euler integration"
138
+ print(f"Generating {num_samples} AMP samples with {method_str} (CFG scale: {cfg_scale})...")
139
+ if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler':
140
+ print(f" Method: {ode_method}, rtol={rtol}, atol={atol}")
141
+
142
+ self.flow_model.eval()
143
+ self.compressor.eval()
144
+ self.decompressor.eval()
145
+
146
+ all_generated = []
147
+
148
+ with torch.no_grad():
149
+ for i in tqdm(range(0, num_samples, batch_size), desc="Generating with improved ODE"):
150
+ current_batch = min(batch_size, num_samples - i)
151
+
152
+ # Sample random noise (starting point at t=1)
153
+ eps = torch.randn(current_batch, 25, 80, device=self.device) # [B, L', COMP_DIM]
154
+
155
+ # Choose ODE solving method
156
+ if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler':
157
+ # Use proper ODE solver
158
+ try:
159
+ # Store shape for ODE function
160
+ self.current_shape = eps.shape
161
+
162
+ # Create ODE function
163
+ ode_func = self._create_ode_func(cfg_scale=cfg_scale)
164
+
165
+ # Time span: from t=1 (noise) to t=0 (data)
166
+ t_span = torch.tensor([1.0, 0.0], device=self.device, dtype=eps.dtype)
167
+
168
+ # Flatten initial condition for torchdiffeq
169
+ y0 = eps.view(-1)
170
+
171
+ # Solve ODE with proper adaptive solver
172
+ if ode_method in ['dopri5', 'adaptive_heun']:
173
+ # Adaptive solvers
174
+ solution = odeint(
175
+ ode_func, y0, t_span,
176
+ method=ode_method,
177
+ rtol=rtol,
178
+ atol=atol,
179
+ options={'max_num_steps': 1000}
180
+ )
181
+ else:
182
+ # Fixed-step solvers
183
+ solution = odeint(
184
+ ode_func, y0, t_span,
185
+ method=ode_method,
186
+ options={'step_size': 0.04} # 1/25 for 25 steps
187
+ )
188
+
189
+ # Get final solution (at t=0)
190
+ xt = solution[-1].view(self.current_shape)
191
+
192
+ except Exception as e:
193
+ print(f"⚠️ ODE solving failed for batch {i//batch_size + 1}: {e}")
194
+ print("Falling back to Euler method...")
195
+ # Fall through to Euler method
196
+ xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps)
197
+ else:
198
+ # Use manual Euler integration (original method)
199
+ xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps)
200
+
201
+ # Decompress to get embeddings
202
+ decompressed = self.decompressor(xt) # [B, L, ESM_DIM]
203
+
204
+ # Apply reverse preprocessing
205
+ m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max']
206
+ decompressed = decompressed * (mx - mn + 1e-8) + mn
207
+ decompressed = decompressed * s + m
208
+
209
+ all_generated.append(decompressed.cpu())
210
+
211
+ # Concatenate all batches
212
+ generated_embeddings = torch.cat(all_generated, dim=0)
213
+
214
+ print(f"✓ Generated {generated_embeddings.shape[0]} AMP embeddings")
215
+ print(f" Shape: {generated_embeddings.shape}")
216
+ print(f" Stats - Mean: {generated_embeddings.mean():.4f}, Std: {generated_embeddings.std():.4f}")
217
+
218
+ return generated_embeddings
219
+
220
+ def _generate_with_euler(self, eps, current_batch, cfg_scale, num_steps):
221
+ """Fallback Euler integration method (original implementation)."""
222
+ xt = eps.clone()
223
+ amp_labels = torch.full((current_batch,), 0, device=self.device) # 0 = AMP
224
+ mask_labels = torch.full((current_batch,), 2, device=self.device) # 2 = Mask
225
+
226
+ for step in range(num_steps):
227
+ t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps)
228
+
229
+ # CFG: Generate with condition and without condition
230
+ if cfg_scale > 0:
231
+ # With AMP condition
232
+ vt_cond = self.flow_model(xt, t, labels=amp_labels)
233
+
234
+ # Without condition (mask)
235
+ vt_uncond = self.flow_model(xt, t, labels=mask_labels)
236
+
237
+ # CFG interpolation
238
+ vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
239
+ else:
240
+ # No CFG, use mask label
241
+ vt = self.flow_model(xt, t, labels=mask_labels)
242
+
243
+ # Euler step for backward integration (t: 1 -> 0)
244
+ dt = -1.0 / num_steps
245
+ xt = xt + vt * dt
246
+
247
+ return xt
248
+
249
+ def compare_ode_methods(self, num_samples=20, cfg_scale=7.5):
250
+ """
251
+ Compare different ODE solving methods for quality assessment.
252
+ """
253
+ if not TORCHDIFFEQ_AVAILABLE:
254
+ print("⚠️ torchdiffeq not available, cannot compare ODE methods")
255
+ return self.generate_amps(num_samples=num_samples, cfg_scale=cfg_scale)
256
+
257
+ methods = ['euler', 'rk4', 'dopri5', 'adaptive_heun']
258
+ results = {}
259
+
260
+ print("🔬 Comparing ODE solving methods...")
261
+
262
+ for method in methods:
263
+ print(f"\n--- Testing {method} ---")
264
+ try:
265
+ start_time = torch.cuda.Event(enable_timing=True)
266
+ end_time = torch.cuda.Event(enable_timing=True)
267
+
268
+ start_time.record()
269
+ embeddings = self.generate_amps(
270
+ num_samples=num_samples,
271
+ batch_size=10,
272
+ cfg_scale=cfg_scale,
273
+ ode_method=method
274
+ )
275
+ end_time.record()
276
+
277
+ torch.cuda.synchronize()
278
+ elapsed_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds
279
+
280
+ results[method] = {
281
+ 'embeddings': embeddings,
282
+ 'time': elapsed_time,
283
+ 'mean': embeddings.mean().item(),
284
+ 'std': embeddings.std().item(),
285
+ 'success': True
286
+ }
287
+
288
+ print(f"✓ {method}: {elapsed_time:.2f}s, mean={embeddings.mean():.4f}, std={embeddings.std():.4f}")
289
+
290
+ except Exception as e:
291
+ print(f"❌ {method} failed: {e}")
292
+ results[method] = {'success': False, 'error': str(e)}
293
+
294
+ return results
295
+
296
+ def generate_with_reflow(self, num_samples=100):
297
+ """
298
+ Generate AMP samples using 1-step reflow (if you have reflow model).
299
+ """
300
+ print(f"Generating {num_samples} AMP samples with 1-step reflow...")
301
+
302
+ # This would use the reflow implementation
303
+ # For now, just use 1-step generation
304
+ return self.generate_amps(num_samples=num_samples, num_steps=1, batch_size=32)
305
+
306
+ def main():
307
+ """Main generation function."""
308
+ print("=== AMP Generation Pipeline with CFG ===")
309
+
310
+ # Use the best model from training (lowest validation loss: 0.017183)
311
+ model_path = '/data2/edwardsun/flow_checkpoints/amp_flow_model_best_optimized.pth'
312
+
313
+ # Check if checkpoint exists
314
+ try:
315
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
316
+ print(f"✓ Found best model at step {checkpoint['step']} with loss {checkpoint['loss']:.6f}")
317
+ print(f" Global step: {checkpoint['global_step']}")
318
+ print(f" Total samples: {checkpoint['total_samples']:,}")
319
+ except:
320
+ print(f"❌ Best model not found: {model_path}")
321
+ print("Please train the flow matching model first using amp_flow_training.py")
322
+ return
323
+
324
+ # Initialize generator
325
+ generator = AMPGenerator(model_path, device='cuda')
326
+
327
+ # Test ODE methods comparison if available
328
+ if TORCHDIFFEQ_AVAILABLE:
329
+ print("\n🔬 Comparing ODE solving methods...")
330
+ comparison_results = generator.compare_ode_methods(num_samples=10, cfg_scale=7.5)
331
+
332
+ # Use best method for generation
333
+ best_method = 'dopri5' # Recommended method
334
+ print(f"\n🚀 Using {best_method} for main generation...")
335
+ else:
336
+ best_method = 'euler'
337
+ print("\n⚠️ Using fallback Euler integration...")
338
+
339
+ # Generate samples with different CFG scales using improved ODE solving
340
+ print("\n1. Generating with CFG scale 0.0 (no conditioning)...")
341
+ samples_no_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=0.0, ode_method=best_method)
342
+
343
+ print("\n2. Generating with CFG scale 3.0 (weak conditioning)...")
344
+ samples_weak_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=3.0, ode_method=best_method)
345
+
346
+ print("\n3. Generating with CFG scale 7.5 (strong conditioning)...")
347
+ samples_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=7.5, ode_method=best_method)
348
+
349
+ print("\n4. Generating with CFG scale 15.0 (very strong conditioning)...")
350
+ samples_very_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=15.0, ode_method=best_method)
351
+
352
+ # Create output directory if it doesn't exist
353
+ output_dir = '/data2/edwardsun/generated_samples'
354
+ os.makedirs(output_dir, exist_ok=True)
355
+
356
+ # Get today's date for filename
357
+ today = datetime.now().strftime('%Y%m%d')
358
+
359
+ # Save generated samples with date
360
+ torch.save(samples_no_cfg, os.path.join(output_dir, f'generated_amps_best_model_no_cfg_{today}.pt'))
361
+ torch.save(samples_weak_cfg, os.path.join(output_dir, f'generated_amps_best_model_weak_cfg_{today}.pt'))
362
+ torch.save(samples_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_strong_cfg_{today}.pt'))
363
+ torch.save(samples_very_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_very_strong_cfg_{today}.pt'))
364
+
365
+ print("\n✓ Generation complete!")
366
+ print(f"Generated samples saved (Date: {today}):")
367
+ print(f" - generated_amps_best_model_no_cfg_{today}.pt (no conditioning)")
368
+ print(f" - generated_amps_best_model_weak_cfg_{today}.pt (weak CFG)")
369
+ print(f" - generated_amps_best_model_strong_cfg_{today}.pt (strong CFG)")
370
+ print(f" - generated_amps_best_model_very_strong_cfg_{today}.pt (very strong CFG)")
371
+
372
+ print("\nCFG Analysis:")
373
+ print(" - CFG scale 0.0: No conditioning, generates diverse sequences")
374
+ print(" - CFG scale 3.0: Weak AMP conditioning")
375
+ print(" - CFG scale 7.5: Strong AMP conditioning (recommended)")
376
+ print(" - CFG scale 15.0: Very strong AMP conditioning (may be too restrictive)")
377
+
378
+ print("\nNext steps:")
379
+ print("1. Decode embeddings back to sequences using ESM-2 decoder")
380
+ print("2. Evaluate with ProtFlow metrics (FPD, MMD, ESM-2 perplexity)")
381
+ print("3. Compare sequences generated with different CFG scales")
382
+ print("4. Evaluate AMP properties (antimicrobial activity, toxicity)")
383
+ if TORCHDIFFEQ_AVAILABLE:
384
+ print(f"5. ✓ Enhanced generation with {best_method} ODE solver")
385
+ else:
386
+ print("5. Install torchdiffeq for improved ODE solving: pip install torchdiffeq")
387
+
388
+ if __name__ == "__main__":
389
+ main()