File size: 70,899 Bytes
e9e75df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "b01cd0d5",
      "metadata": {
        "id": "b01cd0d5"
      },
      "source": [
        "# Start using Chroma\n",
        "\n",
        "<div style=\"border:2px solid #f7e097; padding:10px; margin-top:20px; background-color:#fefcd5; border-radius: 5px;\">\n",
        "    🔑 <b>Note:</b> To generate proteins with Chroma, you'll need an API key from <a href=\"https://chroma-weights.generatebiomedicines.com\">chroma-weights.generatebiomedicines.com</a>. Execute the cell below and enter your key after accepting the license.\n",
        "</div>\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "c6db90e2",
      "metadata": {
        "id": "c6db90e2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Note: you may need to restart the kernel to use updated packages.\n"
          ]
        }
      ],
      "source": [
        "import locale\n",
        "locale.getpreferredencoding = lambda: \"UTF-8\"\n",
        "%pip install generate-chroma > /dev/null 2>&1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "f15f2198",
      "metadata": {
        "id": "f15f2198"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f8e1ed8ac9014799ad87d5e27c84b5c3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": []
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import torch\n",
        "from chroma import Chroma, Protein, conditioners, api\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "api.register_key(input(\"2cdade6d058b4fd1b85fa5badb501312\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f46c2848",
      "metadata": {
        "id": "f46c2848"
      },
      "source": [
        "To generate protein samples with Chroma, initialize the model and call the sample method. The sample method generates a protein backbone, designs a sequence, and returns a `Protein` object."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "9728242e",
      "metadata": {
        "id": "9728242e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt\n",
            "Loaded from cache\n",
            "Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt\n",
            "Loaded from cache\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d52be95afc3b4846a21ef28ccae8729b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Integrating SDE:   0%|          | 0/500 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "ename": "RuntimeError",
          "evalue": "CUDA out of memory. Tried to allocate 30.00 MiB (GPU 0; 23.69 GiB total capacity; 127.36 MiB already allocated; 13.19 MiB free; 140.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[3], line 5\u001b[0m\n\u001b[1;32m      2\u001b[0m chroma \u001b[38;5;241m=\u001b[39m Chroma()\n\u001b[1;32m      4\u001b[0m \u001b[38;5;66;03m# Sample a Protein\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m protein \u001b[38;5;241m=\u001b[39m \u001b[43mchroma\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/chroma.py:229\u001b[0m, in \u001b[0;36mChroma.sample\u001b[0;34m(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, design_ban_S, design_method, design_selection, design_t, temperature_S, temperature_chi, top_p_S, regularization, potts_mcmc_depth, potts_proposal, potts_symmetry_order, verbose)\u001b[0m\n\u001b[1;32m    226\u001b[0m design_kwargs \u001b[38;5;241m=\u001b[39m {k: input_args[k] \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m input_args \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m design_keys}\n\u001b[1;32m    228\u001b[0m \u001b[38;5;66;03m# Perform Sampling\u001b[39;00m\n\u001b[0;32m--> 229\u001b[0m sample_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbackbone_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m full_output:\n\u001b[1;32m    232\u001b[0m     protein_sample, output_dictionary \u001b[38;5;241m=\u001b[39m sample_output\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/chroma.py:355\u001b[0m, in \u001b[0;36mChroma._sample\u001b[0;34m(self, samples, steps, chain_lengths, tspan, protein_init, conditioner, langevin_factor, langevin_isothermal, inverse_temperature, initialize_noise, integrate_func, sde_func, trajectory_length, full_output, **kwargs)\u001b[0m\n\u001b[1;32m    352\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    353\u001b[0m     X_unc, C_unc, S_unc \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_backbones(samples, chain_lengths)\n\u001b[0;32m--> 355\u001b[0m outs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackbone_network\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample_sde\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    356\u001b[0m \u001b[43m    \u001b[49m\u001b[43mC_unc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    357\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX_init\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mX_unc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    358\u001b[0m \u001b[43m    \u001b[49m\u001b[43mconditioner\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconditioner\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    359\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtspan\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtspan\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    360\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlangevin_isothermal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlangevin_isothermal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    361\u001b[0m \u001b[43m    \u001b[49m\u001b[43mintegrate_func\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mintegrate_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    362\u001b[0m \u001b[43m    \u001b[49m\u001b[43msde_func\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msde_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    363\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlangevin_factor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlangevin_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    364\u001b[0m \u001b[43m    \u001b[49m\u001b[43minverse_temperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minverse_temperature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    365\u001b[0m \u001b[43m    \u001b[49m\u001b[43mN\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    366\u001b[0m \u001b[43m    \u001b[49m\u001b[43minitialize_noise\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minitialize_noise\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    367\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    368\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    370\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m S_unc\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m outs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mshape:\n\u001b[1;32m    371\u001b[0m     S \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros_like(outs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC\u001b[39m\u001b[38;5;124m\"\u001b[39m])\u001b[38;5;241m.\u001b[39mlong()\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/graph_backbone.py:187\u001b[0m, in \u001b[0;36mGraphBackbone.__init__.<locals>.<lambda>\u001b[0;34m(C, **kwargs)\u001b[0m\n\u001b[1;32m    185\u001b[0m \u001b[38;5;66;03m# Wrap sampling functions\u001b[39;00m\n\u001b[1;32m    186\u001b[0m _X0_func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m X, C, t: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdenoise(X, C, t)\n\u001b[0;32m--> 187\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msample_sde \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnoise_perturb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample_sde\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    188\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_X0_func\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m    189\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    190\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msample_baoab \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39msample_baoab(\n\u001b[1;32m    191\u001b[0m     _X0_func, C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    192\u001b[0m )\n\u001b[1;32m    193\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msample_ode \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39msample_ode(\n\u001b[1;32m    194\u001b[0m     _X0_func, C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    195\u001b[0m )\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/torch/autograd/grad_mode.py:28\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m     26\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     27\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m():\n\u001b[0;32m---> 28\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/diffusion.py:1213\u001b[0m, in \u001b[0;36mDiffusionChainCov.sample_sde\u001b[0;34m(self, X0_func, C, X_init, conditioner, N, tspan, inverse_temperature, langevin_factor, langevin_isothermal, sde_func, integrate_func, initialize_noise, remap_time, remove_drift_translate, remove_noise_translate, align_X0)\u001b[0m\n\u001b[1;32m   1210\u001b[0m     Ct \u001b[38;5;241m=\u001b[39m C\n\u001b[1;32m   1212\u001b[0m \u001b[38;5;66;03m# Integrate\u001b[39;00m\n\u001b[0;32m-> 1213\u001b[0m X_trajectory \u001b[38;5;241m=\u001b[39m \u001b[43mintegrate_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43msdefun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_init\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtspan\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mT_grid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mT_grid\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1215\u001b[0m \u001b[38;5;66;03m# Return constrained coordinates\u001b[39;00m\n\u001b[1;32m   1216\u001b[0m outputs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m   1217\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC\u001b[39m\u001b[38;5;124m\"\u001b[39m: Ct,\n\u001b[1;32m   1218\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_sample\u001b[39m\u001b[38;5;124m\"\u001b[39m: Xt_trajectory[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m],\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1221\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mXunc_trajectory\u001b[39m\u001b[38;5;124m\"\u001b[39m: X_trajectory,\n\u001b[1;32m   1222\u001b[0m }\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/sde.py:64\u001b[0m, in \u001b[0;36msde_integrate\u001b[0;34m(sde_func, y0, tspan, N, project_func, T_grid)\u001b[0m\n\u001b[1;32m     61\u001b[0m t \u001b[38;5;241m=\u001b[39m t0\n\u001b[1;32m     62\u001b[0m dT \u001b[38;5;241m=\u001b[39m t1 \u001b[38;5;241m-\u001b[39m t0\n\u001b[0;32m---> 64\u001b[0m f, gZ \u001b[38;5;241m=\u001b[39m \u001b[43msde_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     65\u001b[0m y \u001b[38;5;241m=\u001b[39m y \u001b[38;5;241m+\u001b[39m dT \u001b[38;5;241m*\u001b[39m f \u001b[38;5;241m+\u001b[39m dT\u001b[38;5;241m.\u001b[39mabs()\u001b[38;5;241m.\u001b[39msqrt() \u001b[38;5;241m*\u001b[39m gZ\n\u001b[1;32m     66\u001b[0m y \u001b[38;5;241m=\u001b[39m y \u001b[38;5;28;01mif\u001b[39;00m project_func \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m project_func(t, y)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/diffusion.py:1174\u001b[0m, in \u001b[0;36mDiffusionChainCov.sample_sde.<locals>.sdefun\u001b[0;34m(_t, _X)\u001b[0m\n\u001b[1;32m   1173\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msdefun\u001b[39m(_t, _X):\n\u001b[0;32m-> 1174\u001b[0m     f, gZ \u001b[38;5;241m=\u001b[39m \u001b[43msde_func\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1175\u001b[0m \u001b[43m        \u001b[49m\u001b[43m_X\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1176\u001b[0m \u001b[43m        \u001b[49m\u001b[43m_X0_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1177\u001b[0m \u001b[43m        \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1178\u001b[0m \u001b[43m        \u001b[49m\u001b[43m_t\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1179\u001b[0m \u001b[43m        \u001b[49m\u001b[43mconditioner\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconditioner\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1180\u001b[0m \u001b[43m        \u001b[49m\u001b[43minverse_temperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minverse_temperature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1181\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlangevin_factor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlangevin_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1182\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlangevin_isothermal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlangevin_isothermal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1183\u001b[0m \u001b[43m        \u001b[49m\u001b[43malign_X0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malign_X0\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1184\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1185\u001b[0m     \u001b[38;5;66;03m# Remove net translational component\u001b[39;00m\n\u001b[1;32m   1186\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m remove_drift_translate:\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/data/xcs.py:114\u001b[0m, in \u001b[0;36mvalidate_XCS.<locals>.decorator.<locals>.new_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m), tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m\"\u001b[39m]):\n\u001b[1;32m    113\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS and O are both provided but don\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt match!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/diffusion.py:785\u001b[0m, in \u001b[0;36mDiffusionChainCov.reverse_sde\u001b[0;34m(self, X, X0_func, C, t, conditioner, Z, inverse_temperature, langevin_factor, langevin_isothermal, align_X0)\u001b[0m\n\u001b[1;32m    782\u001b[0m Z \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn_like(X) \u001b[38;5;28;01mif\u001b[39;00m Z \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m Z\n\u001b[1;32m    784\u001b[0m \u001b[38;5;66;03m# X = backbone.center_X(X, C)\u001b[39;00m\n\u001b[0;32m--> 785\u001b[0m score \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscore\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX0_func\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconditioner\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_X0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malign_X0\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    786\u001b[0m score_transformed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_gaussian\u001b[38;5;241m.\u001b[39mmultiply_covariance(score, C)\n\u001b[1;32m    788\u001b[0m f \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m    789\u001b[0m     beta \u001b[38;5;241m*\u001b[39m (\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m*\u001b[39m backbone\u001b[38;5;241m.\u001b[39mcenter_X(X, C)\n\u001b[1;32m    790\u001b[0m     \u001b[38;5;241m-\u001b[39m g\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m*\u001b[39m score_scale_t \u001b[38;5;241m*\u001b[39m score_transformed\n\u001b[1;32m    791\u001b[0m )\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/data/xcs.py:114\u001b[0m, in \u001b[0;36mvalidate_XCS.<locals>.decorator.<locals>.new_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m), tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m\"\u001b[39m]):\n\u001b[1;32m    113\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS and O are both provided but don\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt match!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/diffusion.py:952\u001b[0m, in \u001b[0;36mDiffusionChainCov.score\u001b[0;34m(self, X, X0_func, C, t, conditioner, detach_X0, align_X0, U_traj)\u001b[0m\n\u001b[1;32m    949\u001b[0m U_conditioner \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mas_tensor(U_conditioner)\n\u001b[1;32m    951\u001b[0m \u001b[38;5;66;03m# Compute system energy\u001b[39;00m\n\u001b[0;32m--> 952\u001b[0m U_diffusion \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menergy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    953\u001b[0m \u001b[43m    \u001b[49m\u001b[43mXt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX0_func\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdetach_X0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdetach_X0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_X0\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malign_X0\u001b[49m\n\u001b[1;32m    954\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    956\u001b[0m U_traj\u001b[38;5;241m.\u001b[39mappend(U_diffusion\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mcpu())\n\u001b[1;32m    958\u001b[0m \u001b[38;5;66;03m# Compute score function as negative energy gradient\u001b[39;00m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/data/xcs.py:114\u001b[0m, in \u001b[0;36mvalidate_XCS.<locals>.decorator.<locals>.new_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m), tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m\"\u001b[39m]):\n\u001b[1;32m    113\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS and O are both provided but don\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt match!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/diffusion.py:892\u001b[0m, in \u001b[0;36mDiffusionChainCov.energy\u001b[0;34m(self, X, X0_func, C, t, detach_X0, align_X0)\u001b[0m\n\u001b[1;32m    890\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m detach_X0:\n\u001b[1;32m    891\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 892\u001b[0m         X0 \u001b[38;5;241m=\u001b[39m \u001b[43mX0_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    893\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    894\u001b[0m     X0 \u001b[38;5;241m=\u001b[39m X0_func(X, C, t\u001b[38;5;241m=\u001b[39mt)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/diffusion.py:1168\u001b[0m, in \u001b[0;36mDiffusionChainCov.sample_sde.<locals>._X0_func\u001b[0;34m(_X, _C, t)\u001b[0m\n\u001b[1;32m   1167\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_X0_func\u001b[39m(_X, _C, t):\n\u001b[0;32m-> 1168\u001b[0m     _X0 \u001b[38;5;241m=\u001b[39m \u001b[43mX0_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_X\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_C\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1169\u001b[0m     Xt_trajectory\u001b[38;5;241m.\u001b[39mappend(_X\u001b[38;5;241m.\u001b[39mdetach())\n\u001b[1;32m   1170\u001b[0m     Xhat_trajectory\u001b[38;5;241m.\u001b[39mappend(_X0\u001b[38;5;241m.\u001b[39mdetach())\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/graph_backbone.py:186\u001b[0m, in \u001b[0;36mGraphBackbone.__init__.<locals>.<lambda>\u001b[0;34m(X, C, t)\u001b[0m\n\u001b[1;32m    181\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmlp_W \u001b[38;5;241m=\u001b[39m graph\u001b[38;5;241m.\u001b[39mMLP(\n\u001b[1;32m    182\u001b[0m         dim_in\u001b[38;5;241m=\u001b[39margs\u001b[38;5;241m.\u001b[39mdim_nodes, num_layers_hidden\u001b[38;5;241m=\u001b[39margs\u001b[38;5;241m.\u001b[39mnode_mlp_layers, dim_out\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    183\u001b[0m     )\n\u001b[1;32m    185\u001b[0m \u001b[38;5;66;03m# Wrap sampling functions\u001b[39;00m\n\u001b[0;32m--> 186\u001b[0m _X0_func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m X, C, t: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdenoise\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    187\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msample_sde \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39msample_sde(\n\u001b[1;32m    188\u001b[0m     _X0_func, C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    189\u001b[0m )\n\u001b[1;32m    190\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msample_baoab \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39msample_baoab(\n\u001b[1;32m    191\u001b[0m     _X0_func, C, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    192\u001b[0m )\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/data/xcs.py:114\u001b[0m, in \u001b[0;36mvalidate_XCS.<locals>.decorator.<locals>.new_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m), tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m\"\u001b[39m]):\n\u001b[1;32m    113\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS and O are both provided but don\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt match!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/graph_backbone.py:239\u001b[0m, in \u001b[0;36mGraphBackbone.denoise\u001b[0;34m(self, X, C, t, return_geometry)\u001b[0m\n\u001b[1;32m    235\u001b[0m X_update \u001b[38;5;241m=\u001b[39m X\n\u001b[1;32m    237\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_graph_cycles):\n\u001b[1;32m    238\u001b[0m     \u001b[38;5;66;03m# Encode as graph\u001b[39;00m\n\u001b[0;32m--> 239\u001b[0m     node_h, edge_h, edge_idx, mask_i, mask_ij \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoders\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    240\u001b[0m \u001b[43m        \u001b[49m\u001b[43mX_update\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    241\u001b[0m \u001b[43m        \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    242\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnode_h_aux\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnode_h\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    243\u001b[0m \u001b[43m        \u001b[49m\u001b[43medge_h_aux\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_h\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    244\u001b[0m \u001b[43m        \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    245\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmask_ij\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmask_ij\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    246\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    247\u001b[0m     \u001b[38;5;66;03m# Update backbone\u001b[39;00m\n\u001b[1;32m    248\u001b[0m     X_update, R_ji, t_ji, logit_ji \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbackbone_updates[i](\n\u001b[1;32m    249\u001b[0m         X_update, C, node_h, edge_h, edge_idx, mask_i, mask_ij\n\u001b[1;32m    250\u001b[0m     )\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/data/xcs.py:114\u001b[0m, in \u001b[0;36mvalidate_XCS.<locals>.decorator.<locals>.new_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mallclose(tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mO\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m), tensors[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS\u001b[39m\u001b[38;5;124m\"\u001b[39m]):\n\u001b[1;32m    113\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS and O are both provided but don\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt match!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/graph_design.py:1237\u001b[0m, in \u001b[0;36mBackboneEncoderGNN.forward\u001b[0;34m(self, X, C, node_h_aux, edge_h_aux, edge_idx, mask_ij)\u001b[0m\n\u001b[1;32m   1234\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheckpoint_gradients \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m X\u001b[38;5;241m.\u001b[39mrequires_grad):\n\u001b[1;32m   1235\u001b[0m     X\u001b[38;5;241m.\u001b[39mrequires_grad \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m-> 1237\u001b[0m node_h, edge_h, edge_idx, mask_i, mask_ij \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_checkpoint\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1238\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfeature_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmask_ij\u001b[49m\n\u001b[1;32m   1239\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1241\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m node_h_aux \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1242\u001b[0m     node_h \u001b[38;5;241m=\u001b[39m node_h \u001b[38;5;241m+\u001b[39m mask_i\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m*\u001b[39m node_h_aux\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/graph_design.py:1251\u001b[0m, in \u001b[0;36mBackboneEncoderGNN._checkpoint\u001b[0;34m(self, module, *args)\u001b[0m\n\u001b[1;32m   1249\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_checkpoint\u001b[39m(\u001b[38;5;28mself\u001b[39m, module: nn\u001b[38;5;241m.\u001b[39mModule, \u001b[38;5;241m*\u001b[39margs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m nn\u001b[38;5;241m.\u001b[39mModule:\n\u001b[1;32m   1250\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheckpoint_gradients:\n\u001b[0;32m-> 1251\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcheckpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1252\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1253\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m module(\u001b[38;5;241m*\u001b[39margs)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/torch/utils/checkpoint.py:211\u001b[0m, in \u001b[0;36mcheckpoint\u001b[0;34m(function, *args, **kwargs)\u001b[0m\n\u001b[1;32m    208\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m kwargs:\n\u001b[1;32m    209\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected keyword arguments: \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(arg \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m kwargs))\n\u001b[0;32m--> 211\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mCheckpointFunction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreserve\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/torch/utils/checkpoint.py:90\u001b[0m, in \u001b[0;36mCheckpointFunction.forward\u001b[0;34m(ctx, run_function, preserve_rng_state, *args)\u001b[0m\n\u001b[1;32m     87\u001b[0m ctx\u001b[38;5;241m.\u001b[39msave_for_backward(\u001b[38;5;241m*\u001b[39mtensor_inputs)\n\u001b[1;32m     89\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 90\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mrun_function\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     91\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/protein_graph.py:232\u001b[0m, in \u001b[0;36mProteinFeatureGraph.forward\u001b[0;34m(self, X, C, edge_idx, mask_ij, custom_D, custom_mask_2D)\u001b[0m\n\u001b[1;32m    230\u001b[0m edge_h \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    231\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39medge_layers):\n\u001b[0;32m--> 232\u001b[0m     edge_h_l \u001b[38;5;241m=\u001b[39m \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    233\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcentered:\n\u001b[1;32m    234\u001b[0m         edge_h_l \u001b[38;5;241m=\u001b[39m edge_h_l \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__getattr__\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medge_means_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/protein_graph.py:988\u001b[0m, in \u001b[0;36mEdgeDistance2mer.forward\u001b[0;34m(self, X, edge_idx, C)\u001b[0m\n\u001b[1;32m    986\u001b[0m shape_flat \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(D_ij\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m3\u001b[39m]) \u001b[38;5;241m+\u001b[39m [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m    987\u001b[0m D_ij \u001b[38;5;241m=\u001b[39m D_ij\u001b[38;5;241m.\u001b[39mreshape(shape_flat)\n\u001b[0;32m--> 988\u001b[0m feature_ij \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfeaturize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mD_ij\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    990\u001b[0m \u001b[38;5;66;03m# DEBGUG\u001b[39;00m\n\u001b[1;32m    991\u001b[0m \u001b[38;5;66;03m# _debug_plot_edges(edge_idx, feature_ij, unravel=True)\u001b[39;00m\n\u001b[1;32m    992\u001b[0m \u001b[38;5;66;03m# exit(0)\u001b[39;00m\n\u001b[1;32m    993\u001b[0m edge_h \u001b[38;5;241m=\u001b[39m mask_ij \u001b[38;5;241m*\u001b[39m feature_ij\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/protein_graph.py:973\u001b[0m, in \u001b[0;36mEdgeDistance2mer.featurize\u001b[0;34m(self, D)\u001b[0m\n\u001b[1;32m    971\u001b[0m h_list \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    972\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m feature \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures:\n\u001b[0;32m--> 973\u001b[0m     h \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfeature_funcs\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfeature\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mD\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    974\u001b[0m     h_list\u001b[38;5;241m.\u001b[39mappend(h)\n\u001b[1;32m    975\u001b[0m h \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(h_list, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/protein_graph.py:967\u001b[0m, in \u001b[0;36mEdgeDistance2mer.__init__.<locals>.<lambda>\u001b[0;34m(D)\u001b[0m\n\u001b[1;32m    960\u001b[0m \u001b[38;5;66;03m# Public attribute\u001b[39;00m\n\u001b[1;32m    961\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdim_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m([feature_dims[d] \u001b[38;5;28;01mfor\u001b[39;00m d \u001b[38;5;129;01min\u001b[39;00m features])\n\u001b[1;32m    963\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeature_funcs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m    964\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlog\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mlambda\u001b[39;00m D: torch\u001b[38;5;241m.\u001b[39mlog(D \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdistance_eps),\n\u001b[1;32m    965\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minverse\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mlambda\u001b[39;00m D: \u001b[38;5;241m1.0\u001b[39m \u001b[38;5;241m/\u001b[39m (D \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdistance_eps),\n\u001b[1;32m    966\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mraw\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mlambda\u001b[39;00m D: D,\n\u001b[0;32m--> 967\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrbf\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mlambda\u001b[39;00m D: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrbf_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mD\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m    968\u001b[0m }\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1098\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1099\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1103\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1104\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/protein_graph.py:1439\u001b[0m, in \u001b[0;36mRBFExpansion.forward\u001b[0;34m(self, h)\u001b[0m\n\u001b[1;32m   1437\u001b[0m shape_ones \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(shape))] \u001b[38;5;241m+\u001b[39m [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m   1438\u001b[0m rbf_centers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrbf_centers\u001b[38;5;241m.\u001b[39mview(shape_ones)\n\u001b[0;32m-> 1439\u001b[0m h \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mexp(\u001b[38;5;241m-\u001b[39m(((\u001b[43mh\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mrbf_centers\u001b[49m) \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstd) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m))\n\u001b[1;32m   1440\u001b[0m h \u001b[38;5;241m=\u001b[39m h\u001b[38;5;241m.\u001b[39mview(shape[:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m+\u001b[39m [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m   1441\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m h\n",
            "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 0; 23.69 GiB total capacity; 127.36 MiB already allocated; 13.19 MiB free; 140.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
          ]
        }
      ],
      "source": [
        "# Initialize the Model\n",
        "chroma = Chroma()\n",
        "\n",
        "# Sample a Protein\n",
        "protein = chroma.sample()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "95f44aa7",
      "metadata": {
        "id": "95f44aa7"
      },
      "source": [
        "The `Protein` object enables one line inspection, saving, and loading of proteins."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cd620ec7",
      "metadata": {
        "id": "cd620ec7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Protein: system\n",
            "> Chain A (100 residues)\n",
            "MKSIEEKLKEIIDKAKELGCDDCANRLKQVLDEIKRNKENKCEAYKKAIDALKSIVDELERRAQELASRDPELGKQAREQVENIKKEIDELIKEIKKSCA\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "print(protein) # Inspect the sequence of the protein sample\n",
        "protein.to('chroma_sample.cif') # Save the sample to disk\n",
        "protein = Protein('chroma_sample.cif') # Load a protein from disk"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "93c6ea6b",
      "metadata": {
        "id": "93c6ea6b"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "136cf83c3a85456ab08156c857b5f64e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "NGLWidget()"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "display(protein)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ab969fef",
      "metadata": {
        "id": "ab969fef"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "47fc2f8daa9742de98ced4285e0164d9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Integrating diffusion metrics:   0%|          | 0/50 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "ename": "TypeError",
          "evalue": "unsupported operand type(s) for |: 'dict' and 'dict'",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[5], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Calculate sample scores\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m elbo \u001b[38;5;241m=\u001b[39m \u001b[43mchroma\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscore\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprotein\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124melbo\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mscore\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msample elbo: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00melbo\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/models/chroma.py:722\u001b[0m, in \u001b[0;36mChroma.score\u001b[0;34m(self, proteins, num_samples, tspan)\u001b[0m\n\u001b[1;32m    720\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    721\u001b[0m     sequence_scores[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt_seq\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m sequence_scores\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 722\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbackbone_scores\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m|\u001b[39;49m\u001b[43m \u001b[49m\u001b[43msequence_scores\u001b[49m\n",
            "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for |: 'dict' and 'dict'"
          ]
        }
      ],
      "source": [
        "# Calculate sample scores\n",
        "elbo = chroma.score(protein)['elbo'].score\n",
        "print(f'sample elbo: {elbo}')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "06688f25",
      "metadata": {
        "id": "06688f25"
      },
      "source": [
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1be29dbb",
      "metadata": {
        "id": "1be29dbb",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "# Conditioning"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "26d19c04",
      "metadata": {
        "id": "26d19c04"
      },
      "source": [
        "Chroma conditioners allow us to program proteins. In the following examples we will show conditional generation for `Infilling`, `Symmetry`, `Shape`, `Protein Classes`, and `Natural Language`."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "baf8ed65",
      "metadata": {
        "id": "baf8ed65",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "## Symmetry\n",
        "\n",
        "Chroma can generate symmetric proteins with the help of the symmetry conditioner. We demonstrate a minimal example of conditioning on the cyclic point group with a 7-fold rotation axis. This point group has 7 asymmetric units arranged in a circle. The subunits are of size 50 in this example. The following parameters can be adjusted below:\n",
        "\n",
        "* `SYMMETRY_GROUP`: symmetry group, choose from {'C_2', 'C_3', ..., \"D_2\", \"D_3\", ..., \"T\", \"O\", \"I\"}\n",
        "* `SUBUNIT_SIZES`: chain lengths for asymmetric unit: e.g [100], [100, 150], more than one chain is allowed for the asymmetric unit\n",
        "* `KNBR`: number of neighbors to pay attention to during sampling. max allowed is total number of asymetric units in the protein complex - 1.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4e060bd9",
      "metadata": {
        "id": "4e060bd9",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "SYMMETRY_GROUP = \"C_7\"\n",
        "SUBUNIT_SIZES = [100]\n",
        "KNBR = 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "87903f39",
      "metadata": {
        "id": "87903f39",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Draw a Sample\n",
        "torch.manual_seed(0)\n",
        "conditioner = conditioners.SymmetryConditioner(G=SYMMETRY_GROUP, num_chain_neighbors=KNBR)\n",
        "symmetric_protein = chroma.sample(\n",
        "    chain_lengths=SUBUNIT_SIZES,\n",
        "    conditioner=conditioner,\n",
        "    langevin_factor=8,\n",
        "    inverse_temperature=8,\n",
        "    sde_func=\"langevin\",\n",
        "    potts_symmetry_order=conditioner.potts_symmetry_order)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2f5d7eef",
      "metadata": {
        "id": "2f5d7eef",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "display(symmetric_protein)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "20bcee17",
      "metadata": {
        "id": "20bcee17",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "## Infilling\n",
        "\n",
        "Many protein design tasks including imputation of missing structural data, redesign of an enzyme scaffold given an active site, or redesign of the CDRs of a known antibody framework require exact specification of the known structural coordinates. The substructure conditioner enables this type of design. By specifiying the set of residues that are designable, and a protein to redesign, the user can perform infilling. In this example, a plane split is used which cuts a protein into two portions, a designable portion and a fixed portion. The following parameters can be set by the user:\n",
        "\n",
        "* `MASK_FRACTION`: the fraction of the protein to redesign.\n",
        "* `PDB_ID`: The pdb to use for a infilling. There are also a set of `TESTED_PDBS` that you can use as examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4f2478d5",
      "metadata": {
        "id": "4f2478d5"
      },
      "outputs": [],
      "source": [
        "TESTED_PDBS = ['3bdi', '5sv5','6qaz','2e0q','5xb0','6bde','1a8q','5o0t','1drf','1shg']\n",
        "MASK_PERCENT = 0.5 # Allow about 50% of the Protein to be designed\n",
        "PDB_ID = TESTED_PDBS[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "59e2d4b8",
      "metadata": {
        "id": "59e2d4b8"
      },
      "outputs": [],
      "source": [
        "# Configure Substructure Conditioner\n",
        "from chroma.utility.chroma import plane_split_protein\n",
        "protein = Protein(PDB_ID, canonicalize=True, device=device)\n",
        "\n",
        "X, C, _ = protein.to_XCS()\n",
        "residues_to_design = plane_split_protein(X, C, protein, 0.5).nonzero()[:,1].tolist()\n",
        "protein.sys.save_selection(gti=residues_to_design, selname=\"infilling_selection\")\n",
        "\n",
        "conditioner = conditioners.SubstructureConditioner(\n",
        "        protein,\n",
        "        backbone_model=chroma.backbone_network,\n",
        "        selection = 'namesel infilling_selection').to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c03aeb4b",
      "metadata": {
        "id": "c03aeb4b"
      },
      "outputs": [],
      "source": [
        "# Draw a Sample\n",
        "torch.manual_seed(0)\n",
        "infilled_protein = chroma.sample(\n",
        "             protein_init=protein,\n",
        "             conditioner=conditioner,\n",
        "             langevin_factor=4.0,\n",
        "             langevin_isothermal=True,\n",
        "             inverse_temperature=8.0,\n",
        "             sde_func='langevin',\n",
        "             steps=500)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f2f8d76f",
      "metadata": {
        "id": "f2f8d76f"
      },
      "outputs": [],
      "source": [
        "display(infilled_protein)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "269d8f5e",
      "metadata": {
        "id": "269d8f5e",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "## Shape\n",
        "\n",
        "The shape conditioner enforces adherance to a predefined volumetric shape as represented by a point cloud. In the below example we use the Python Imaging Library to render a 3D point cloud from letters, and then we use the ShapeConditioner to sample backbones consistent with this point cloud. The user can set hyperparameters and vary the letter and the number of residues. For faster feedback, the number of steps has been decreased from that used in the manuscript. In this example both the choice of `LETTER` and the number of protein residues that fill the point cloud.\n",
        " * `LETTER`: a single character string containing the letter that will be made by the conditioner.\n",
        " * `NUM_RESIDUES`: the number of protein residues to fill the point cloud.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4736bb86",
      "metadata": {
        "id": "4736bb86"
      },
      "outputs": [],
      "source": [
        "LETTER = \"G\"\n",
        "NUM_RESIDUES = 1000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2a2e109a",
      "metadata": {
        "id": "2a2e109a",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Configure Shape Conditioner\n",
        "from chroma.utility.chroma import letter_to_point_cloud\n",
        "letter_point_cloud = letter_to_point_cloud(LETTER)\n",
        "\n",
        "conditioner = conditioners.ShapeConditioner(\n",
        "        letter_point_cloud,\n",
        "        chroma.backbone_network.noise_schedule,\n",
        "        autoscale_num_residues=NUM_RESIDUES).to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "adbe3e70",
      "metadata": {
        "id": "adbe3e70",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Draw a Sample\n",
        "torch.manual_seed(0)\n",
        "shaped_protein = chroma.sample(chain_lengths=[NUM_RESIDUES], conditioner=conditioner)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "111c8b74",
      "metadata": {
        "id": "111c8b74"
      },
      "outputs": [],
      "source": [
        "display(shaped_protein)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "27e1fcce",
      "metadata": {
        "id": "27e1fcce",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "## CATH class\n",
        "\n",
        "Proteins can be conditionally generated with specified folds according to CATH class annotations.  This conditioner uses the ProClass Model. Below we show a minimal example conditioning on generating a protein with mostly beta content.\n",
        "\n",
        "The ProClass Conditioner can set CATH class annotations at 3 levels.\n",
        "\n",
        "* `CATH_ANNOTATION`: `X`, e.g. `2` Selects a C level annotation, in this case \"Mostly Beta\"\n",
        "* `CATH_ANNOTATION`: `X.X`, e.g. `2.60` Selects a CA level annotation, in this case \"Sandwich\"\n",
        "* `CATH_ANNOTATION`: `X.X.X` e.g. `2.60.40` Selects a CAT level annotation, in this case \"Immunoglobulin-like\"\n",
        "\n",
        "In general C level annotations are most robust.  CA and CAT level annotations typically require many more samples to get good results. See the paper experiments for details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "48107807",
      "metadata": {
        "id": "48107807",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "CATH_ANNOTATION = '2.60.40'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "24d5e164",
      "metadata": {
        "id": "24d5e164",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Draw a Sample\n",
        "torch.manual_seed(0)\n",
        "conditioner = conditioners.ProClassConditioner('cath', CATH_ANNOTATION)\n",
        "cath_conditioned_protein = chroma.sample(conditioner=conditioner)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2fa3492b",
      "metadata": {
        "id": "2fa3492b",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "display(cath_conditioned_protein)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4f09cb2e",
      "metadata": {
        "id": "4f09cb2e",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "## Natural language"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "af8201d4",
      "metadata": {
        "id": "af8201d4"
      },
      "source": [
        "Here, we demonstrate backbone generation conditioned on natural language prompts. The sampling is guided by the gradients of a structure to text model. To condition, we define a `ProCapConditioner` with the following inputs:\n",
        "- a caption\n",
        "- the chain ID, specifying the (1-indexed) caption refers to; captions corresponding to the entire protein can be indicated with `chain_id = -1`\n",
        "- the weight with which to use the conditioner\n",
        "\n",
        "Training was performed with individual chain captions drawn from UniProt, and complex-level captions taken from the PDB.\n",
        "\n",
        "Below, we demonstrate caption-guided sampling to obtain a single chain backbone corresponding to an SH2 domain."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8d9dfce9",
      "metadata": {
        "id": "8d9dfce9",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "CAPTION = \"Crystal structure of SH2 domain\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "05e02265",
      "metadata": {
        "id": "05e02265",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Draw a Sample\n",
        "torch.manual_seed(0)\n",
        "conditioner = conditioners.ProCapConditioner(CAPTION, -1)\n",
        "caption_conditioned_protein = chroma.sample(steps=200, chain_lengths=[110], conditioner=conditioner)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bb43c861",
      "metadata": {
        "id": "bb43c861",
        "pycharm": {
          "name": "#%%\n"
        },
        "scrolled": true
      },
      "outputs": [],
      "source": [
        "display(caption_conditioned_protein)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}