File size: 34,579 Bytes
446e400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
import json
import os
import shutil
import random
import sys
import time
from typing import List, Tuple, Optional

import Bio.PDB
import Bio.SeqUtils
import pandas as pd
import numpy as np
import requests
from rdkit import Chem
from rdkit.Chem import AllChem


BASE_FOLDER = "/tmp/"

OUTPUT_FOLDER = f"{BASE_FOLDER}/processed"
# https://storage.googleapis.com/plinder/2024-06/v2/index/annotation_table.parquet
PLINDER_ANNOTATIONS = f'{BASE_FOLDER}/raw_data/2024-06_v2_index_annotation_table.parquet'
# https://storage.googleapis.com/plinder/2024-06/v2/splits/split.parquet
PLINDER_SPLITS = f'{BASE_FOLDER}/raw_data/2024-06_v2_splits_split.parquet'

# https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dapo/links.parquet
PLINDER_LINKED_APO_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=apo_links.parquet"
# https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dpred/links.parquet
PLINDER_LINKED_PRED_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=pred_links.parquet"
# https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/apo.zip
PLINDER_LINKED_APO_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_apo"
# https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/pred.zip
PLINDER_LINKED_PRED_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_pred"
GSUTIL_PATH = f"{BASE_FOLDER}/google-cloud-sdk/bin/gsutil"



def get_cached_systems_to_train(recompute=False):
    output_path = os.path.join(OUTPUT_FOLDER, "to_train.pickle")
    if os.path.exists(output_path) and not recompute:
        return pd.read_pickle(output_path)

    """
    full:
loaded 1357906 409726 163816 433865
loaded 990260 409726 125818 106411
joined splits 409726
Has splits 311008
unique systems 311008
split
train    309140
test       1036
val         832
Name: count, dtype: int64
Has affinity 36856
Has affinity by splits split
train    36598
test       142
val        116
Name: count, dtype: int64
Total systems before pred 311008
Total systems after pred 311008
Has pred 83487
Has apo 75127
Has both 51506
Has either 107108
columns Index(['system_id', 'entry_pdb_id', 'ligand_binding_affinity',
       'entry_release_date', 'system_pocket_UniProt',
       'system_num_protein_chains', 'system_num_ligand_chains', 'uniqueness',
       'split', 'cluster', 'cluster_for_val_split',
       'system_pass_validation_criteria', 'system_pass_statistics_criteria',
       'system_proper_num_ligand_chains', 'system_proper_pocket_num_residues',
       'system_proper_num_interactions',
       'system_proper_ligand_max_molecular_weight',
       'system_has_binding_affinity', 'system_has_apo_or_pred', '_bucket_id',
       'linked_pred_id', 'linked_apo_id'],
      dtype='object')
total systems 311008
    """

    systems = pd.read_parquet(PLINDER_ANNOTATIONS,
                              columns=['system_id', 'entry_pdb_id', 'ligand_binding_affinity',
                                       'entry_release_date', 'system_pocket_UniProt', 'entry_resolution',
                                       'system_num_protein_chains', 'system_num_ligand_chains'])
    splits = pd.read_parquet(PLINDER_SPLITS)
    linked_pred = pd.read_parquet(PLINDER_LINKED_PRED_MAP)
    linked_apo = pd.read_parquet(PLINDER_LINKED_APO_MAP)

    print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo))

    # remove duplicated
    systems = systems.drop_duplicates(subset=['system_id'])
    splits = splits.drop_duplicates(subset=['system_id'])
    linked_pred = linked_pred.drop_duplicates(subset=['reference_system_id'])
    linked_apo = linked_apo.drop_duplicates(subset=['reference_system_id'])
    print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo))

    # join splits
    systems = pd.merge(systems, splits, on='system_id', how='inner')
    print("joined splits", len(systems))

    systems['_bucket_id'] = systems['entry_pdb_id'].str[1:3]

    # leave only with train/val/test splits
    systems = systems[systems['split'].isin(['train', 'val', 'test'])]

    print("Has splits", len(systems))
    print("unique systems", systems['system_id'].nunique())
    print(systems["split"].value_counts())

    print("Has affinity", len(systems[systems['ligand_binding_affinity'].notna()]))

    # print has affinity by splits
    print("Has affinity by splits", systems[systems['ligand_binding_affinity'].notna()]['split'].value_counts())

    print("Total systems before pred", len(systems))
    # join linked structures - allow to not have linked structures
    systems = pd.merge(systems, linked_pred[['reference_system_id', 'id']],
                       left_on='system_id', right_on='reference_system_id',
                       how='left')
    print("Total systems after pred", len(systems))

    # Rename the 'id' column from linked_pred to 'linked_pred_id'
    systems.rename(columns={'id': 'linked_pred_id'}, inplace=True)

    # Merge the result with linked_apo on the same condition
    systems = pd.merge(systems, linked_apo[['reference_system_id', 'id']],
                       left_on='system_id', right_on='reference_system_id',
                       how='left')

    # Rename the 'id' column from linked_apo to 'linked_apo_id'
    systems.rename(columns={'id': 'linked_apo_id'}, inplace=True)

    # Drop the reference_system_id columns that were added during the merge
    systems.drop(columns=['reference_system_id_x', 'reference_system_id_y'], inplace=True)

    cluster_sizes = systems["cluster"].value_counts()
    systems["cluster_size"] = systems["cluster"].map(cluster_sizes)
    # print(systems[['system_id', 'cluster', 'cluster_size']])

    print("Has pred", systems['linked_pred_id'].notna().sum())
    print("Has apo", systems['linked_apo_id'].notna().sum())
    print("Has both", (systems['linked_pred_id'].notna() & systems['linked_apo_id'].notna()).sum())
    print("Has either", (systems['linked_pred_id'].notna() | systems['linked_apo_id'].notna()).sum())

    print("columns", systems.columns)

    systems.to_pickle(output_path)
    return systems


