chq1155 commited on
Commit
04bd26e
·
verified ·
1 Parent(s): d6e9b56

Add Colab inference demo notebook

Browse files
Files changed (1) hide show
  1. notebooks/TD3B_Inference_Demo.ipynb +455 -0
notebooks/TD3B_Inference_Demo.ipynb ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "# TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation\n",
23
+ "\n",
24
+ "This notebook demonstrates **TD3B inference** — generating peptide binders with specified agonist or antagonist behavior for GPCR targets.\n",
25
+ "\n",
26
+ "**What TD3B does:**\n",
27
+ "- Takes a target protein sequence + desired direction (agonist / antagonist)\n",
28
+ "- Generates peptide binder sequences using a finetuned discrete diffusion model\n",
29
+ "- Scores them with a Direction Oracle and binding affinity predictor\n",
30
+ "- Returns the best candidates via weighted resampling (Algorithm 2)\n",
31
+ "\n",
32
+ "**Requirements:** GPU runtime (T4 or better). Click **Runtime → Change runtime type → GPU**."
33
+ ],
34
+ "metadata": {}
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "source": [
39
+ "## 1. Setup"
40
+ ],
41
+ "metadata": {}
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install dependencies\n",
50
+ "!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121\n",
51
+ "!pip install -q transformers fair-esm SmilesPE rdkit-pypi scipy pandas numpy xgboost pytorch-lightning lightning hydra-core loguru timm huggingface_hub"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "# Clone TD3B repository and download checkpoints from HuggingFace\n",
61
+ "!git clone https://github.com/chq1155/TD3B_ICML.git TD3B\n",
62
+ "%cd TD3B\n",
63
+ "\n",
64
+ "from huggingface_hub import hf_hub_download\n",
65
+ "import os\n",
66
+ "\n",
67
+ "REPO_ID = \"ChatterjeeLab/TD3B\"\n",
68
+ "os.makedirs(\"checkpoints\", exist_ok=True)\n",
69
+ "os.makedirs(\"data\", exist_ok=True)\n",
70
+ "\n",
71
+ "# Download checkpoints (this may take a few minutes)\n",
72
+ "for fname in [\"checkpoints/td3b.ckpt\", \"checkpoints/pretrained.ckpt\",\n",
73
+ " \"checkpoints/direction_oracle.pt\",\n",
74
+ " \"scoring/functions/classifiers/binding-affinity.pt\",\n",
75
+ " \"data/test.csv\", \"data/train.csv\"]:\n",
76
+ " print(f\"Downloading {fname}...\")\n",
77
+ " hf_hub_download(repo_id=REPO_ID, filename=fname, local_dir=\".\")\n",
78
+ "\n",
79
+ "print(\"\\nAll files downloaded!\")\n",
80
+ "!ls -lh checkpoints/"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "source": [
86
+ "## 2. Load Model and Oracle"
87
+ ],
88
+ "metadata": {}
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "import sys\n",
97
+ "sys.path.insert(0, \".\")\n",
98
+ "\n",
99
+ "import torch\n",
100
+ "import numpy as np\n",
101
+ "import pandas as pd\n",
102
+ "\n",
103
+ "from diffusion import Diffusion\n",
104
+ "from configs.finetune_config import (\n",
105
+ " DiffusionConfig, RoFormerConfig, NoiseConfig,\n",
106
+ " TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig,\n",
107
+ ")\n",
108
+ "from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer\n",
109
+ "from td3b.direction_oracle import DirectionalOracle\n",
110
+ "from td3b.td3b_scoring import TD3BRewardFunction, create_td3b_reward_function\n",
111
+ "from scoring.functions.binding import BindingAffinity\n",
112
+ "from utils.app import PeptideAnalyzer\n",
113
+ "\n",
114
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
115
+ "print(f\"Using device: {device}\")\n",
116
+ "if torch.cuda.is_available():\n",
117
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
118
+ " print(f\"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "# Load tokenizer\n",
128
+ "tokenizer = SMILES_SPE_Tokenizer(\"tokenizer/new_vocab.txt\", \"tokenizer/new_splits.txt\")\n",
129
+ "print(f\"Tokenizer vocab size: {len(tokenizer)}\")\n",
130
+ "\n",
131
+ "# Load diffusion model\n",
132
+ "print(\"\\nLoading TD3B model...\")\n",
133
+ "cfg = DiffusionConfig(\n",
134
+ " roformer=RoFormerConfig(hidden_size=768, n_layers=8, n_heads=8),\n",
135
+ " noise=NoiseConfig(),\n",
136
+ " training=TrainingConfig(sampling_eps=1e-3),\n",
137
+ " sampling=SamplingConfig(steps=128, sampling_eps=1e-3),\n",
138
+ " eval_cfg=EvalConfig(), optim=OptimConfig(lr=3e-4), mcts=MCTSConfig(),\n",
139
+ ")\n",
140
+ "model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device)\n",
141
+ "\n",
142
+ "ckpt = torch.load(\"checkpoints/td3b.ckpt\", map_location=device, weights_only=False)\n",
143
+ "state_dict = ckpt.get(\"model_state_dict\") or ckpt.get(\"state_dict\") or ckpt\n",
144
+ "model.load_state_dict(state_dict, strict=False)\n",
145
+ "model.eval()\n",
146
+ "model.tokenizer = tokenizer\n",
147
+ "print(\"TD3B model loaded!\")\n",
148
+ "\n",
149
+ "# Load Direction Oracle\n",
150
+ "print(\"\\nLoading Direction Oracle...\")\n",
151
+ "oracle = DirectionalOracle(\n",
152
+ " model_ckpt=\"checkpoints/direction_oracle.pt\",\n",
153
+ " tr2d2_checkpoint=\"checkpoints/pretrained.ckpt\",\n",
154
+ " tokenizer_vocab=\"tokenizer/new_vocab.txt\",\n",
155
+ " tokenizer_splits=\"tokenizer/new_splits.txt\",\n",
156
+ " device=device,\n",
157
+ ")\n",
158
+ "oracle.eval()\n",
159
+ "print(\"Direction Oracle loaded!\")\n",
160
+ "\n",
161
+ "# Load Affinity Predictor\n",
162
+ "print(\"\\nLoading Affinity Predictor...\")\n",
163
+ "analyzer = PeptideAnalyzer()\n",
164
+ "print(\"\\nAll models loaded!\")"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "source": [
170
+ "## 3. Define Helper Functions"
171
+ ],
172
+ "metadata": {}
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "def sample_sequences(model, batch_size, seq_length, num_steps=128, eps=1e-5):\n",
181
+ " \"\"\"Sample sequences from the diffusion model.\"\"\"\n",
182
+ " x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)\n",
183
+ " timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device)\n",
184
+ " dt = torch.tensor((1 - eps) / num_steps, device=model.device)\n",
185
+ "\n",
186
+ " for i in range(num_steps):\n",
187
+ " t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device)\n",
188
+ " _, x = model.single_reverse_step(x, t=t, dt=dt)\n",
189
+ " x = x.to(model.device)\n",
190
+ "\n",
191
+ " mask_pos = (x == model.mask_index)\n",
192
+ " if mask_pos.any():\n",
193
+ " t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device)\n",
194
+ " _, x = model.single_noise_removal(x, t=t, dt=dt)\n",
195
+ " return x\n",
196
+ "\n",
197
+ "\n",
198
+ "def generate_binders(target_seq, direction=\"agonist\", num_pool=32,\n",
199
+ " num_keep=8, alpha=0.1, seq_length=200):\n",
200
+ " \"\"\"\n",
201
+ " Generate directional binders for a target protein.\n",
202
+ " \n",
203
+ " Args:\n",
204
+ " target_seq: Target protein amino acid sequence\n",
205
+ " direction: 'agonist' or 'antagonist'\n",
206
+ " num_pool: Number of candidates to generate\n",
207
+ " num_keep: Number of final samples after resampling\n",
208
+ " alpha: Temperature for weighted resampling\n",
209
+ " seq_length: Binder sequence length (in SMILES tokens)\n",
210
+ " \n",
211
+ " Returns:\n",
212
+ " DataFrame with generated binders and scores\n",
213
+ " \"\"\"\n",
214
+ " d_star = 1.0 if direction == \"agonist\" else -1.0\n",
215
+ " \n",
216
+ " # Build reward function\n",
217
+ " affinity_pred = BindingAffinity(\n",
218
+ " prot_seq=target_seq, tokenizer=tokenizer,\n",
219
+ " base_path=\".\", device=device, emb_model=model.backbone\n",
220
+ " )\n",
221
+ " reward_fn = create_td3b_reward_function(\n",
222
+ " affinity_predictor=affinity_pred,\n",
223
+ " directional_oracle=oracle,\n",
224
+ " target_protein_seq=target_seq,\n",
225
+ " target_direction=direction,\n",
226
+ " peptide_tokenizer=tokenizer,\n",
227
+ " device=device,\n",
228
+ " )\n",
229
+ " \n",
230
+ " # Generate candidates\n",
231
+ " with torch.no_grad():\n",
232
+ " x_pool = sample_sequences(model, num_pool, seq_length)\n",
233
+ " sequences = tokenizer.batch_decode(x_pool)\n",
234
+ " \n",
235
+ " # Score all\n",
236
+ " rewards, info = reward_fn(sequences)\n",
237
+ " affinities = info[\"affinities\"]\n",
238
+ " directions = info[\"directions\"]\n",
239
+ " \n",
240
+ " # Weighted resampling (Algorithm 2)\n",
241
+ " rewards_t = torch.as_tensor(rewards, device=device)\n",
242
+ " weights = torch.softmax(rewards_t / max(alpha, 1e-6), dim=0)\n",
243
+ " idx = torch.multinomial(weights, num_samples=num_keep, replacement=True)\n",
244
+ " chosen = idx.cpu().numpy()\n",
245
+ " \n",
246
+ " # Filter to valid peptides only\n",
247
+ " results = []\n",
248
+ " for i in chosen:\n",
249
+ " is_valid = analyzer.is_peptide(sequences[i])\n",
250
+ " da = float(directions[i] > 0.5) if d_star > 0 else float(directions[i] < 0.5)\n",
251
+ " results.append({\n",
252
+ " \"sequence\": sequences[i],\n",
253
+ " \"direction\": direction,\n",
254
+ " \"is_valid\": is_valid,\n",
255
+ " \"affinity\": float(affinities[i]),\n",
256
+ " \"gated_reward\": float(rewards[i]),\n",
257
+ " \"p_agonist\": float(directions[i]),\n",
258
+ " \"direction_accuracy\": da,\n",
259
+ " })\n",
260
+ " \n",
261
+ " df = pd.DataFrame(results)\n",
262
+ " return df"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "markdown",
267
+ "source": [
268
+ "## 4. Generate Binders\n",
269
+ "\n",
270
+ "Let's generate **agonist** and **antagonist** binders for a test target and compare the Direction Oracle predictions."
271
+ ],
272
+ "metadata": {}
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "# Load test targets\n",
281
+ "test_df = pd.read_csv(\"data/test.csv\")\n",
282
+ "print(f\"Test set: {len(test_df)} target-binder pairs\")\n",
283
+ "\n",
284
+ "# Pick first target for demo\n",
285
+ "target_row = test_df.iloc[0]\n",
286
+ "TARGET_SEQ = target_row[\"Target_Sequence\"]\n",
287
+ "TARGET_UID = target_row[\"Target_UniProt_ID\"]\n",
288
+ "print(f\"\\nTarget: {TARGET_UID}\")\n",
289
+ "print(f\"Sequence length: {len(TARGET_SEQ)} aa\")\n",
290
+ "print(f\"Sequence: {TARGET_SEQ[:60]}...\")"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "metadata": {},
297
+ "outputs": [],
298
+ "source": [
299
+ "%%time\n",
300
+ "# Generate AGONIST binders\n",
301
+ "print(\"Generating agonist binders (d*=+1)...\")\n",
302
+ "torch.manual_seed(42)\n",
303
+ "np.random.seed(42)\n",
304
+ "df_agonist = generate_binders(TARGET_SEQ, direction=\"agonist\", num_pool=32, num_keep=8)\n",
305
+ "\n",
306
+ "print(f\"\\nGenerated {len(df_agonist)} samples ({df_agonist['is_valid'].sum()} valid)\")\n",
307
+ "print(f\"Mean p(agonist): {df_agonist['p_agonist'].mean():.3f}\")\n",
308
+ "print(f\"Mean affinity: {df_agonist['affinity'].mean():.2f}\")\n",
309
+ "print(f\"Mean gated reward: {df_agonist['gated_reward'].mean():.2f}\")"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "%%time\n",
319
+ "# Generate ANTAGONIST binders\n",
320
+ "print(\"Generating antagonist binders (d*=-1)...\")\n",
321
+ "torch.manual_seed(42)\n",
322
+ "np.random.seed(42)\n",
323
+ "df_antagonist = generate_binders(TARGET_SEQ, direction=\"antagonist\", num_pool=32, num_keep=8)\n",
324
+ "\n",
325
+ "print(f\"\\nGenerated {len(df_antagonist)} samples ({df_antagonist['is_valid'].sum()} valid)\")\n",
326
+ "print(f\"Mean p(agonist): {df_antagonist['p_agonist'].mean():.3f}\")\n",
327
+ "print(f\"Mean affinity: {df_antagonist['affinity'].mean():.2f}\")\n",
328
+ "print(f\"Mean gated reward: {df_antagonist['gated_reward'].mean():.2f}\")"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "markdown",
333
+ "source": [
334
+ "## 5. Compare Directional Control"
335
+ ],
336
+ "metadata": {}
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "metadata": {},
342
+ "outputs": [],
343
+ "source": [
344
+ "import matplotlib.pyplot as plt\n",
345
+ "\n",
346
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
347
+ "\n",
348
+ "# Plot 1: Direction Oracle p(agonist)\n",
349
+ "axes[0].hist(df_agonist[\"p_agonist\"], bins=20, alpha=0.7, label=\"d*=+1 (agonist)\", color=\"#e74c3c\")\n",
350
+ "axes[0].hist(df_antagonist[\"p_agonist\"], bins=20, alpha=0.7, label=\"d*=-1 (antagonist)\", color=\"#3498db\")\n",
351
+ "axes[0].axvline(0.5, color=\"gray\", linestyle=\"--\", label=\"threshold\")\n",
352
+ "axes[0].set_xlabel(\"p(agonist)\")\n",
353
+ "axes[0].set_ylabel(\"Count\")\n",
354
+ "axes[0].set_title(\"Direction Oracle Predictions\")\n",
355
+ "axes[0].legend()\n",
356
+ "\n",
357
+ "# Plot 2: Binding Affinity\n",
358
+ "axes[1].hist(df_agonist[\"affinity\"], bins=20, alpha=0.7, label=\"Agonist\", color=\"#e74c3c\")\n",
359
+ "axes[1].hist(df_antagonist[\"affinity\"], bins=20, alpha=0.7, label=\"Antagonist\", color=\"#3498db\")\n",
360
+ "axes[1].set_xlabel(\"Predicted Binding Affinity\")\n",
361
+ "axes[1].set_ylabel(\"Count\")\n",
362
+ "axes[1].set_title(\"Binding Affinity Distribution\")\n",
363
+ "axes[1].legend()\n",
364
+ "\n",
365
+ "# Plot 3: Gated Reward\n",
366
+ "axes[2].hist(df_agonist[\"gated_reward\"], bins=20, alpha=0.7, label=\"Agonist\", color=\"#e74c3c\")\n",
367
+ "axes[2].hist(df_antagonist[\"gated_reward\"], bins=20, alpha=0.7, label=\"Antagonist\", color=\"#3498db\")\n",
368
+ "axes[2].set_xlabel(\"Gated Reward\")\n",
369
+ "axes[2].set_ylabel(\"Count\")\n",
370
+ "axes[2].set_title(\"Gated Reward Distribution\")\n",
371
+ "axes[2].legend()\n",
372
+ "\n",
373
+ "plt.tight_layout()\n",
374
+ "plt.savefig(\"td3b_results.png\", dpi=150, bbox_inches=\"tight\")\n",
375
+ "plt.show()\n",
376
+ "\n",
377
+ "print(\"\\nSummary:\")\n",
378
+ "print(f\" Agonist mode: p(agonist)={df_agonist['p_agonist'].mean():.3f} Affinity={df_agonist['affinity'].mean():.2f} Gated={df_agonist['gated_reward'].mean():.2f}\")\n",
379
+ "print(f\" Antagonist mode: p(agonist)={df_antagonist['p_agonist'].mean():.3f} Affinity={df_antagonist['affinity'].mean():.2f} Gated={df_antagonist['gated_reward'].mean():.2f}\")\n",
380
+ "print(f\" Directional gap: Δp = {df_agonist['p_agonist'].mean() - df_antagonist['p_agonist'].mean():.3f}\")"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "markdown",
385
+ "source": [
386
+ "## 6. Run on Multiple Targets\n",
387
+ "\n",
388
+ "Generate binders for the first 5 test targets and compute aggregate metrics."
389
+ ],
390
+ "metadata": {}
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "N_TARGETS = 5 # Number of targets to evaluate (increase for full benchmark)\n",
399
+ "\n",
400
+ "all_results = []\n",
401
+ "targets = test_df.drop_duplicates(\"Target_UniProt_ID\").head(N_TARGETS)\n",
402
+ "\n",
403
+ "for i, (_, row) in enumerate(targets.iterrows()):\n",
404
+ " uid = row[\"Target_UniProt_ID\"]\n",
405
+ " seq = row[\"Target_Sequence\"]\n",
406
+ " print(f\"[{i+1}/{N_TARGETS}] {uid} (len={len(seq)})\")\n",
407
+ " \n",
408
+ " for direction in [\"agonist\", \"antagonist\"]:\n",
409
+ " torch.manual_seed(42)\n",
410
+ " np.random.seed(42)\n",
411
+ " df = generate_binders(seq, direction=direction, num_pool=32, num_keep=8)\n",
412
+ " df[\"target_uid\"] = uid\n",
413
+ " all_results.append(df)\n",
414
+ " \n",
415
+ " d_star = 1.0 if direction == \"agonist\" else -1.0\n",
416
+ " da = df[\"direction_accuracy\"].mean()\n",
417
+ " print(f\" {direction:>10s}: DA={da:.2f} Aff={df['affinity'].mean():.2f} Gated={df['gated_reward'].mean():.2f} valid={df['is_valid'].sum()}/{len(df)}\")\n",
418
+ "\n",
419
+ "combined = pd.concat(all_results, ignore_index=True)\n",
420
+ "\n",
421
+ "print(f\"\\n{'='*60}\")\n",
422
+ "print(f\"AGGREGATE METRICS ({N_TARGETS} targets)\")\n",
423
+ "print(f\"{'='*60}\")\n",
424
+ "for d_name, d_val in [(\"Agonist (d*=+1)\", \"agonist\"), (\"Antagonist (d*=-1)\", \"antagonist\")]:\n",
425
+ " sub = combined[combined[\"direction\"] == d_val]\n",
426
+ " valid = sub[sub[\"is_valid\"] == True]\n",
427
+ " print(f\" {d_name}:\")\n",
428
+ " print(f\" Affinity: {sub['affinity'].mean():.2f}\")\n",
429
+ " print(f\" Direction Accuracy: {sub['direction_accuracy'].mean():.3f}\")\n",
430
+ " print(f\" Gated Reward (all): {sub['gated_reward'].mean():.2f}\")\n",
431
+ " if len(valid) > 0:\n",
432
+ " print(f\" Gated Reward (valid): {valid['gated_reward'].mean():.2f}\")\n",
433
+ " print(f\" Valid: {sub['is_valid'].sum()}/{len(sub)}\")\n",
434
+ "\n",
435
+ "# Save\n",
436
+ "combined.to_csv(\"td3b_demo_results.csv\", index=False)\n",
437
+ "print(f\"\\nResults saved to td3b_demo_results.csv\")"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "source": [
443
+ "## Citation\n",
444
+ "\n",
445
+ "```bibtex\n",
446
+ "@article{caotd3b,\n",
447
+ " title={TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation},\n",
448
+ " author={Cao, Hanqun and Pal, Aastha and Tang, Sophia and Zhang, Yinuo and Zhang, Jingjie and Heng, Pheng-Ann and Chatterjee, Pranam}\n",
449
+ "}\n",
450
+ "```"
451
+ ],
452
+ "metadata": {}
453
+ }
454
+ ]
455
+ }