AlienChen commited on
Commit
c4b1fea
·
verified ·
1 Parent(s): 9a451d4

Upload 106 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/callbacks/checkpoint_every_n_steps.yaml +8 -0
  2. configs/callbacks/checkpoint_monitor.yaml +10 -0
  3. configs/callbacks/learning_rate_monitor.yaml +3 -0
  4. configs/classifier_model/dimamba-classifier.yaml +14 -0
  5. configs/classifier_model/hyenadna-classifier.yaml +4 -0
  6. configs/classifier_model/small-classifier.yaml +11 -0
  7. configs/classifier_model/tiny-classifier.yaml +11 -0
  8. configs/classifier_model/tiny-dimamba-classifier.yaml +14 -0
  9. configs/config.yaml +129 -0
  10. configs/data/amazon_polarity.yaml +10 -0
  11. configs/data/cifar10.yaml +11 -0
  12. configs/data/lm1b.yaml +8 -0
  13. configs/data/peptide.yaml +8 -0
  14. configs/data/protein.yaml +8 -0
  15. configs/data/qm9.yaml +11 -0
  16. configs/data/ten_species.yaml +11 -0
  17. configs/data/text8.yaml +9 -0
  18. configs/guidance/cbg.yaml +5 -0
  19. configs/guidance/cfg.yaml +3 -0
  20. configs/guidance/fudge.yaml +5 -0
  21. configs/guidance/nos.yaml +6 -0
  22. configs/guidance/pplm.yaml +6 -0
  23. configs/lr_scheduler/constant_warmup.yaml +2 -0
  24. configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
  25. configs/model/dimamba.yaml +12 -0
  26. configs/model/fudge_predictor.yaml +4 -0
  27. configs/model/hf.yaml +2 -0
  28. configs/model/medium.yaml +10 -0
  29. configs/model/small.yaml +11 -0
  30. configs/model/tiny.yaml +10 -0
  31. configs/model/unet.yaml +19 -0
  32. configs/model/unet_campbell.yaml +19 -0
  33. configs/noise/ar.yaml +2 -0
  34. configs/noise/linear.yaml +3 -0
  35. configs/noise/loglinear.yaml +3 -0
  36. configs/noise/polynomial.yaml +5 -0
  37. configs/strategy/ddp.yaml +2 -0
  38. configs/strategy/fsdp.yaml +3 -0
  39. guidance_eval/__init__.py +0 -0
  40. guidance_eval/amazon_polarity_eval.py +228 -0
  41. guidance_eval/qm9_eval.py +208 -0
  42. guidance_eval/ten_species_eval.py +585 -0
  43. main.py +262 -0
  44. models/__init__.py +4 -0
  45. models/__pycache__/__init__.cpython-310.pyc +0 -0
  46. models/__pycache__/__init__.cpython-39.pyc +0 -0
  47. models/__pycache__/bindevaluator.cpython-310.pyc +0 -0
  48. models/__pycache__/dimamba.cpython-310.pyc +0 -0
  49. models/__pycache__/dimamba.cpython-39.pyc +0 -0
  50. models/__pycache__/dit.cpython-310.pyc +0 -0