def create_conformers(smiles, output_path, num_conformers=100, multiplier_samples=1):
    target_mol = Chem.MolFromSmiles(smiles)
    target_mol = Chem.AddHs(target_mol)

    params = AllChem.ETKDGv3()
    params.numThreads = 0  # Use all available threads
    params.pruneRmsThresh = 0.1  # Pruning threshold for RMSD
    conformer_ids = AllChem.EmbedMultipleConfs(target_mol, numConfs=num_conformers * multiplier_samples, params=params)

    # Optional: Optimize each conformer using MMFF94 force field
    # for conf_id in conformer_ids:
    #     AllChem.UFFOptimizeMolecule(target_mol, confId=conf_id)

    # remove hydrogen atoms
    target_mol = Chem.RemoveHs(target_mol)

    # Save aligned conformers to a file (optional)
    w = Chem.SDWriter(output_path)
    for i, conf_id in enumerate(conformer_ids):
        if i >= num_conformers:
            break
        w.write(target_mol, confId=conf_id)
    w.close()


def do_robust_chain_object_renumber(chain: Bio.PDB.Chain.Chain, new_chain_id: str) -> Optional[Bio.PDB.Chain.Chain]:
    all_residues = [res for res in chain.get_residues()
                    if "CA" in res and Bio.SeqUtils.seq1(res.get_resname()) not in ("X", "", " ")]
    if not all_residues:
        return None

    res_and_res_id = [(res, res.get_id()[1]) for res in all_residues]

    min_res_id = min([i[1] for i in res_and_res_id])
    if min_res_id < 1:
        print("Negative res id", chain, min_res_id)
        factor = -1 * min_res_id + 1
        res_and_res_id = [(res, res_id + factor) for res, res_id in res_and_res_id]

    res_and_res_id_no_collisions = []
    for res, res_id in res_and_res_id[::-1]:
        if res_and_res_id_no_collisions and res_and_res_id_no_collisions[-1][1] == res_id:
            # there is a collision, usually an insertion residue
            res_and_res_id_no_collisions = [(i, j + 1) for i, j in res_and_res_id_no_collisions]
        res_and_res_id_no_collisions.append((res, res_id))

    first_res_id = min([i[1] for i in res_and_res_id_no_collisions])
    factor = 1 - first_res_id  # start from 1
    new_chain = Bio.PDB.Chain.Chain(new_chain_id)

    res_and_res_id_no_collisions.sort(key=lambda x: x[1])

    for res, res_id in res_and_res_id_no_collisions:
        chain.detach_child(res.id)
        res.id = (" ", res_id + factor, " ")
        new_chain.add(res)

    return new_chain


def robust_renumber_protein(pdb_path: str, output_path: str):
    if pdb_path.endswith(".pdb"):
        pdb_parser = Bio.PDB.PDBParser(QUIET=True)
        pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path)
    elif pdb_path.endswith(".cif"):
        pdb_struct = Bio.PDB.MMCIFParser().get_structure("original_pdb", pdb_path)
    else:
        raise ValueError("Unknown file type", pdb_path)
    assert len(list(pdb_struct)) == 1, "can't extract if more than one model"
    model = next(iter(pdb_struct))
    chains = list(model.get_chains())
    new_model = Bio.PDB.Model.Model(0)
    chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
    for chain, chain_id in zip(chains, chain_ids):
        new_chain = do_robust_chain_object_renumber(chain, chain_id)
        if new_chain is None:
            continue
        new_model.add(new_chain)
    new_struct = Bio.PDB.Structure.Structure("renumbered_pdb")
    new_struct.add(new_model)
    io = Bio.PDB.PDBIO()
    io.set_structure(new_struct)
    io.save(output_path)


def _get_extra(extra_to_save: int, res_before: List[int], res_after: List[int]) -> set:
    take_from_before = random.randint(0, extra_to_save)
    take_from_after = extra_to_save - take_from_before
    if take_from_before > len(res_before):
        take_from_after = extra_to_save - len(res_before)
        take_from_before = len(res_before)
    if take_from_after > len(res_after):
        take_from_before = extra_to_save - len(res_after)
        take_from_after = len(res_after)

    extra_to_add = set()
    if take_from_before > 0:
        extra_to_add.update(res_before[-take_from_before:])
    extra_to_add.update(res_after[:take_from_after])

    return extra_to_add


def crop_protein_cont(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int,
                      distance_threshold: float):
    protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False)
    ligand_size = ligand_pos.shape[0]

    pdb_parser = Bio.PDB.PDBParser(QUIET=True)
    gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path)))

    all_res_ids_by_chain = {chain.id: sorted([res.id[1] for res in chain.get_residues() if "CA" in res])
                            for chain in gt_model.get_chains()}

    protein_conf = protein.GetConformer()
    protein_pos = protein_conf.GetPositions()
    protein_atoms = list(protein.GetAtoms())
    assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}"

    inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :]
    inter_dists = np.sqrt((inter_dists ** 2).sum(-1))
    min_inter_dist_per_protein_atom = inter_dists.min(axis=0)

    res_to_save_count = max_length - ligand_size

    used_protein_idx = np.where(min_inter_dist_per_protein_atom < distance_threshold)[0]
    pocket_residues_by_chain = {}
    for idx in used_protein_idx:
        res = protein_atoms[idx].GetPDBResidueInfo()
        if res.GetIsHeteroAtom():
            continue
        if res.GetChainId() not in pocket_residues_by_chain:
            pocket_residues_by_chain[res.GetChainId()] = set()
        # get residue chain
        pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber())

    if not pocket_residues_by_chain:
        print("No pocket residues found")
        return -1

    # print("pocket_residues_by_chain", pocket_residues_by_chain)

    complete_pocket = []
    extended_pocket_per_chain = {}
    for chain_id, pocket_residues in pocket_residues_by_chain.items():
        max_pocket_res = max(pocket_residues)
        min_pocket_res = min(pocket_residues)

        extended_pocket_per_chain[chain_id] = {res_id for res_id in all_res_ids_by_chain[chain_id]
                                               if min_pocket_res <= res_id <= max_pocket_res}
        for res_id in extended_pocket_per_chain[chain_id]:
            complete_pocket.append((chain_id, res_id))

    # print("extended_pocket_per_chain", pocket_residues_by_chain)

    if len(complete_pocket) > res_to_save_count:
        total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()])
        total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()])
        print(f"Too many residues all: {total_res_ids} pocket:{total_pocket_res} {len(complete_pocket)} "
              f"(ligand size: {ligand_size})")
        return -1

    extra_to_save = res_to_save_count - len(complete_pocket)

    # divide extra_to_save between chains
    for chain_id, pocket_residues in extended_pocket_per_chain.items():
        extra_to_save_per_chain = extra_to_save // len(extended_pocket_per_chain)
        res_before = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id < min(pocket_residues)]
        res_after = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id > max(pocket_residues)]
        extra_to_add = _get_extra(extra_to_save_per_chain, res_before, res_after)
        for res_id in extra_to_add:
            complete_pocket.append((chain_id, res_id))

    total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()])
    total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()])
    total_extended_res = sum([len(res_ids) for res_ids in extended_pocket_per_chain.values()])
    print(f"Found valid pocket all: {total_res_ids} pocket:{total_pocket_res} {total_extended_res} "
          f"{len(complete_pocket)} (ligand size: {ligand_size}) extra: {extra_to_save}")
    # print("all_res_ids_by_chain", all_res_ids_by_chain)
    # print("complete_pocket", sorted(complete_pocket))

    res_to_remove = []
    for res in gt_model.get_residues():
        if (res.parent.id, res.id[1]) not in complete_pocket or res.id[0].strip() != "" or res.id[2].strip() != "":
            res_to_remove.append(res)
    for res in res_to_remove:
        gt_model[res.parent.id].detach_child(res.id)

    io = Bio.PDB.PDBIO()
    io.set_structure(gt_model)
    io.save(output_path)

    return len(complete_pocket)