configs/callbacks/checkpoint_every_n_steps.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ checkpoint_every_n_steps:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
4
+ save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
5
+ dirpath: ${checkpointing.save_dir}/checkpoints
6
+ verbose: True
7
+ auto_insert_metric_name: False
8
+ # every_n_train_steps: 500
configs/callbacks/checkpoint_monitor.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_monitor:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ monitor: val/nll # name of the logged metric which determines when model is improving
4
+ mode: min # can be "max" or "min"
5
+ save_top_k: 1 # save k best models (determined by above metric)
6
+ save_last: False # True = additionally always save model from last epoch
7
+ dirpath: ${checkpointing.save_dir}/checkpoints
8
+ filename: best
9
+ auto_insert_metric_name: False
10
+ verbose: True
configs/callbacks/learning_rate_monitor.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ learning_rate_monitor:
2
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
3
+ logging_interval: step
configs/classifier_model/dimamba-classifier.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dimamba
2
+ type: dimamba
3
+ hidden_size: 256
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 8
7
+ scale_by_sigma: True
8
+ dropout: 0.1
9
+ tie_word_embeddings: False
10
+ bidirectional: True,
11
+ bidirectional_strategy: add
12
+ bidirectional_weight_tie: True
13
+ num_classes: ${data.num_classes}
14
+ pooling: mean
configs/classifier_model/hyenadna-classifier.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: hyena-32k
2
+ type: hyenadna
3
+ hyena_model_name_or_path: ???
4
+ n_layer: 4
configs/classifier_model/small-classifier.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 12
7
+ n_heads: 12
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ num_classes: ${data.num_classes}
11
+ pooling: mean
configs/classifier_model/tiny-classifier.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ num_classes: ${data.num_classes}
11
+ pooling: mean
configs/classifier_model/tiny-dimamba-classifier.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: dimamba
3
+ hidden_size: 128
4
+ cond_dim: 128
5
+ length: ${model.length} # Same length as diffusion model
6
+ n_blocks: 4
7
+ scale_by_sigma: True
8
+ dropout: 0.1
9
+ tie_word_embeddings: False
10
+ bidirectional: True,
11
+ bidirectional_strategy: add
12
+ bidirectional_weight_tie: True
13
+ num_classes: ${data.num_classes}
14
+ pooling: mean
configs/config.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
4
+ - /data: protein
5
+ - /model: small
6
+ - /strategy: ddp
7
+ - /noise: loglinear
8
+ - /lr_scheduler: cosine_decay_warmup # constant_warmup
9
+ - /classifier_model: null
10
+ - /guidance: null
11
+
12
+ mode: ppl_eval # train / train_classifier / ppl_eval
13
+ diffusion: uniform # absorbing_state / uniform
14
+ backbone: dit # dit / dimamba / ar
15
+ classifier_backbone: null
16
+ parameterization: d3pm # subs / d3pm / ar
17
+ time_conditioning: True # UDLM is conditioned on time
18
+ subs_masking: False
19
+ zero_recon_loss: True # Use for UDLM
20
+ T: 0 # 0 (continuous time) / 1000
21
+
22
+ is_vision: False
23
+ seed: 42
24
+
25
+ loader:
26
+ global_batch_size: 512
27
+ eval_global_batch_size: ${.global_batch_size}
28
+ # Note: batch_size and eval_batch_size are **per machine**
29
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
30
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
31
+ num_workers: 0 # ${eval:"len(__import__('os').sched_getaffinity(0))"}
32
+ pin_memory: True
33
+ persistent_workers: False # True
34
+
35
+ sampling:
36
+ use_cache: True
37
+ steps: 32
38
+ # Note: batch_size is **per machine**
39
+ batch_size: 1 # ${loader.eval_batch_size}
40
+ num_sample_batches: 10 # Total samples: `num_gpus` * `batch_size` * `num_sample_batches`
41
+ use_float64: False
42
+
43
+ eval:
44
+ checkpoint_path: '/home/tc415/discrete-diffusion-guidance/outputs/peptide/2024.12.31/122818/checkpoints/best.ckpt' # Used to evaluate a checkpoint after training.
45
+ # target_sequence: 'MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS'
46
+ # target_motifs: '123-127' # UBC9
47
+ # target_sequence: 'MAMAEGERTECAEPPRDEPPADGALKRAEELKTQANDYFKAKDYENAIKFYSQAIELNPSNAIYYGNRSLAYLRTECYGYALGDATRAIELDKKYIKGYYRRAASNMALGKFRAALRDYETVVKVKPHDKDAKMKYQECNKIVKQKAFERAIAGDEHKRSVVDSLDIESMTIEDEYSGPKLEDGKVTISFMKELMQWYKDQKKLHRKCAYQILVQVKEVLSKLSTLVETTLKETEKITVCGDTHGQFYDLLNIFELNGLPSETNPYIFNGDFVDRGSFSVEVILTLFGFKLLYPDHFHLLRGNHETDNMNQIYGFEGEVKAKYTAQMYELFSEVFEWLPLAQCINGKVLIMHGGLFSEDGVTLDDIRKIERNRQPPDSGPMCDLLWSDPQPQNGRSISKRGVSCQFGPDVTKAFLEENNLDYIIRSHEVKAEGYEVAHGGRCVTVFSAPNYCDQMGNKASYIHLQGSDLRPQFHQFTAVPHPNVKPMAYANTLLQLGMM'
48
+ # target_motifs: '94-100' # PPP5
49
+ # target_sequence: 'MRHSKRTYCPDWDDKDWDYGKWRSSSSHKRRKRSHSSAQENKRCKYNHSKMCDSHYLESRSINEKDYHSRRYIDEYRNDYTQGCEPGHRQRDHESRYQNHSSKSSGRSGRSSYKSKHRIHHSTSHRRSHGKSHRRKRTRSVEDDEEGHLICQSGDVLSARYEIVDTLGEGAFGKVVECIDHKAGGRHVAVKIVKNVDRYCEAARSEIQVLEHLNTTDPNSTFRCVQMLEWFEHHGHICIVFELLGLSTYDFIKENGFLPFRLDHIRKMAYQICKSVNFLHSNKLTHTDLKPENILFVQSDYTEAYNPKIKRDERTLINPDIKVVDFGSATYDDEHHSTLVSTRHYRAPEVILALGWSQPCDVWSIGCILIEYYLGFTVFPTHDSKEHLAMMERILGPLPKHMIQKTRKRKYFHHDRLDWDEHSSAGRYVSRRCKPLKEFMLSQDVEHERLFDLIQKMLEYDPAKRITLREALKHPFFDLLKKSI'
50
+ # target_motifs: '336-342' # CLK1
51
+ # target_sequence: 'MEYHQPEDPAPGKAGTAEAVIPENHEVLAGPDEHPQDTDARDADGEAREREPADQALLPSQCGDNLESPLPEASSAPPGPTLGTLPEVETIRACSMPQELPQSPRTRQPEPDFYCVKWIPWKGEQTPIITQSTNGPCPLLAIMNILFLQWKVKLPPQKEVITSDELMAHLGNCLLSIKPQEKSEGLQLNFQQNVDDAMTVLPKLATGLDVNVRFTGVSDFEYTPECSVFDLLGIPLYHGWLVDPQSPEAVRAVGKLSYNQLVERIITCKHSSDTNLVTEGLIAEQFLETTAAQLTYHGLCELTAAAKEGELSVFFRNNHFSTMTKHKSHLYLLVTDQGFLQEEQVVWESLHNVDGDSCFCDSDFHLSHSLGKGPGAEGGSGSPETQLQVDQDYLIALSLQQQQPRGPLGLTDLELAQQLQQEEYQQQQAAQPVRMRTRVLSLQGRGATSGRPAGERRQRPKHESDCILL'
52
+ # target_motifs: '202-210' # MINDY1
53
+ # target_sequence: 'MTGNAGEWCLMESDPGVFTELIKGFGCRGAQVEEIWSLEPENFEKLKPVHGLIFLFKWQPGEEPAGSVVQDSRLDTIFFAKQVINNACATQAIVSVLLNCTHQDVHLGETLSEFKEFSQSFDAAMKGLALSNSDVIRQVHNSFARQQMFEFDTKTSAKEEDAFHFVSYVPVNGRLYELDGLREGPIDLGACNQDDWISAVRPVIEKRIQKYSEGEIRFNLMAIVSDRKMIYEQKIAELQRQLAEEEPMDTDQGNSMLSAIQSEVAKNQMLIEEEVQKLKRYKIENIRRKHNYLPFIMELLKTLAEHQQLIPLVEKAKEKQNAKKAQETK'
54
+ # target_motifs: '152-157' # UCHL5
55
+ # target_sequence: 'MSSGCQKTTTSKSIPTRWVTINDATHMPHDYSTTPGGTPFIITPGGTRIIYDRQFLLECRTSPLARTPPYSLPDIPGVTSPPSKHIINVKAHNGEPLNNNIAAPADKSTGDDAQFEMDI'
56
+ # target_motifs: '40-50' # 4E-BP2
57
+ # target_sequence: 'MASTDYSTYSQAAAQQGYSAYTAQPTQGYAQTTQAYGQQSYGTYGQPTDVSYTQAQTTATYGQTAYATSYGQPPTGYTTPTAPQAYSQPVQGYGTGAYDTTTATVTTTQASYAAQSAYGTQPAYPAYGQQPAATAPTRPQDGNKPTETSQPQSSTGGYNQPSLGYGQSNYSYPQVPGSYPMQPVTAPPSYPPTSYSSTQPTSYDQSSYSQQNTYGQPSSYGQQSSYGQQSSYGQQPPTSYPPQTGSYSQAPSQYSQQSSSYGQQNPSYDSVRRGAWGNNMNSGLNKSPPLGGAQTISKNTEQRPQPDPYQILGPTSSRLANPGSGQIQLWQFLLELLSDSANASCITWEGTNGEFKMTDPDEVARRWGERKSKPNMNYDKLSRALRYYYDKNIMTKVHGKRYAYKFDFHGIAQALQPHPTESSMYKYPSDISYMPSYHAHQQKVNFVPPHPSSMPVTSSSFFGAASQYWTSPTGGIYPNPNVPRHPNTHVPSHLGSYY'
58
+ # target_motifs: '323-330' # EWS::FLI1
59
+ target_sequence: 'MLQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESL'
60
+ target_motifs: '415-430' # NCAM1_IG
61
+ # target_sequence: 'TPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEFKTQPVQGEPSAPKLEGQMGEDGNSIKVNLIKQDDGGSPIRHYLVRYRALSSEWKPEIRLPSGSDHVMLKSLDWNAEYEVYVVAENQQGKSKAAHFVFRTSAQP'
62
+ # target_motifs: '98-108' # NCAM1_FN3
63
+
64
+ disable_ema: False
65
+ generate_samples: True
66
+ generated_samples_path: ''
67
+ max_samples: 50_000
68
+
69
+ training:
70
+ ema: 0.9999
71
+ antithetic_sampling: True
72
+ importance_sampling: False
73
+ sampling_eps: 1e-3
74
+ change_of_variables: False
75
+ compute_loss_on_pad_tokens: True
76
+ use_simple_ce_loss: False # Ignore ELBO; just use CE
77
+ guidance: null # Can turn off with `training.guidance: null`
78
+ # cond_dropout: 0.0
79
+
80
+ optim:
81
+ weight_decay: 1e-4
82
+ lr: 1e-5
83
+ beta1: 0.9
84
+ beta2: 0.999
85
+ eps: 1e-8
86
+
87
+ trainer:
88
+ _target_: lightning.Trainer
89
+ accelerator: cuda
90
+ num_nodes: 1
91
+ devices: 2 # ${device_count:}
92
+ accumulate_grad_batches: 1 # ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
93
+ gradient_clip_val: 1.0
94
+ precision: 'bf16-mixed'
95
+ num_sanity_val_steps: 2
96
+ # max_epochs: 10
97
+ max_steps: 1652000
98
+ log_every_n_steps: 100
99
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
100
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
101
+ val_check_interval: 16520 # 2545
102
+
103
+ wandb:
104
+ project: moPPIt-v2
105
+ job_type: model-training
106
+ name: protein_medium_100epochs_lr1e-5_gradclip1_wd1e-4_dropout0.1 #epochs10_lr3e-4_bsz8_64-true_all-params_gradclip1_beta-one0.9_beta-two0.999
107
+ id: ${.name}
108
+
109
+ hydra:
110
+ run:
111
+ dir: ./outputs/${wandb.name} # ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
112
+ job:
113
+ chdir: true
114
+
115
+ checkpointing:
116
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
117
+ save_dir: ${cwd:}
118
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
119
+ resume_from_ckpt: False
120
+ resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
121
+
122
+
123
+ # target_sequence: 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'
124
+ # target_motifs: '305-313' # P53_1
125
+ # target_motifs: '371-382' # P53_2
126
+ # target_motifs: '351-393' # P53_3
127
+ # target_motifs: '210-230' # P53_4
128
+ # target_sequence: 'MLQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKTLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEFKTQPVHSPPP'
129
+ # target_motifs: '28-39' # NCAM1_ECD
configs/data/amazon_polarity.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ train: amazon_polarity
2
+ valid: amazon_polarity
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
9
+ label_col: label
10
+ num_classes: 2
configs/data/cifar10.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train: ??? # (Local) Path to CIFAR-10 training data
2
+ valid: ??? # (Local) Path to CIFAR-10 validation data
3
+ label_col: labels
4
+ num_classes: 10
5
+ streaming: False
6
+ size: 1024
7
+ length: 3072
8
+ add_special_tokens: True
9
+ add_mask_token: True
10
+ tokenizer_name_or_path: raw_pixels
11
+
configs/data/lm1b.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ train: lm1b
2
+ valid: lm1b
3
+ tokenizer_name_or_path: bert-base-uncased
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
configs/data/peptide.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ train: peptide
2
+ valid: peptide
3
+ tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D
4
+ cache_dir: /home/tc415/discrete-diffusion-guidance/dataset
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
configs/data/protein.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ train: protein_400k
2
+ valid: protein_400k
3
+ tokenizer_name_or_path: facebook/esm2_t33_650M_UR50D
4
+ cache_dir: /home/tc415/discrete-diffusion-guidance/dataset
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
configs/data/qm9.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train: qm9
2
+ valid: qm9
3
+ tokenizer_name_or_path: yairschiff/qm9-tokenizer
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: True
9
+ label_col: qed
10
+ label_col_pctile: 90
11
+ num_classes: 2
configs/data/ten_species.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train: ten_species
2
+ valid: ten_species
3
+ tokenizer_name_or_path: kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16
4
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
5
+ wrap: False
6
+ streaming: False
7
+ override_cache: False
8
+ add_special_tokens: False
9
+ label_col: species_label
10
+ num_classes: 10
11
+ rc_aug: False
configs/data/text8.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: When using this dataset, set model.length = 256 to match D3PM setup
2
+ train: text8
3
+ valid: text8
4
+ tokenizer_name_or_path: text8
5
+ cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
6
+ wrap: True
7
+ streaming: False
8
+ override_cache: False
9
+ add_special_tokens: False
configs/guidance/cbg.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ method: cbg
2
+ condition: 0
3
+ classifier_checkpoint_path: '/home/tc415/discrete-diffusion-guidance/model_path/finetune_bindevaluator_0/model-epoch=30-val_mcc=0.60-val_loss=0.51.ckpt'
4
+ gamma: 2.0
5
+ use_approx: False # use first-order approximation
configs/guidance/cfg.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ method: cfg
2
+ condition: 0
3
+ gamma: 1.0
configs/guidance/fudge.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ method: fudge
2
+ condition: 0
3
+ classifier_checkpoint_path: ''
4
+ topk: 20
5
+ gamma: 1.0
configs/guidance/nos.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ method: nos
2
+ condition: 0
3
+ classifier_checkpoint_path: ''
4
+ num_nos_steps: 1
5
+ nos_step_size: 0.1
6
+ nos_stability_coef: 0.01
configs/guidance/pplm.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ method: pplm
2
+ condition: 0
3
+ classifier_checkpoint_path: ''
4
+ num_pplm_steps: 1
5
+ pplm_step_size: 0.1
6
+ pplm_stability_coef: 0.01
configs/lr_scheduler/constant_warmup.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: transformers.get_constant_schedule_with_warmup
2
+ num_warmup_steps: 2500
configs/lr_scheduler/cosine_decay_warmup.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: utils.CosineDecayWarmupLRScheduler
2
+ t_in_epochs: False
3
+ t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
4
+ warmup_prefix: True
5
+ warmup_lr_init: 1e-7
6
+ warmup_t: ${eval:0.1*${trainer.max_steps}}
7
+ lr_min: 1e-7
configs/model/dimamba.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dimamba
2
+ type: dimamba
3
+ hidden_size: 256
4
+ cond_dim: 128
5
+ length: 32768
6
+ n_blocks: 8
7
+ scale_by_sigma: True
8
+ dropout: 0.1
9
+ tie_word_embeddings: False
10
+ bidirectional: True,
11
+ bidirectional_strategy: add
12
+ bidirectional_weight_tie: True
configs/model/fudge_predictor.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: fudge_predictor
2
+ type: lstm
3
+ hidden_dim: 300
4
+ length: 1024
configs/model/hf.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pretrained_model_name_or_path: null
2
+ length: 128
configs/model/medium.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: medium
2
+ type: ddit
3
+ hidden_size: 1024
4
+ cond_dim: 128
5
+ length: 4096
6
+ n_blocks: 24
7
+ n_heads: 16
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/model/small.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: null
6
+ length_range: '25,27,28,31,35,43-49'
7
+ n_blocks: 12
8
+ n_heads: 12
9
+ scale_by_sigma: True
10
+ dropout: 0.1
11
+ tie_word_embeddings: False
configs/model/tiny.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tiny
2
+ type: ddit
3
+ hidden_size: 512
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 8
7
+ n_heads: 8
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/model/unet.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: unet
2
+ type: unet
3
+ ch: 128
4
+ num_res_blocks: 2
5
+ num_scales: 4
6
+ ch_mult: [1, 2, 2, 2]
7
+ input_channels: 3
8
+ output_channels: -1 # determined by vocab_size
9
+ scale_count_to_put_attn: 1 # at 16 res
10
+ data_min_max: [0, 255] # No need currently
11
+ dropout: 0.1
12
+ skip_rescale: True
13
+ time_conditioning: True # Whether to add in time embeddings
14
+ time_scale_factor: 1000
15
+ time_embed_dim: ${.ch}
16
+ fix_logistic: False
17
+ size: ${data.size}
18
+ cond_dim: ${.ch}
19
+ length: ${data.length}
configs/model/unet_campbell.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: unet
2
+ type: unet
3
+ ch: 128
4
+ num_res_blocks: 2
5
+ num_scales: 4
6
+ ch_mult: [1, 2, 2, 2]
7
+ input_channels: 3
8
+ output_channels: -1 # determined by input_channels * 2
9
+ scale_count_to_put_attn: 1 # at 16 res
10
+ data_min_max: [0, 255] # No need currently, determined by [0, vocab_size]
11
+ dropout: 0.1
12
+ skip_rescale: True
13
+ time_conditioning: True # Whether to add in time embeddings
14
+ time_scale_factor: 1000
15
+ time_embed_dim: ${.ch}
16
+ fix_logistic: False
17
+ size: ${data.size}
18
+ cond_dim: ${.ch}
19
+ length: ${data.length}
configs/noise/ar.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ type: ar
2
+ scale: 6.0
configs/noise/linear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: linear
2
+ sigma_min: 1e-3
3
+ sigma_max: 7.0
configs/noise/loglinear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: loglinear
2
+ sigma_min: 1e-4
3
+ sigma_max: 20
configs/noise/polynomial.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ type: polynomial
2
+ a: -3
3
+ b: 5
4
+ c: -4
5
+ eps: 1e-3
configs/strategy/ddp.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: lightning.pytorch.strategies.DDPStrategy
2
+ find_unused_parameters: false
configs/strategy/fsdp.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # TODO(yair): Currently not compatible with grad clipping
2
+ _target_: lightning.pytorch.strategies.FSDPStrategy
3
+ sharding_strategy: SHARD_GRAD_OP
guidance_eval/__init__.py ADDED
File without changes
guidance_eval/amazon_polarity_eval.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import json
3
+ import os
4
+
5
+ import hydra
6
+ import lightning as L
7
+ import omegaconf
8
+ import pandas as pd
9
+ import rdkit
10
+ import rich.syntax
11
+ import rich.tree
12
+ import spacy
13
+ import torch
14
+ import transformers
15
+ # from evaluate import load
16
+ from nltk.util import ngrams
17
+ from tqdm.auto import tqdm
18
+
19
+ import dataloader
20
+ import diffusion
21
+ import eval_utils
22
+
23
+ rdkit.rdBase.DisableLog('rdApp.error')
24
+
25
+ omegaconf.OmegaConf.register_new_resolver(
26
+ 'cwd', os.getcwd)
27
+ omegaconf.OmegaConf.register_new_resolver(
28
+ 'device_count', torch.cuda.device_count)
29
+ omegaconf.OmegaConf.register_new_resolver(
30
+ 'eval', eval)
31
+ omegaconf.OmegaConf.register_new_resolver(
32
+ 'div_up', lambda x, y: (x + y - 1) // y)
33
+ omegaconf.OmegaConf.register_new_resolver(
34
+ 'if_then_else',
35
+ lambda condition, x, y: x if condition else y
36
+ )
37
+
38
+
39
+ def _print_config(
40
+ config: omegaconf.DictConfig,
41
+ resolve: bool = True) -> None:
42
+ """Prints content of DictConfig using Rich library and its tree structure.
43
+
44
+ Args:
45
+ config (DictConfig): Configuration composed by Hydra.
46
+ resolve (bool): Whether to resolve reference fields of DictConfig.
47
+ """
48
+
49
+ style = 'dim'
50
+ tree = rich.tree.Tree('CONFIG', style=style,
51
+ guide_style=style)
52
+
53
+ fields = config.keys()
54
+ for field in fields:
55
+ branch = tree.add(field, style=style, guide_style=style)
56
+
57
+ config_section = config.get(field)
58
+ branch_content = str(config_section)
59
+ if isinstance(config_section, omegaconf.DictConfig):
60
+ branch_content = omegaconf.OmegaConf.to_yaml(
61
+ config_section, resolve=resolve)
62
+
63
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
64
+ rich.print(tree)
65
+
66
+ def compute_diversity(sentences):
67
+ # compute diversity
68
+ ngram_range = [2, 3, 4]
69
+
70
+ tokenizer = spacy.load("en_core_web_sm").tokenizer
71
+ token_list = []
72
+ for sentence in sentences:
73
+ token_list.append(
74
+ [str(token) for token in tokenizer(sentence)])
75
+ ngram_sets = {}
76
+ ngram_counts = collections.defaultdict(int)
77
+ n_gram_repetition = {}
78
+
79
+ for n in ngram_range:
80
+ ngram_sets[n] = set()
81
+ for tokens in token_list:
82
+ ngram_sets[n].update(ngrams(tokens, n))
83
+ ngram_counts[n] += len(list(ngrams(tokens, n)))
84
+ n_gram_repetition[f"{n}gram_repetition"] = (
85
+ 1 - len(ngram_sets[n]) / ngram_counts[n])
86
+ diversity = 1
87
+ for val in n_gram_repetition.values():
88
+ diversity *= (1 - val)
89
+ return diversity
90
+
91
+
92
+ def compute_sentiment_classifier_score(sentences, eval_model_name_or_path):
93
+ tokenizer = transformers.AutoTokenizer.from_pretrained(eval_model_name_or_path)
94
+ eval_model = transformers.AutoModelForSequenceClassification.from_pretrained(
95
+ eval_model_name_or_path).to('cuda')
96
+ eval_model.eval()
97
+
98
+ total_pos = 0
99
+ total_neg = 0
100
+ pbar = tqdm(sentences, desc='Classifier eval')
101
+ for sen in pbar:
102
+ # Tokenize the input text
103
+ inputs = tokenizer(
104
+ sen,
105
+ return_tensors="pt",
106
+ truncation=True,
107
+ padding=True).to('cuda')
108
+
109
+ # Get the model predictions
110
+ with torch.no_grad():
111
+ outputs = eval_model(**inputs)
112
+
113
+ # Convert logits to probabilities
114
+ probs = torch.nn.functional.softmax(
115
+ outputs.logits, dim=-1)
116
+
117
+ # Get the predicted class
118
+ predicted_class = torch.argmax(probs, dim=1).item()
119
+ if predicted_class == 1:
120
+ total_pos += 1
121
+ else:
122
+ total_neg += 1
123
+ pbar.set_postfix(accuracy=total_pos / (total_pos + total_neg))
124
+ return total_pos / (total_pos + total_neg)
125
+
126
+
127
+ # def compute_mauve(config, tokenizer, sentences):
128
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
129
+ # # compute mauve
130
+ # torch.cuda.empty_cache()
131
+ # mauve = load("mauve")
132
+ # human_references = []
133
+ #
134
+ # valid_loader = dataloader.get_dataloaders(
135
+ # config, tokenizer, valid_seed=config.seed)
136
+ #
137
+ # # construct reference
138
+ # for batch_id in range(config.sampling.num_sample_batches):
139
+ # batch = next(iter(valid_loader))
140
+ # input_ids = batch['input_ids']
141
+ # for i in range(config.sampling.batch_size):
142
+ # idx = (
143
+ # input_ids[i] == tokenizer.eos_token_id).nonzero(
144
+ # as_tuple=True)
145
+ # if idx[0].numel() > 0:
146
+ # idx = idx[0][0].item()
147
+ # input_ids[i, (idx + 1):] = 0
148
+ # human_references.extend(
149
+ # tokenizer.batch_decode(
150
+ # input_ids, skip_special_tokens=True))
151
+ #
152
+ # assert len(sentences) == len(human_references)
153
+ #
154
+ # results = mauve.compute(predictions=sentences,
155
+ # references=human_references,
156
+ # featurize_model_name=config.data.mauve_model,
157
+ # max_text_length=256, device_id=0)
158
+ # return results.mauve
159
+
160
+
161
+
162
+ @hydra.main(version_base=None, config_path='../configs',
163
+ config_name='config')
164
+ def main(config: omegaconf.DictConfig) -> None:
165
+ # Reproducibility
166
+ L.seed_everything(config.seed)
167
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
168
+ torch.use_deterministic_algorithms(True)
169
+ torch.backends.cudnn.benchmark = False
170
+
171
+ _print_config(config, resolve=True)
172
+ print(f"Checkpoint: {config.eval.checkpoint_path}")
173
+
174
+ tokenizer = dataloader.get_tokenizer(config)
175
+ pretrained = diffusion.Diffusion.load_from_checkpoint(
176
+ config.eval.checkpoint_path,
177
+ tokenizer=tokenizer,
178
+ config=config, logger=False)
179
+ pretrained.eval()
180
+ result_dicts = []
181
+ samples = []
182
+ for _ in tqdm(
183
+ range(config.sampling.num_sample_batches),
184
+ desc='Gen. batches', leave=False):
185
+ sample = pretrained.sample()
186
+ samples.extend(
187
+ pretrained.tokenizer.batch_decode(sample))
188
+ samples = [
189
+ s.replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '').replace('[MASK]', '').strip()
190
+ for s in samples
191
+ ]
192
+ del pretrained # free up space for eval
193
+
194
+ diversity_score = compute_diversity(samples)
195
+ classifier_accuracy = compute_sentiment_classifier_score(
196
+ samples, eval_model_name_or_path=config.eval.classifier_model_name_or_path)
197
+
198
+ generative_ppl = eval_utils.compute_generative_ppl(
199
+ samples,
200
+ eval_model_name_or_path=config.eval.generative_ppl_model_name_or_path,
201
+ gen_ppl_eval_batch_size=8,
202
+ max_length=config.model.length)
203
+
204
+ result_dicts.append({
205
+ 'Seed': config.seed,
206
+ 'T': config.sampling.steps,
207
+ 'Num Samples': config.sampling.batch_size * config.sampling.num_sample_batches,
208
+ 'Diversity': diversity_score,
209
+ 'Accuracy': classifier_accuracy,
210
+ 'Gen. PPL': generative_ppl,
211
+ } | {k.capitalize(): v for k, v in config.guidance.items()})
212
+ print("Guidance:", ", ".join([f"{k.capitalize()} - {v}" for k, v in config.guidance.items()]))
213
+ print(f"\tDiversity: {diversity_score:0.3f} ",
214
+ f"Accuracy: {classifier_accuracy:0.3f} ",
215
+ f"Gen. PPL: {generative_ppl:0.3f}")
216
+ print(f"Generated {len(samples)} sentences.")
217
+ with open(config.eval.generated_samples_path, 'w') as f:
218
+ json.dump(
219
+ {
220
+ 'generated_seqs': samples,
221
+ },
222
+ f, indent=4) # type: ignore
223
+ results_df = pd.DataFrame.from_records(result_dicts)
224
+ results_df.to_csv(config.eval.results_csv_path)
225
+
226
+
227
+ if __name__ == '__main__':
228
+ main()
guidance_eval/qm9_eval.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import typing
5
+
6
+ import datasets
7
+ import hydra
8
+ import lightning as L
9
+ import numpy as np
10
+ import omegaconf
11
+ import pandas as pd
12
+ import rdkit
13
+ import rich.syntax
14
+ import rich.tree
15
+ import torch
16
+ from rdkit import Chem as rdChem
17
+ from rdkit.Chem import QED
18
+ from tqdm.auto import tqdm
19
+
20
+ import dataloader
21
+ import diffusion
22
+
23
+ rdkit.rdBase.DisableLog('rdApp.error')
24
+
25
+ omegaconf.OmegaConf.register_new_resolver(
26
+ 'cwd', os.getcwd)
27
+ omegaconf.OmegaConf.register_new_resolver(
28
+ 'device_count', torch.cuda.device_count)
29
+ omegaconf.OmegaConf.register_new_resolver(
30
+ 'eval', eval)
31
+ omegaconf.OmegaConf.register_new_resolver(
32
+ 'div_up', lambda x, y: (x + y - 1) // y)
33
+ omegaconf.OmegaConf.register_new_resolver(
34
+ 'if_then_else',
35
+ lambda condition, x, y: x if condition else y
36
+ )
37
+
38
+
39
+ def _print_config(
40
+ config: omegaconf.DictConfig,
41
+ resolve: bool = True) -> None:
42
+ """Prints content of DictConfig using Rich library and its tree structure.
43
+
44
+ Args:
45
+ config (DictConfig): Configuration composed by Hydra.
46
+ resolve (bool): Whether to resolve reference fields of DictConfig.
47
+ """
48
+
49
+ style = 'dim'
50
+ tree = rich.tree.Tree('CONFIG', style=style,
51
+ guide_style=style)
52
+
53
+ fields = config.keys()
54
+ for field in fields:
55
+ branch = tree.add(field, style=style, guide_style=style)
56
+
57
+ config_section = config.get(field)
58
+ branch_content = str(config_section)
59
+ if isinstance(config_section, omegaconf.DictConfig):
60
+ branch_content = omegaconf.OmegaConf.to_yaml(
61
+ config_section, resolve=resolve)
62
+
63
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
64
+ rich.print(tree)
65
+
66
+
67
+ def get_mol_property_fn(
68
+ prop: str
69
+ ) -> typing.Callable[[rdChem.Mol], typing.Union[int, float]]:
70
+ if prop == 'qed':
71
+ return QED.qed
72
+ if prop == 'ring_count':
73
+ return lambda x_mol: len(rdChem.GetSymmSSSR(x_mol))
74
+ raise NotImplementedError(
75
+ f"Property function for {prop} not implemented")
76
+
77
+
78
+ @hydra.main(version_base=None, config_path='../configs',
79
+ config_name='config')
80
+ def main(config: omegaconf.DictConfig) -> None:
81
+ # Reproducibility
82
+ L.seed_everything(config.seed)
83
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
84
+ torch.use_deterministic_algorithms(True)
85
+ torch.backends.cudnn.benchmark = False
86
+
87
+ _print_config(config, resolve=True)
88
+ print(f"Checkpoint: {config.eval.checkpoint_path}")
89
+
90
+ qm9_dataset = datasets.load_dataset(
91
+ 'yairschiff/qm9', trust_remote_code=True,
92
+ split='train')
93
+ tokenizer = dataloader.get_tokenizer(config)
94
+ pretrained = diffusion.Diffusion.load_from_checkpoint(
95
+ config.eval.checkpoint_path,
96
+ tokenizer=tokenizer,
97
+ config=config, logger=False)
98
+ pretrained.eval()
99
+ label_col = config.data.label_col
100
+ pctile_threshold = config.data.label_col_pctile
101
+ pctile_threshold_value = np.percentile(
102
+ qm9_dataset[label_col], q=pctile_threshold)
103
+ above_threshold = np.array(qm9_dataset[label_col])[
104
+ qm9_dataset[label_col] >= pctile_threshold_value]
105
+ below_threshold = np.array(qm9_dataset[label_col])[
106
+ qm9_dataset[label_col] < pctile_threshold_value]
107
+ result_dicts = []
108
+ mol_property_fn = get_mol_property_fn(label_col)
109
+
110
+ print(
111
+ f"All - {label_col.upper()} Mean: {np.mean(qm9_dataset[label_col]):0.3f}, {label_col.upper()} Median: {np.median(qm9_dataset[label_col]):0.3f}")
112
+ print(
113
+ f"Below {pctile_threshold}%ile - {label_col.upper()} Mean: {np.mean(below_threshold):0.3f}, {label_col.upper()} Median: {np.median(below_threshold):0.3f}")
114
+ print(
115
+ f"Above {pctile_threshold}%ile - {label_col.upper()} Mean: {np.mean(above_threshold):0.3f}, {label_col.upper()} Median: {np.median(above_threshold):0.3f}")
116
+ result_dicts.append({
117
+ 'Seed': -1,
118
+ 'T': -1,
119
+ 'Num Samples': len(qm9_dataset),
120
+ 'Valid': 1.0,
121
+ 'Unique': 1.0,
122
+ 'Novel': 1.0,
123
+ f'{label_col.upper()} Mean': np.mean(qm9_dataset[label_col]),
124
+ f'{label_col.upper()} 25%ile': np.percentile(qm9_dataset[label_col], q=25),
125
+ f'{label_col.upper()} Median': np.median(qm9_dataset[label_col]),
126
+ f'{label_col.upper()} 75%ile': np.percentile(qm9_dataset[label_col], q=75),
127
+ f'Novel {label_col.upper()} Mean': np.mean(qm9_dataset[label_col]),
128
+ f'Novel {label_col.upper()} 25%ile': np.percentile(qm9_dataset[label_col], q=25),
129
+ f'Novel {label_col.upper()} Median': np.median(qm9_dataset[label_col]),
130
+ f'Novel {label_col.upper()} 75%ile': np.percentile(qm9_dataset[label_col], q=75),
131
+ } | {k.capitalize(): -1 for k, v in config.guidance.items()})
132
+
133
+ samples = []
134
+ for _ in tqdm(
135
+ range(config.sampling.num_sample_batches),
136
+ desc='Gen. batches', leave=False):
137
+ start = time.time()
138
+ sample = pretrained.sample()
139
+ # print(f"Batch took {time.time() - start:.2f} seconds.")
140
+ samples.extend(
141
+ pretrained.tokenizer.batch_decode(sample))
142
+ invalids = []
143
+ valids = []
144
+ mol_property = []
145
+ for t in samples:
146
+ t = t.replace('<bos>', '').replace('<eos>', '').replace('<pad>', '')
147
+ try:
148
+ mol = rdChem.MolFromSmiles(t)
149
+ if mol is None or len(t) == 0:
150
+ invalids.append(t)
151
+ else:
152
+ valids.append(t)
153
+ mol_property.append(mol_property_fn(mol))
154
+ except rdkit.Chem.rdchem.KekulizeException as e:
155
+ print(e)
156
+ invalids.append(t)
157
+ valid = len(valids)
158
+ valid_pct = len(valids) / len(samples)
159
+ unique = len(set(valids))
160
+ novel = len(set(valids) - set(qm9_dataset['canonical_smiles']))
161
+ try:
162
+ unique_pct = unique / valid
163
+ novel_pct = novel / valid
164
+ except ZeroDivisionError:
165
+ unique_pct, novel_pct = 0., 0.
166
+ mol_property_novel = [
167
+ mol_property_fn(rdChem.MolFromSmiles(s))
168
+ for s in set(valids) - set(qm9_dataset['canonical_smiles'])
169
+ ]
170
+ result_dicts.append({
171
+ 'Seed': config.seed,
172
+ 'T': config.sampling.steps,
173
+ 'Num Samples': config.sampling.batch_size * config.sampling.num_sample_batches,
174
+ 'Valid': valid_pct,
175
+ 'Unique': unique_pct,
176
+ 'Novel': novel_pct,
177
+ f'{label_col.upper()} Mean': np.mean(mol_property) if len(mol_property) > 0 else 0.,
178
+ f'{label_col.upper()} 25%ile': np.percentile(mol_property, q=25) if len(mol_property) > 0 else 0.,
179
+ f'{label_col.upper()} Median': np.median(mol_property) if len(mol_property) > 0 else 0.,
180
+ f'{label_col.upper()} 75%ile': np.percentile(mol_property, q=75) if len(mol_property) > 0 else 0.,
181
+ f'Novel {label_col.upper()} Mean': np.mean(mol_property_novel) if len(mol_property_novel) > 0 else 0.,
182
+ f'Novel {label_col.upper()} 25%ile': np.percentile(mol_property_novel, q=25) if len(mol_property_novel) > 0 else 0.,
183
+ f'Novel {label_col.upper()} Median': np.median(mol_property_novel) if len(mol_property_novel) > 0 else 0.,
184
+ f'Novel {label_col.upper()} 75%ile': np.percentile(mol_property_novel, q=75) if len(mol_property_novel) > 0 else 0.,
185
+ } | {k.capitalize(): v for k, v in config.guidance.items()})
186
+ print("Guidance:", ", ".join([f"{k.capitalize()} - {v}" for k, v in config.guidance.items()]))
187
+ print(f"\tValid: {valid:,d} / {len(samples):,d} ({100 * valid_pct:0.2f}%) ",
188
+ f"Unique (of valid): {unique:,d} / {valid:,d} ({100 * unique_pct:0.2f}%) ",
189
+ f"Novel (of valid): {novel:,d} / {valid:,d} ({100 * novel_pct:0.2f}%)\n",
190
+ f"\t{label_col.upper()} Mean: {np.mean(mol_property) if len(mol_property) else 0.:0.3f}, {label_col.upper()} Median: {np.median(mol_property) if len(mol_property) else 0.:0.3f}\n",
191
+ f"\tNovel {label_col.upper()} Mean: {np.mean(mol_property_novel) if len(mol_property_novel) else 0.:0.3f}, Novel {label_col.upper()} Median: {np.median(mol_property_novel) if len(mol_property_novel) else 0.:0.3f}"
192
+ )
193
+ print(f"Generated {len(samples)} sentences.")
194
+ with open(config.eval.generated_samples_path, 'w') as f:
195
+ json.dump(
196
+ {
197
+ 'valid': valids,
198
+ 'novel': list(set(valids) - set(qm9_dataset['canonical_smiles'])),
199
+ f"{label_col}_valid": mol_property,
200
+ f"{label_col}_novel": mol_property_novel,
201
+ },
202
+ f, indent=4) # type: ignore
203
+ results_df = pd.DataFrame.from_records(result_dicts)
204
+ results_df.to_csv(config.eval.results_csv_path)
205
+
206
+
207
+ if __name__ == '__main__':
208
+ main()
guidance_eval/ten_species_eval.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import os
4
+ import typing
5
+
6
+ import datasets
7
+ import hydra
8
+ import lightning as L
9
+ import numpy as np
10
+ import omegaconf
11
+ import pandas as pd
12
+ import rdkit
13
+ import rich.syntax
14
+ import rich.tree
15
+ import scipy
16
+ import torch
17
+ import transformers
18
+ from sklearn.metrics import (
19
+ f1_score,
20
+ matthews_corrcoef,
21
+ precision_score,
22
+ recall_score,
23
+ roc_auc_score
24
+ )
25
+ from tqdm.auto import tqdm
26
+
27
+ import classifier
28
+ import custom_datasets
29
+ import dataloader
30
+ import diffusion
31
+
32
+ rdkit.rdBase.DisableLog('rdApp.error')
33
+
34
+ omegaconf.OmegaConf.register_new_resolver(
35
+ 'cwd', os.getcwd)
36
+ omegaconf.OmegaConf.register_new_resolver(
37
+ 'device_count', torch.cuda.device_count)
38
+ omegaconf.OmegaConf.register_new_resolver(
39
+ 'eval', eval)
40
+ omegaconf.OmegaConf.register_new_resolver(
41
+ 'div_up', lambda x, y: (x + y - 1) // y)
42
+ omegaconf.OmegaConf.register_new_resolver(
43
+ 'if_then_else',
44
+ lambda condition, x, y: x if condition else y
45
+ )
46
+
47
+
48
+ def _print_config(
49
+ config: omegaconf.DictConfig,
50
+ resolve: bool = True) -> None:
51
+ """Prints content of DictConfig using Rich library and its tree structure.
52
+
53
+ Args:
54
+ config (DictConfig): Configuration composed by Hydra.
55
+ resolve (bool): Whether to resolve reference fields of DictConfig.
56
+ """
57
+
58
+ style = 'dim'
59
+ tree = rich.tree.Tree('CONFIG', style=style,
60
+ guide_style=style)
61
+
62
+ fields = config.keys()
63
+ for field in fields:
64
+ branch = tree.add(field, style=style, guide_style=style)
65
+
66
+ config_section = config.get(field)
67
+ branch_content = str(config_section)
68
+ if isinstance(config_section, omegaconf.DictConfig):
69
+ branch_content = omegaconf.OmegaConf.to_yaml(
70
+ config_section, resolve=resolve)
71
+
72
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
73
+ rich.print(tree)
74
+
75
+
76
+ def generate_ordered_kmers(
77
+ kmer_length: int
78
+ ) -> typing.List[str]:
79
+ """
80
+ Function that generates all kmers of a given length and orders them by their index
81
+ defined by the kmer_to_index function.
82
+
83
+ Args:
84
+ kmer_length (int): The length of the kmers to generate
85
+
86
+ Returns:
87
+ List[str]: A list of all kmers of the given length ordered by their index
88
+ """
89
+ characters = ["A", "C", "G", "T"]
90
+
91
+ kmers = ["".join(kmer) for kmer in
92
+ itertools.product(characters,
93
+ repeat=kmer_length)]
94
+ ordered_kmers = sorted(kmers, key=kmer_to_index)
95
+
96
+ return ordered_kmers
97
+
98
+
99
+ def kmer_to_index(kmer: str) -> int:
100
+ """
101
+ Function that converts a given kmer to a unique value
102
+ system.
103
+
104
+ Args:
105
+ kmer (str): The given kmer
106
+
107
+ Returns:
108
+ int: The associated unique value
109
+
110
+ Example:
111
+ >>> kmer_to_index("AAC")
112
+ 1
113
+
114
+ """
115
+ mapping = {"A": 0, "C": 1, "G": 2, "T": 3}
116
+ index = 0
117
+ for char in kmer:
118
+ index = index * 4 + mapping[char]
119
+ return index
120
+
121
+
122
+ def compute_kmer_frequencies(
123
+ seqs: typing.List[str], kmer_length: int
124
+ ) -> typing.Tuple[typing.List[float], typing.List[str]]:
125
+ """
126
+ Computes the kmer frequencies in a list of sequences.
127
+ Each element of the output array is the frequency of a given kmer over the whole
128
+ set of sequences.
129
+
130
+ Args:
131
+ seqs (List[str]): List of nucleotide sequences
132
+ kmer_length (int): Length of the kmers
133
+
134
+ Returns:
135
+ List[float]: Kmer frequencies
136
+ List[str]: The kmers
137
+
138
+ Example:
139
+ >>> sequences = ["AGCT", "AAAA"]
140
+ >>> compute_kmer_frequencies(seqs, kmer_length=1)
141
+ ([0.625, 0.125, 0.125, 0.125], ['A', 'C', 'G', 'T'])
142
+ """
143
+
144
+ kmer_counts: typing.Dict[str, int] = {}
145
+ count_kmers_occurrences = 0
146
+ for seq in seqs:
147
+ for i in range(len(seq) - kmer_length + 1):
148
+ kmer = seq[i: i + kmer_length]
149
+ if kmer in kmer_counts:
150
+ kmer_counts[kmer] += 1
151
+ else:
152
+ kmer_counts[kmer] = 1
153
+ count_kmers_occurrences += 1
154
+
155
+ kmer_list = generate_ordered_kmers(kmer_length)
156
+ kmer_frequencies = []
157
+ for kmer in kmer_list:
158
+ try:
159
+ kmer_frequencies.append(
160
+ kmer_counts[kmer] / count_kmers_occurrences)
161
+ except KeyError:
162
+ kmer_frequencies.append(0)
163
+
164
+ return kmer_frequencies, kmer_list
165
+
166
+
167
+ def run_eval_pipeline(
168
+ seqs: typing.Dict[int, typing.List[str]],
169
+ num_samples_per_class: int,
170
+ train_weights_path: str,
171
+ val_weights_path: str,
172
+ eval_classifier_checkpoint_path: str,
173
+ kmer_freqs_path: str
174
+ ):
175
+ # Eval pipeline
176
+ L.seed_everything(42)
177
+
178
+ # Load classifier
179
+ with hydra.initialize(version_base=None,
180
+ config_path='../configs/'):
181
+ classifier_config = hydra.compose(
182
+ config_name='config',
183
+ overrides=[
184
+ 'hydra.output_subdir=null',
185
+ 'hydra.job.chdir=False',
186
+ 'hydra/job_logging=disabled',
187
+ 'hydra/hydra_logging=disabled',
188
+ '+is_eval_classifier=True',
189
+ 'mode=train_classifier',
190
+ 'loader.global_batch_size=32',
191
+ 'loader.eval_global_batch_size=64',
192
+ 'loader.batch_size=2',
193
+ 'loader.eval_batch_size=4',
194
+ 'data=ten_species',
195
+ 'classifier_model=hyenadna-classifier',
196
+ 'classifier_model.hyena_model_name_or_path=LongSafari/hyenadna-small-32k-seqlen-hf',
197
+ 'classifier_backbone=hyenadna',
198
+ 'classifier_model.n_layer=8',
199
+ 'model.length=32768',
200
+ 'diffusion=null',
201
+ 'T=null',
202
+ f"eval.checkpoint_path={eval_classifier_checkpoint_path}"
203
+ ]
204
+ )
205
+ classifier_config = omegaconf.OmegaConf.create(
206
+ classifier_config)
207
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
208
+ classifier_config.data.tokenizer_name_or_path,
209
+ trust_remote_code=True)
210
+ pretrained_classifier = classifier.Classifier.load_from_checkpoint(
211
+ classifier_config.eval.checkpoint_path,
212
+ tokenizer=tokenizer,
213
+ config=classifier_config, logger=False)
214
+ pretrained_classifier.eval()
215
+
216
+ tokenizer = dataloader.get_tokenizer(classifier_config)
217
+ _, val_dl = dataloader.get_dataloaders(
218
+ classifier_config, tokenizer, skip_train=True,
219
+ valid_seed=classifier_config.seed)
220
+
221
+ dataset = datasets.load_dataset(
222
+ 'yairschiff/ten_species',
223
+ split='train',
224
+ # original dataset only has `train` split
225
+ chunk_length=classifier_config.model.length,
226
+ overlap=0,
227
+ trust_remote_code=True)
228
+ dataset = dataset.train_test_split(
229
+ test_size=0.05, seed=42)
230
+ train_dataset = dataset['train']
231
+ val_dataset = dataset['test']
232
+
233
+
234
+ print(f"Len of train set {len(train_dataset) * (2 ** 15):,d}")
235
+ print(f"Len of val set {len(val_dataset) * (2 ** 15):,d}")
236
+
237
+ int_to_species = ['Homo_sapiens', 'Mus_musculus',
238
+ 'Drosophila_melanogaster',
239
+ 'Danio_rerio',
240
+ 'Caenorhabditis_elegans',
241
+ 'Gallus_gallus', 'Gorilla_gorilla',
242
+ 'Felis_catus',
243
+ 'Salmo_trutta', 'Arabidopsis_thaliana']
244
+
245
+ if os.path.exists(train_weights_path):
246
+ train_weights = torch.load(train_weights_path)
247
+ else:
248
+ train_weights = {k: 0 for k in range(10)}
249
+ for i in tqdm(train_dataset, leave=False):
250
+ train_weights[i['species_label']] += 1
251
+ train_weights = {
252
+ k: v / np.sum(list(train_weights.values())) for k, v
253
+ in train_weights.items()}
254
+ torch.save(train_weights, train_weights_path)
255
+ print('Train weights:')
256
+ for k, v in train_weights.items():
257
+ print("\t", int_to_species[k], f"{100 * v:0.2f}")
258
+
259
+ if os.path.exists(val_weights_path):
260
+ val_weights = torch.load(val_weights_path)
261
+ else:
262
+ val_weights = {k: 0 for k in range(10)}
263
+ for i in tqdm(val_dataset, leave=False):
264
+ val_weights[i['species_label']] += 1
265
+ val_weights = {k: v / np.sum(list(val_weights.values()))
266
+ for k, v in val_weights.items()}
267
+ torch.save(val_weights, val_weights_path)
268
+ print('\nVal weights:')
269
+ for k, v in val_weights.items():
270
+ print("\t", int_to_species[k], f"{100 * v:0.2f}")
271
+
272
+
273
+ result_dict = {}
274
+ test_data = []
275
+
276
+ for k, v in seqs.items():
277
+ test_data.extend(
278
+ [
279
+ {
280
+ 'sequence': s.replace('[CLS]', '').replace(
281
+ '[BOS]', '').replace('[MASK]', '').replace(
282
+ '[SEP]', '').replace('[PAD]', '').replace(
283
+ '[UNK]', ''),
284
+ 'species_label': k
285
+ }
286
+ for s in v
287
+ ]
288
+ )
289
+ test_dataset = custom_datasets.ten_species_dataset.TenSpeciesDataset(
290
+ split='test',
291
+ tokenizer=tokenizer,
292
+ max_length=classifier_config.model.length,
293
+ rc_aug=False,
294
+ add_special_tokens=classifier_config.data.add_special_tokens,
295
+ dataset=test_data
296
+ )
297
+
298
+ ## CLASSIFIER ACCURACY
299
+ test_preds = [
300
+ pretrained_classifier.forward(
301
+ test_dataset[i]['input_ids'][None, ...].to(
302
+ 'cuda')).argmax(dim=-1).detach().item()
303
+ for i in
304
+ tqdm(range(len(test_dataset)), desc='Testing')
305
+ ]
306
+ test_preds = np.array(test_preds)
307
+
308
+ test_labels = []
309
+ for k, v in seqs.items():
310
+ test_labels.extend([int(k)] * len(v))
311
+ test_labels = np.array(test_labels)
312
+
313
+ overall_accuracy_score = (test_preds == test_labels).sum() / test_preds.size
314
+ overall_f1_score = f1_score(y_pred=test_preds,
315
+ y_true=test_labels,
316
+ average="macro",
317
+ labels=list(range(classifier_config.data.num_classes)))
318
+ overall_mcc_score = matthews_corrcoef(y_pred=test_preds, y_true=test_labels)
319
+
320
+ print(f"Overall Acc: {overall_accuracy_score:0.2f}")
321
+ print(f"Overall F1: {overall_f1_score:0.2f}")
322
+ print(f"Overall MCC: {overall_mcc_score:0.2f}")
323
+ result_dict['F1'] = overall_f1_score
324
+
325
+ f1_scores = f1_score(
326
+ y_pred=test_preds,
327
+ y_true=test_labels,
328
+ average=None,
329
+ labels=list(range(classifier_config.data.num_classes)))
330
+ precision_scores = precision_score(
331
+ y_pred=test_preds,
332
+ y_true=test_labels,
333
+ average=None,
334
+ labels=list(range(classifier_config.data.num_classes)))
335
+ recall_scores = recall_score(
336
+ y_pred=test_preds,
337
+ y_true=test_labels,
338
+ average=None,
339
+ labels=list(range(classifier_config.data.num_classes)))
340
+
341
+ species_list = ['Homo_sapiens', 'Mus_musculus',
342
+ 'Drosophila_melanogaster',
343
+ 'Danio_rerio',
344
+ 'Caenorhabditis_elegans',
345
+ 'Gallus_gallus', 'Gorilla_gorilla',
346
+ 'Felis_catus',
347
+ 'Salmo_trutta',
348
+ 'Arabidopsis_thaliana']
349
+ for s in range(classifier_config.data.num_classes):
350
+ print(f"Class {s} - {species_list[s]}:")
351
+ print(f" F1: {f1_scores[s]:0.3f}")
352
+ print(f" Precision: {precision_scores[s]:0.3f}")
353
+ print(f" Recall: {recall_scores[s]:0.3f}")
354
+
355
+ ## KMER SPECTRUM
356
+ kmer_lengths = [3, 6]
357
+ kmer_results = {k: [] for k in kmer_lengths}
358
+ if os.path.exists(kmer_freqs_path):
359
+ kmer_freqs = torch.load(kmer_freqs_path)
360
+ else:
361
+ kmer_freqs = {s: {
362
+ kmer_length: {'frequencies': None,
363
+ 'kmers': None} for kmer_length in
364
+ kmer_lengths} for s in range(10)}
365
+ for s in range(10):
366
+ filter_ds = val_dataset.filter(
367
+ lambda x: x['species_label'] == s,
368
+ num_proc=len(os.sched_getaffinity(0)))
369
+ print(f"Computing kmer frequencies for species class {s}")
370
+ for kmer_length in kmer_lengths:
371
+ kmer_frequencies_gt, kmer_list = compute_kmer_frequencies(
372
+ seqs=filter_ds['sequence'],
373
+ kmer_length=kmer_length
374
+ )
375
+ kmer_freqs[s][kmer_length]['frequencies'] = kmer_frequencies_gt
376
+ kmer_freqs[s][kmer_length]['kmers'] = kmer_list
377
+ torch.save(kmer_freqs, kmer_freqs_path)
378
+ for s in range(10):
379
+ print(f"Species class {s}")
380
+ mean_js_divergence = 0
381
+ for kmer_length in kmer_lengths:
382
+ kmer_frequencies_gt = kmer_freqs[s][kmer_length]['frequencies']
383
+ kmer_frequencies_generated, kmer_list = compute_kmer_frequencies(
384
+ seqs=[i['sequence'] for i in test_data if
385
+ i['species_label'] == s],
386
+ kmer_length=kmer_length
387
+ )
388
+
389
+ js_divergence = np.sum(
390
+ scipy.spatial.distance.jensenshannon(
391
+ kmer_frequencies_gt,
392
+ kmer_frequencies_generated)
393
+ )
394
+ kmer_results[kmer_length].append(js_divergence)
395
+ mean_js_divergence += js_divergence
396
+ print(
397
+ f"\tJS divergence with k={kmer_length} : {js_divergence}")
398
+ print(
399
+ f"\tMean JS divergence : {mean_js_divergence / len(kmer_lengths):0.2f}")
400
+
401
+ for k, v in kmer_results.items():
402
+ weighted_kmer_js = (np.array(v) * np.array(
403
+ list(val_weights.values()))).sum()
404
+ print(
405
+ f"Weighted mean JS divergence across classes with k={k}: {weighted_kmer_js:0.2f}")
406
+ result_dict[f"{k}mer JS"] = weighted_kmer_js
407
+
408
+ ## DISCRIMINATOR AUROC
409
+ # Hyperparams
410
+ d_model = 128
411
+ n_layer = 2
412
+
413
+ batch_size = 8
414
+ lr = 1e-4
415
+ epochs = 5
416
+
417
+ disc_data = [
418
+ {'sequence': i['sequence'], 'species_label': 0}
419
+ for i in test_data]
420
+ for s in range(10):
421
+ filter_val_ds = val_dataset.filter(
422
+ lambda x: x['species_label'] == s,
423
+ num_proc=len(os.sched_getaffinity(0)))
424
+ indices = np.random.permutation(
425
+ np.arange(len(filter_val_ds)))[:num_samples_per_class]
426
+ disc_data.extend(
427
+ [{'sequence': i['sequence'], 'species_label': 1}
428
+ for i in filter_val_ds.select(indices)]
429
+ )
430
+ print(f"Size of discriminator dataset: {len(disc_data)}")
431
+ disc_dataset_hf = datasets.Dataset.from_list(
432
+ disc_data)
433
+ disc_dataset_hf = disc_dataset_hf.train_test_split(
434
+ test_size=0.1, seed=42)
435
+
436
+ disc_dataset_train = custom_datasets.ten_species_dataset.TenSpeciesDataset(
437
+ split='train',
438
+ tokenizer=tokenizer,
439
+ max_length=classifier_config.model.length,
440
+ rc_aug=False,
441
+ add_special_tokens=classifier_config.data.add_special_tokens,
442
+ dataset=disc_dataset_hf['train']
443
+ )
444
+
445
+ disc_dataset_val = custom_datasets.ten_species_dataset.TenSpeciesDataset(
446
+ split='test',
447
+ tokenizer=tokenizer,
448
+ max_length=classifier_config.model.length,
449
+ rc_aug=False,
450
+ add_special_tokens=classifier_config.data.add_special_tokens,
451
+ dataset=disc_dataset_hf['test']
452
+ )
453
+
454
+ disc_train_dl = torch.utils.data.DataLoader(
455
+ disc_dataset_train,
456
+ batch_size=batch_size,
457
+ num_workers=0,
458
+ pin_memory=True,
459
+ shuffle=True)
460
+
461
+ disc_val_dl = torch.utils.data.DataLoader(
462
+ disc_dataset_val,
463
+ batch_size=batch_size,
464
+ num_workers=0,
465
+ pin_memory=True,
466
+ shuffle=False)
467
+
468
+ hyena_config = transformers.AutoConfig.from_pretrained(
469
+ 'LongSafari/hyenadna-small-32k-seqlen-hf',
470
+ d_model=d_model,
471
+ n_layer=n_layer,
472
+ trust_remote_code=True)
473
+ disc_model = transformers.AutoModelForSequenceClassification.from_config(
474
+ hyena_config,
475
+ pretrained=False,
476
+ num_labels=2,
477
+ problem_type='single_label_classification',
478
+ trust_remote_code=True)
479
+
480
+ optimizer = torch.optim.AdamW(
481
+ disc_model.parameters(), lr=lr, weight_decay=0,
482
+ betas=(0.9, 0.999), eps=1e-8)
483
+
484
+ disc_model.to('cuda')
485
+ losses = []
486
+ auroc_list = []
487
+ for ep in tqdm(range(epochs), desc='Epochs'):
488
+ # Train loop:
489
+ disc_model.train()
490
+ train_pbar = tqdm(disc_train_dl, desc='Train',
491
+ leave=False)
492
+ for batch in train_pbar:
493
+ labels = batch['species_label'].to('cuda')
494
+ logits = disc_model(
495
+ batch['input_ids'].to('cuda')).logits
496
+ loss = torch.nn.functional.cross_entropy(
497
+ logits.view(-1, logits.size(-1)),
498
+ labels,
499
+ ignore_index=-100,
500
+ reduction='mean')
501
+ optimizer.zero_grad()
502
+ loss.backward()
503
+ optimizer.step()
504
+ train_pbar.set_postfix({'loss': loss.item()})
505
+ losses.append(loss.item())
506
+ # Val loop:
507
+ disc_model.eval()
508
+ disc_labels = []
509
+ disc_preds = []
510
+ for batch in disc_val_dl:
511
+ disc_labels.append(
512
+ batch['species_label'].numpy())
513
+ disc_preds.append(
514
+ disc_model(
515
+ batch['input_ids'].to('cuda')
516
+ ).logits[..., 1].detach().to('cpu').numpy()
517
+ )
518
+ disc_labels = np.concatenate(disc_labels)
519
+ disc_preds = np.concatenate(disc_preds)
520
+ auroc = roc_auc_score(y_true=disc_labels, y_score=disc_preds)
521
+ auroc_list.append(auroc)
522
+ print(f"Ep {ep} - AUROC score {auroc}")
523
+ result_dict["Disc AUROC"] = auroc_list[-1]
524
+ del disc_model
525
+ print('*****************************')
526
+ return result_dict
527
+
528
+
529
+ @hydra.main(version_base=None, config_path='../configs',
530
+ config_name='config')
531
+ def main(config: omegaconf.DictConfig) -> None:
532
+ # Reproducibility
533
+ L.seed_everything(config.seed)
534
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
535
+ torch.use_deterministic_algorithms(True)
536
+ torch.backends.cudnn.benchmark = False
537
+
538
+ _print_config(config, resolve=True)
539
+ print(f"Checkpoint: {config.eval.checkpoint_path}")
540
+
541
+ tokenizer = dataloader.get_tokenizer(config)
542
+ pretrained = diffusion.Diffusion.load_from_checkpoint(
543
+ config.eval.checkpoint_path,
544
+ tokenizer=tokenizer,
545
+ config=config, logger=False)
546
+ pretrained.eval()
547
+
548
+ # Generate samples
549
+ if not os.path.exists(config.eval.generated_samples_path):
550
+ samples_per_class = {}
551
+ classes = range(config.data.num_classes)
552
+ for species in classes:
553
+ config.guidance.condition = species
554
+ print("Guidance:", ", ".join([f"{k.capitalize()} - {v}" for k, v in config.guidance.items()]))
555
+ samples = []
556
+ for _ in tqdm(
557
+ range(config.sampling.num_sample_batches), desc='Gen. batches', leave=False):
558
+ sample = pretrained.sample()
559
+ samples.extend(pretrained.tokenizer.batch_decode(sample))
560
+ samples_per_class[species] = samples
561
+ with open(config.eval.generated_samples_path, 'w') as f:
562
+ json.dump(samples_per_class, f, indent=4) # type: ignore
563
+ else:
564
+ with open(config.eval.generated_samples_path, 'r') as f:
565
+ samples_per_class = json.load(f)
566
+ samples_per_class = {int(k): v for k, v in samples_per_class.items()}
567
+
568
+ # Run eval pipeline
569
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
570
+ result_dict = run_eval_pipeline(
571
+ samples_per_class,
572
+ num_samples_per_class=config.sampling.num_sample_batches*config.sampling.batch_size,
573
+ train_weights_path=config.eval.train_weights_path,
574
+ val_weights_path=config.eval.val_weights_path,
575
+ eval_classifier_checkpoint_path=config.eval.eval_classifier_checkpoint_path,
576
+ kmer_freqs_path=config.eval.kmer_freqs_path)
577
+ result_dict['Seed'] = config.seed
578
+ result_dict['T'] = config.sampling.steps
579
+ result_dict = result_dict | {k.capitalize(): v for k, v in config.guidance.items()}
580
+ result_dict['Num Samples'] = sum([len(v) for v in samples_per_class.values()])
581
+ results_df = pd.DataFrame.from_records([result_dict])
582
+ results_df.to_csv(config.eval.results_csv_path)
583
+
584
+ if __name__ == '__main__':
585
+ main()
main.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import fsspec
5
+ import hydra
6
+ import lightning as L
7
+ import omegaconf
8
+ import rich.syntax
9
+ import rich.tree
10
+ import torch
11
+ from tqdm import tqdm
12
+ from datasets import load_from_disk
13
+ import pdb
14
+
15
+ import classifier
16
+ import dataloader
17
+ import diffusion
18
+ import eval_utils
19
+ import utils
20
+
21
+ omegaconf.OmegaConf.register_new_resolver(
22
+ 'cwd', os.getcwd)
23
+ omegaconf.OmegaConf.register_new_resolver(
24
+ 'device_count', torch.cuda.device_count)
25
+ omegaconf.OmegaConf.register_new_resolver(
26
+ 'eval', eval)
27
+ omegaconf.OmegaConf.register_new_resolver(
28
+ 'div_up', lambda x, y: (x + y - 1) // y)
29
+ omegaconf.OmegaConf.register_new_resolver(
30
+ 'if_then_else',
31
+ lambda condition, x, y: x if condition else y
32
+ )
33
+
34
+
35
+ def _load_from_checkpoint(config, tokenizer):
36
+ if 'hf' in config.backbone:
37
+ return diffusion.Diffusion(
38
+ config, tokenizer=tokenizer).to('cuda')
39
+
40
+ return diffusion.Diffusion.load_from_checkpoint(
41
+ config.eval.checkpoint_path,
42
+ tokenizer=tokenizer,
43
+ config=config, logger=False).to('cuda')
44
+
45
+
46
+ @L.pytorch.utilities.rank_zero_only
47
+ def _print_config(
48
+ config: omegaconf.DictConfig,
49
+ resolve: bool = True,
50
+ save_cfg: bool = True) -> None:
51
+ """Prints content of DictConfig using Rich library and its tree structure.
52
+
53
+ Args:
54
+ config (DictConfig): Configuration composed by Hydra.
55
+ resolve (bool): Whether to resolve reference fields of DictConfig.
56
+ save_cfg (bool): Whether to save the configuration tree to a file.
57
+ """
58
+
59
+ style = 'dim'
60
+ tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
61
+
62
+ fields = config.keys()
63
+ for field in fields:
64
+ branch = tree.add(field, style=style, guide_style=style)
65
+
66
+ config_section = config.get(field)
67
+ branch_content = str(config_section)
68
+ if isinstance(config_section, omegaconf.DictConfig):
69
+ branch_content = omegaconf.OmegaConf.to_yaml(
70
+ config_section, resolve=resolve)
71
+
72
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
73
+ rich.print(tree)
74
+ if save_cfg:
75
+ with fsspec.open(
76
+ '{}/config_tree.txt'.format(
77
+ config.checkpointing.save_dir), 'w') as fp:
78
+ rich.print(tree, file=fp)
79
+
80
+
81
+ @L.pytorch.utilities.rank_zero_only
82
+ def _print_batch(train_ds, valid_ds, tokenizer, k=64):
83
+ for dl_type, dl in [
84
+ ('train', train_ds), ('valid', valid_ds)]:
85
+ print(f'Printing {dl_type} dataloader batch.')
86
+ batch = next(iter(dl))
87
+ print('Batch input_ids.shape', batch['input_ids'].shape)
88
+ first = batch['input_ids'][0, :k]
89
+ last = batch['input_ids'][0, -k:]
90
+ print(f'First {k} tokens:', tokenizer.decode(first))
91
+ print('ids:', first)
92
+ print(f'Last {k} tokens:', tokenizer.decode(last))
93
+ print('ids:', last)
94
+
95
+
96
+ def _train(config, logger, tokenizer,
97
+ train_classifier=False):
98
+ logger.info('Starting Training.')
99
+ wandb_logger = None
100
+ if config.get('wandb', None) is not None:
101
+ wandb_logger = L.pytorch.loggers.WandbLogger(
102
+ config=omegaconf.OmegaConf.to_object(config),
103
+ ** config.wandb)
104
+
105
+ if (config.checkpointing.resume_from_ckpt
106
+ and config.checkpointing.resume_ckpt_path is not None
107
+ and utils.fsspec_exists(
108
+ config.checkpointing.resume_ckpt_path)):
109
+ ckpt_path = config.checkpointing.resume_ckpt_path
110
+ else:
111
+ ckpt_path = None
112
+
113
+ # Lightning callbacks
114
+ callbacks = []
115
+ if 'callbacks' in config:
116
+ for _, callback in config.callbacks.items():
117
+ callbacks.append(hydra.utils.instantiate(callback))
118
+
119
+ # train_ds, valid_ds = dataloader.get_dataloaders(
120
+ # config, tokenizer)
121
+ train_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/train')
122
+ val_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/val')
123
+ test_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/test')
124
+
125
+ data_module = dataloader.CustomDataModule(train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size=config.loader.batch_size)
126
+ train_ds = data_module.train_dataloader()
127
+ valid_ds = data_module.val_dataloader()
128
+
129
+ if not config.is_vision:
130
+ _print_batch(train_ds, valid_ds, tokenizer)
131
+
132
+ if train_classifier:
133
+ # This param indicates classifier will be used for
134
+ # PPLM / NOS-style guidance
135
+ # (see: https://arxiv.org/abs/2305.20009).
136
+ if getattr(config, 'is_pplm_classifier', False):
137
+ pretrained_model = _load_from_checkpoint(
138
+ config, tokenizer)
139
+ if (getattr(config.classifier_model, 'use_encoder_ema', True)
140
+ and pretrained_model.ema):
141
+ pretrained_model.load_ema_params()
142
+ pretrained_backbone = pretrained_model.backbone
143
+ # Remove the last layer for the classifier
144
+ if hasattr(pretrained_backbone, 'output_layer'): #DiT
145
+ delattr(pretrained_backbone, 'output_layer')
146
+ if hasattr(pretrained_backbone, 'model.lm_head'): #DiMamba
147
+ delattr(pretrained_backbone, 'model.lm_head')
148
+ if getattr(config.classifier_model, 'freeze_encoder', True):
149
+ for param in pretrained_backbone.parameters():
150
+ param.requires_grad = False
151
+ else:
152
+ pretrained_backbone = None
153
+
154
+ model = classifier.Classifier(
155
+ config,
156
+ tokenizer=valid_ds.tokenizer,
157
+ pretrained_backbone=pretrained_backbone)
158
+ else:
159
+ model = diffusion.Diffusion(
160
+ config, tokenizer=tokenizer)
161
+ # model = diffusion.Diffusion(
162
+ # config, tokenizer=valid_ds.tokenizer)
163
+
164
+ trainer = hydra.utils.instantiate(
165
+ config.trainer,
166
+ default_root_dir=os.getcwd(),
167
+ callbacks=callbacks,
168
+ strategy=hydra.utils.instantiate(config.strategy),
169
+ logger=wandb_logger)
170
+ trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)
171
+
172
+
173
+ def _gen_ppl_eval(config, tokenizer):
174
+ pretrained = _load_from_checkpoint(
175
+ config=config, tokenizer=tokenizer)
176
+ pretrained.eval()
177
+ samples = []
178
+ for _ in tqdm(range(config.sampling.num_sample_batches),
179
+ desc='Gen. batches', leave=False):
180
+ sample = pretrained.sample()
181
+ samples.extend(
182
+ pretrained.tokenizer.batch_decode(sample))
183
+
184
+ # Replace CLS token with BOS token (if applicable) and
185
+ # remove padding and mask tokens
186
+ tok_bos_token = tokenizer.bos_token if tokenizer.bos_token is not None else tokenizer.cls_token
187
+ samples = [
188
+ s.replace('[PAD]', '').replace('[MASK]', '').strip()
189
+ for s in samples
190
+ ]
191
+ # Add BOS token to the beginning of each sample (if not already present)
192
+ samples = [
193
+ s if s.startswith(tok_bos_token) else f"{tok_bos_token} {s}"
194
+ for s in samples
195
+ ]
196
+ del pretrained # free up space for eval
197
+ print(f"Generated {len(samples)} samples.")
198
+
199
+ generative_ppl = eval_utils.compute_generative_ppl(
200
+ samples,
201
+ eval_model_name_or_path=config.eval.generative_ppl_model_name_or_path,
202
+ gen_ppl_eval_batch_size=8,
203
+ max_length=config.model.length)
204
+ tokens = tokenizer.batch_encode_plus(
205
+ samples,
206
+ return_tensors='pt',
207
+ add_special_tokens=False,
208
+ max_length=config.model.length,
209
+ padding='max_length',
210
+ truncation=True)['input_ids']
211
+ _, counts = torch.unique(
212
+ torch.tensor(tokens), return_counts=True, sorted=False)
213
+ entropy = torch.special.entr(
214
+ counts.float() / counts.sum()).sum().item()
215
+ with open(config.eval.generated_samples_path, 'w') as f:
216
+ json.dump({
217
+ 'generative_ppl': generative_ppl,
218
+ 'entropy': entropy,
219
+ 'generated_seqs': samples,
220
+ },
221
+ f, indent=4) # type: ignore
222
+ print(f"Entropy: {entropy:0.3f}")
223
+ print(f"Gen. PPL: {generative_ppl:0.3f}")
224
+
225
+
226
+ def _ppl_eval(config, tokenizer):
227
+ print(f"Evaluating perplexity on {config.data.valid}.")
228
+ pretrained = _load_from_checkpoint(
229
+ config=config, tokenizer=tokenizer)
230
+ pretrained.eval()
231
+ if not config.eval.disable_ema:
232
+ pretrained.load_ema_params()
233
+
234
+ _, valid_ds = dataloader.get_dataloaders(
235
+ config, tokenizer, skip_train=True, valid_seed=config.seed)
236
+ ppl = eval_utils.compute_ppl(pretrained, valid_ds)
237
+ print(f"PPL: {ppl:0.3f}")
238
+
239
+
240
+ @hydra.main(version_base=None, config_path='configs',
241
+ config_name='config')
242
+ def main(config):
243
+ """Main entry point for training."""
244
+ L.seed_everything(config.seed)
245
+ _print_config(config, resolve=True, save_cfg=True)
246
+
247
+ logger = utils.get_logger(__name__)
248
+ tokenizer = dataloader.get_tokenizer(config)
249
+
250
+ if config.mode == 'gen_ppl_eval':
251
+ _gen_ppl_eval(config, tokenizer)
252
+ elif config.mode == 'ppl_eval':
253
+ _ppl_eval(config, tokenizer)
254
+ elif 'train' in config.mode:
255
+ _train(config, logger, tokenizer,
256
+ train_classifier='classifier' in config.mode)
257
+ else:
258
+ raise NotImplementedError(f"Mode {config.mode} not implemented.")
259
+
260
+
261
+ if __name__ == '__main__':
262
+ main()
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import dit
2
+ from . import dimamba
3
+ from . import ema
4
+ from . import unet
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (262 Bytes). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (260 Bytes). View file
 
models/__pycache__/bindevaluator.cpython-310.pyc ADDED
Binary file (2.63 kB). View file
 
models/__pycache__/dimamba.cpython-310.pyc ADDED
Binary file (27.9 kB). View file
 
models/__pycache__/dimamba.cpython-39.pyc ADDED
Binary file (27.6 kB). View file
 
models/__pycache__/dit.cpython-310.pyc ADDED
Binary file (14.9 kB). View file