def crop_protein_simple(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int):
    protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False)
    ligand_size = ligand_pos.shape[0]
    res_to_save_count = max_length - ligand_size

    pdb_parser = Bio.PDB.PDBParser(QUIET=True)
    gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path)))

    protein_conf = protein.GetConformer()
    protein_pos = protein_conf.GetPositions()
    protein_atoms = list(protein.GetAtoms())
    assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}"

    inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :]
    inter_dists = np.sqrt((inter_dists ** 2).sum(-1))
    min_inter_dist_per_protein_atom = inter_dists.min(axis=0)

    protein_idx_by_dist = np.argsort(min_inter_dist_per_protein_atom)
    pocket_residues_by_chain = {}
    total_found = 0
    for idx in protein_idx_by_dist:
        res = protein_atoms[idx].GetPDBResidueInfo()
        if res.GetIsHeteroAtom():
            continue

        if res.GetChainId() not in pocket_residues_by_chain:
            pocket_residues_by_chain[res.GetChainId()] = set()
        # get residue chain
        pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber())
        total_found = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()])
        if total_found >= res_to_save_count:
            break
    print("saved with simple", total_found)

    if not pocket_residues_by_chain:
        print("No pocket residues found")
        return -1

    res_to_remove = []
    for res in gt_model.get_residues():
        if res.id[1] not in pocket_residues_by_chain.get(res.parent.id, set()) \
                or res.id[0].strip() != "" or res.id[2].strip() != "":
            res_to_remove.append(res)
    for res in res_to_remove:
        gt_model[res.parent.id].detach_child(res.id)

    io = Bio.PDB.PDBIO()
    io.set_structure(gt_model)
    io.save(output_path)

    return total_found


def cif_to_pdb(cif_path: str, pdb_path: str):
    protein = Bio.PDB.MMCIFParser().get_structure("s_cif", cif_path)
    io = Bio.PDB.PDBIO()
    io.set_structure(protein)
    io.save(pdb_path)


def get_chain_object_to_seq(chain: Bio.PDB.Chain.Chain) -> str:
    res_id_to_res = {res.get_id()[1]: res for res in chain.get_residues() if "CA" in res}

    if len(res_id_to_res) == 0:
        print("skipping empty chain", chain.get_id())
        return ""
    seq = ""
    for i in range(1, max(res_id_to_res) + 1):
        if i in res_id_to_res:
            seq += Bio.SeqUtils.seq1(res_id_to_res[i].get_resname())
        else:
            seq += "X"
    return seq


def get_sequence_from_pdb(pdb_path: str) -> Tuple[str, List[int]]:
    pdb_parser = Bio.PDB.PDBParser(QUIET=True)
    pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path)
    # chain_to_seq = {chain.id: get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()}
    all_chain_seqs = [ get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()]
    chain_lengths = [len(seq) for seq in all_chain_seqs]
    return ("X" * 20).join(all_chain_seqs), chain_lengths


from Bio import PDB
from Bio import pairwise2


def extract_sequence(chain):
    seq = ''
    residues = []
    for res in chain.get_residues():
        seq_res = Bio.SeqUtils.seq1(res.get_resname())
        if seq_res in ('X', "", " "):
            continue
        seq += seq_res
        residues.append(res)
    return seq, residues


def map_residues(alignment, residues_gt, residues_pred):
    idx_gt = 0
    idx_pred = 0
    mapping = []
    for i in range(len(alignment.seqA)):
        aa_gt = alignment.seqA[i]
        aa_pred = alignment.seqB[i]
        res_gt = None
        res_pred = None
        if aa_gt != '-':
            res_gt = residues_gt[idx_gt]
            idx_gt += 1
        if aa_pred != '-':
            res_pred = residues_pred[idx_pred]
            idx_pred +=1
        if res_gt and res_pred:
            mapping.append((res_gt, res_pred))
    return mapping


class ResidueSelect(PDB.Select):
    def __init__(self, residues_to_select):
        self.residues_to_select = set(residues_to_select)

    def accept_residue(self, residue):
        return residue in self.residues_to_select


def align_gt_and_input(gt_pdb_path, input_pdb_path, output_gt_path, output_input_path):
    parser = PDB.PDBParser(QUIET=True)
    gt_structure = parser.get_structure('gt', gt_pdb_path)
    pred_structure = parser.get_structure('pred', input_pdb_path)
    matched_residues_gt = []
    matched_residues_pred = []

    used_chain_pred = []
    total_mapping_size = 0
    for chain_gt in gt_structure.get_chains():
        seq_gt, residues_gt = extract_sequence(chain_gt)
        best_alignment = None
        best_chain_pred = None
        best_score = -1
        best_residues_pred = None
        # Find the best matching chain in pred
        for chain_pred in pred_structure.get_chains():
            print("checking", chain_pred.get_id(), chain_gt.get_id())
            if chain_pred in used_chain_pred:
                continue
            seq_pred, residues_pred = extract_sequence(chain_pred)
            print(seq_gt)
            print(seq_pred)
            alignments = pairwise2.align.globalxx(seq_gt, seq_pred, one_alignment_only=True)
            if not alignments:
                continue
            print("checking2", chain_pred.get_id(), chain_gt.get_id())

            alignment = alignments[0]
            score = alignment.score
            if score > best_score:
                best_score = score
                best_alignment = alignment
                best_chain_pred = chain_pred
                best_residues_pred = residues_pred
        if best_alignment:
            mapping = map_residues(best_alignment, residues_gt, best_residues_pred)
            total_mapping_size += len(mapping)
            used_chain_pred.append(best_chain_pred)
            for res_gt, res_pred in mapping:
                matched_residues_gt.append(res_gt)
                matched_residues_pred.append(res_pred)
        else:
            print(f"No matching chain found for chain {chain_gt.get_id()}")
    print(f"Total mapping size: {total_mapping_size}")

    # Write new PDB files with only matched residues
    io = PDB.PDBIO()
    io.set_structure(gt_structure)
    io.save(output_gt_path, ResidueSelect(matched_residues_gt))
    io.set_structure(pred_structure)
    io.save(output_input_path, ResidueSelect(matched_residues_pred))


def validate_matching_input_gt(gt_pdb_path, input_pdb_path):
    gt_residues = [res for res in PDB.PDBParser().get_structure('gt', gt_pdb_path).get_residues()]
    input_residues = [res for res in PDB.PDBParser().get_structure('input', input_pdb_path).get_residues()]

    if len(gt_residues) != len(input_residues):
        print(f"Residue count mismatch: {len(gt_residues)} vs {len(input_residues)}")
        return -1

    for res_gt, res_input in zip(gt_residues, input_residues):
        if res_gt.get_resname() != res_input.get_resname():
            print(f"Residue name mismatch: {res_gt.get_resname()} vs {res_input.get_resname()}")
            return -1
    return len(input_residues)


def prepare_system(row, system_folder, output_models_folder, output_jsons_folder, should_overwrite=False):
    output_json_path = os.path.join(output_jsons_folder, f"{row['system_id']}.json")
    if os.path.exists(output_json_path) and not should_overwrite:
        return "Already exists"

    plinder_gt_pdb_path = os.path.join(system_folder, f"receptor.pdb")
    plinder_gt_ligand_paths = []
    plinder_gt_ligands_folder = os.path.join(system_folder, "ligand_files")

    gt_output_path = os.path.join(output_models_folder, f"{row['system_id']}_gt.pdb")
    gt_output_relative_path = "plinder_models/" + f"{row['system_id']}_gt.pdb"

    tmp_input_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_input.pdb")
    protein_input_path = os.path.join(output_models_folder, f"{row['system_id']}_input.pdb")
    protein_input_relative_path = "plinder_models/" + f"{row['system_id']}_input.pdb"

    print("Copying ground truth files")
    if not os.path.exists(plinder_gt_pdb_path):
        print("no receptor", plinder_gt_pdb_path)
        return "No receptor"

    tmp_gt_pdb_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_gt.pdb")
    robust_renumber_protein(plinder_gt_pdb_path, tmp_gt_pdb_path)

    ligand_pos_list = []
    for ligand_file in os.listdir(plinder_gt_ligands_folder):
        if not ligand_file.endswith(".sdf"):
            continue
        plinder_gt_ligand_paths.append(os.path.join(plinder_gt_ligands_folder, ligand_file))
        loaded_ligand = Chem.MolFromMolFile(os.path.join(plinder_gt_ligands_folder, ligand_file))
        ligand_pos_list.append(loaded_ligand.GetConformer().GetPositions())
        if loaded_ligand is None:
            print("failed to load", plinder_gt_ligand_paths[-1])
            return "Failed to load ligand"

    # Crop ground truth protein, also removes insertion codes
    ligand_pos = np.concatenate(ligand_pos_list, axis=0)

    res_count_in_protein = crop_protein_cont(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350,
                                             distance_threshold=5)
    if res_count_in_protein == -1:
        print("Failed to crop protein continously, using simple crop")
        crop_protein_simple(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350)

    os.remove(tmp_gt_pdb_path)

    # Generate input protein structure
    input_protein_source = None
    if pd.notna(row["linked_apo_id"]):
        apo_pdb_path = os.path.join(PLINDER_LINKED_APO_STRUCTURES, f"{row['linked_apo_id']}.cif")
        try:
            robust_renumber_protein(apo_pdb_path, tmp_input_path)
            input_protein_source = "apo"
            print("Using input apo", row['linked_apo_id'])
        except Exception as e:
            print("Problem with apo", e, row["linked_apo_id"], apo_pdb_path)
    if not os.path.exists(tmp_input_path) and pd.notna(row["linked_pred_id"]):
        pred_pdb_path = os.path.join(PLINDER_LINKED_PRED_STRUCTURES, f"{row['linked_pred_id']}.cif")
        try:
            # cif_to_pdb(pred_pdb_path, tmp_input_path)
            robust_renumber_protein(pred_pdb_path, tmp_input_path)
            input_protein_source = "pred"
            print("Using input  pred", row['linked_pred_id'])
        except:
            print("Problem with pred")
    if not os.path.exists(tmp_input_path):
        print("No linked structure found, running ESM")
        url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
        sequence, chain_lengths = get_sequence_from_pdb(gt_output_path)
        if len(sequence) <= 400:
            try:
                response = requests.post(url, data=sequence)
                response.raise_for_status()
                pdb_text = response.text
                with open(tmp_input_path, "w") as f:
                    f.write(pdb_text)

                # divide to chains
                if len(chain_lengths) > 1:
                    pdb_parser = Bio.PDB.PDBParser(QUIET=True)
                    pdb_struct = pdb_parser.get_structure("original_pdb", tmp_input_path)
                    pdb_model = next(iter(pdb_struct))
                    chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[:len(chain_lengths)]
                    start_ind = 1
                    esm_chain = next(pdb_model.get_chains())
                    new_model = Bio.PDB.Model.Model(0)
                    for chain_length, chain_id in zip(chain_lengths, chain_ids):
                        end_ind = start_ind + chain_length
                        new_chain = Bio.PDB.Chain.Chain(chain_id)
                        for res in esm_chain.get_residues():
                            if start_ind <= res.id[1] <= end_ind:
                                new_chain.add(res)
                        new_model.add(new_chain)
                        start_ind = end_ind + 20  # 20 is the gap in esm
                    io = Bio.PDB.PDBIO()
                    io.set_structure(new_model)
                    io.save(tmp_input_path)

                input_protein_source = "esm"
                print("Using input ESM")
            except requests.exceptions.RequestException as e:
                print(f"An error occurred in ESM: {e}")
                # return "No linked structure found"
        else:
            print("Sequence too long for ESM")
    if not os.path.exists(tmp_input_path):
        print("Using input GT")
        shutil.copyfile(gt_output_path, tmp_input_path)
        input_protein_source = "gt"

    align_gt_and_input(gt_output_path, tmp_input_path, gt_output_path, protein_input_path)
    protein_size = validate_matching_input_gt(gt_output_path, protein_input_path)
    assert protein_size > -1, "Failed to validate matching input and gt"
    os.remove(tmp_input_path)

    rel_gt_lig_paths = []
    rel_ref_lig_paths = []
    input_smiles = []
    for i, ligand_path in enumerate(sorted(plinder_gt_ligand_paths)):
        gt_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_gt_{i}.sdf")
        # rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ref_ligand_{i}.sdf")
        rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_gt_{i}.sdf")
        shutil.copyfile(ligand_path, gt_ligand_output_path)

        loaded_ligand = Chem.MolFromMolFile(gt_ligand_output_path)
        input_smiles.append(Chem.MolToSmiles(loaded_ligand))

        ref_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_ref_{i}.sdf")
        rel_ref_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_ref_{i}.sdf")
        create_conformers(input_smiles[-1], ref_ligand_output_path, num_conformers=1)
        # check if file is empty
        if os.path.getsize(ref_ligand_output_path) == 0:
            print("Empty ref ligand, copying from gt", ref_ligand_output_path)
            shutil.copyfile(gt_ligand_output_path, ref_ligand_output_path)

    affinity = row["ligand_binding_affinity"]
    if not pd.notna(affinity):
        affinity = None

    json_data = {
        "input_structure": protein_input_relative_path,
        "gt_structure": gt_output_relative_path,
        "gt_sdf_list": rel_gt_lig_paths,
        "input_smiles_list": input_smiles,
        "resolution": row.fillna(99)["entry_resolution"],
        "release_year": row["entry_release_date"],
        "affinity": affinity,
        "protein_seq_len": protein_size,
        "uniprot": row["system_pocket_UniProt"],
        "ligand_num_atoms": ligand_pos.shape[0],
        "cluster": row["cluster"],
        "cluster_size": row["cluster_size"],
        "input_protein_source": input_protein_source,
        "ref_sdf_list": rel_ref_lig_paths,
        "pdb_id": row["system_id"],
    }
    open(output_json_path, "w").write(json.dumps(json_data, indent=4))

    return "success"

    # use linked structures
    # input_structure_to_use = None
    # apo_linked_structure = os.path.join(linked_structures_folder, "apo", system_id)
    # pred_linked_structure = os.path.join(linked_structures_folder, "pred", system_id)
    # if os.path.exists(apo_linked_structure):
    #     for folder in os.listdir(apo_linked_structure):
    #         if not os.path.isdir(os.path.join(pred_linked_structure, folder)):
    #             continue
    #         for filename in os.listdir(os.path.join(apo_linked_structure, folder)):
    #             if filename.endswith(".cif"):
    #                 input_structure_to_use = os.path.join(apo_linked_structure, folder, filename)
    #                 break
    #         if input_structure_to_use:
    #             break
    #     print(system_id, "found apo", input_structure_to_use)
    # elif os.path.exists(pred_linked_structure):
    #     for folder in os.listdir(pred_linked_structure):
    #         if not os.path.isdir(os.path.join(pred_linked_structure, folder)):
    #             continue
    #         for filename in os.listdir(os.path.join(pred_linked_structure, folder)):
    #             if filename.endswith(".cif"):
    #                 input_structure_to_use = os.path.join(pred_linked_structure, folder, filename)
    #                 break
    #         if input_structure_to_use:
    #             break
    #     print(system_id, "found pred", input_structure_to_use)
    # else:
    #     print(system_id, "no linked structure found")
    #     return "No linked structure found"


def main(prefix_bucket_id: str = "*"):
    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    systems = get_cached_systems_to_train()
    print("total systems", len(systems))

    print("clusters", systems["cluster"].value_counts())

    # systems = systems[systems["system_num_protein_chains"] > 1]
    # return

    print("splits", systems["split"].value_counts())
    val_or_test = systems[(systems["split"] == "val") | (systems["split"] == "test")]
    print("validation or test", len(val_or_test))

    output_models_folder = os.path.join(OUTPUT_FOLDER, "plinder_models")
    output_train_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_train")
    output_val_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_val")
    output_test_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_test")
    output_info = os.path.join(OUTPUT_FOLDER, "plinder_generation_info.csv")
    if prefix_bucket_id != "*":
        output_info = os.path.join(OUTPUT_FOLDER, f"plinder_generation_info_{prefix_bucket_id}.csv")

    os.makedirs(output_models_folder, exist_ok=True)
    os.makedirs(output_train_jsons_folder, exist_ok=True)
    os.makedirs(output_val_jsons_folder, exist_ok=True)
    os.makedirs(output_test_jsons_folder, exist_ok=True)

    split_to_folder = {
        "train": output_train_jsons_folder,
        "val": output_val_jsons_folder,
        "test": output_test_jsons_folder
    }

    output_info_file = open(output_info, "a+")

    for bucket_id, bucket_systems in systems.groupby('_bucket_id', sort=True):
        if prefix_bucket_id != "*" and not str(bucket_id).startswith(prefix_bucket_id):
            continue
        # if bucket_id != "z2":
        #     continue
        # systems_folder = "{BASE_FOLDER}/processed/tmp_z2/systems"

        print("Starting bucket", bucket_id, len(bucket_systems))
        print(len(bucket_systems), bucket_systems["system_num_ligand_chains"].value_counts())

        tmp_output_models_folder = os.path.join(OUTPUT_FOLDER, f"tmp_{bucket_id}")
        os.makedirs(tmp_output_models_folder, exist_ok=True)
        os.system(f'{GSUTIL_PATH} -m cp -r "gs://plinder/2024-06/v2/systems/{bucket_id}.zip" {tmp_output_models_folder}')
        systems_folder = os.path.join(tmp_output_models_folder, "systems")
        os.system(f'unzip -o {os.path.join(tmp_output_models_folder, f"{bucket_id}.zip")} -d {systems_folder}')

        for i, row in bucket_systems.iterrows():
            # if not str(row['system_id']).startswith("4z22__1__1.A__1.C"):
            #     continue
            print("doing", row['system_id'], row["system_num_protein_chains"], row["system_num_ligand_chains"])
            system_folder = os.path.join(systems_folder, row['system_id'])
            try:
                success = prepare_system(row, system_folder, output_models_folder, split_to_folder[row["split"]])
                print("done", row['system_id'], success)
                output_info_file.write(f"{bucket_id},{row['system_id']},{success}\n")
            except Exception as e:
                print("Failed", row['system_id'], e)
                output_info_file.write(f"{bucket_id},{row['system_id']},Failed\n")
            output_info_file.flush()

        shutil.rmtree(tmp_output_models_folder)


if __name__ == '__main__':
    prefix_bucket_id = "*"
    if len(sys.argv) > 1:
        prefix_bucket_id = sys.argv[1]
    main(prefix_bucket_id)