Sukanyaaa commited on
Commit
b38c7b5
·
verified ·
1 Parent(s): d7f69ca

Upload 36 files

Browse files
Files changed (36) hide show
  1. src/__init__.py +0 -0
  2. src/__pycache__/__init__.cpython-310.pyc +0 -0
  3. src/data/__init__.py +0 -0
  4. src/data/__pycache__/__init__.cpython-310.pyc +0 -0
  5. src/data/__pycache__/pinder_datamodule.cpython-310.pyc +0 -0
  6. src/data/components/__init__.py +0 -0
  7. src/data/components/__pycache__/__init__.cpython-310.pyc +0 -0
  8. src/data/components/__pycache__/pinder_dataset.cpython-310.pyc +0 -0
  9. src/data/components/__pycache__/prepare_data.cpython-310.pyc +0 -0
  10. src/data/components/pinder_dataset.py +64 -0
  11. src/data/components/prepare_data.py +175 -0
  12. src/data/pinder_datamodule.py +167 -0
  13. src/eval.py +99 -0
  14. src/models/__init__.py +0 -0
  15. src/models/__pycache__/__init__.cpython-310.pyc +0 -0
  16. src/models/__pycache__/pinder_module.cpython-310.pyc +0 -0
  17. src/models/components/__init__.py +0 -0
  18. src/models/components/__pycache__/__init__.cpython-310.pyc +0 -0
  19. src/models/components/__pycache__/equivariant_mpnn.cpython-310.pyc +0 -0
  20. src/models/components/__pycache__/utils.cpython-310.pyc +0 -0
  21. src/models/components/equivariant_mpnn.py +231 -0
  22. src/models/components/utils.py +100 -0
  23. src/models/pinder_module.py +297 -0
  24. src/train.py +133 -0
  25. src/utils/__init__.py +5 -0
  26. src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  27. src/utils/__pycache__/instantiators.cpython-310.pyc +0 -0
  28. src/utils/__pycache__/logging_utils.cpython-310.pyc +0 -0
  29. src/utils/__pycache__/pylogger.cpython-310.pyc +0 -0
  30. src/utils/__pycache__/rich_utils.cpython-310.pyc +0 -0
  31. src/utils/__pycache__/utils.cpython-310.pyc +0 -0
  32. src/utils/instantiators.py +56 -0
  33. src/utils/logging_utils.py +57 -0
  34. src/utils/pylogger.py +51 -0
  35. src/utils/rich_utils.py +103 -0
  36. src/utils/utils.py +119 -0
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (138 Bytes). View file
 
src/data/__init__.py ADDED
File without changes
src/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
src/data/__pycache__/pinder_datamodule.cpython-310.pyc ADDED
Binary file (6.15 kB). View file
 
src/data/components/__init__.py ADDED
File without changes
src/data/components/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
src/data/components/__pycache__/pinder_dataset.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
src/data/components/__pycache__/prepare_data.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
src/data/components/pinder_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import __main__
4
+ import rootutils
5
+ import torch
6
+ from torch_geometric.data import Dataset
7
+
8
+ # setup root dir and pythonpath
9
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10
+ from src.data.components.prepare_data import CropPairedPDB
11
+
12
+ setattr(__main__, "CropPairedPDB", CropPairedPDB)
13
+
14
+
15
+ class PinderDataset(Dataset):
16
+ """Pinder dataset.
17
+
18
+ Args:
19
+ Dataset: PyTorch Geometric Dataset.
20
+ """
21
+
22
+ def __init__(self, file_paths: List[str]) -> None:
23
+ """Initialize the PinderDataset.
24
+
25
+ Args:
26
+ file_paths: List of file paths.
27
+ """
28
+ super().__init__()
29
+ self.file_paths = file_paths
30
+
31
+ @property
32
+ def processed_file_names(self) -> List[str]:
33
+ """Return the processed file names.
34
+
35
+ Returns:
36
+ List[str]: List of processed
37
+ """
38
+ return self.file_paths
39
+
40
+ def len(self) -> int:
41
+ """Return the length of the dataset.
42
+
43
+ Returns:
44
+ int: Length of the dataset
45
+ """
46
+ return len(self.processed_file_names)
47
+
48
+ def get(self, idx) -> CropPairedPDB:
49
+ """Get the data at the given index.
50
+
51
+ Args:
52
+ idx: Index of the data.
53
+
54
+ Returns:
55
+ CropPairedPDB: CropPairedPDB object.
56
+ """
57
+ data = torch.load(self.processed_file_names[idx], weights_only=False)
58
+ return data
59
+
60
+
61
+ if __name__ == "__main__":
62
+ file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
63
+ dataset = PinderDataset(file_paths=file_paths)
64
+ print(dataset[0])
src/data/components/prepare_data.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import os
3
+ from argparse import ArgumentParser
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import rootutils
8
+ import torch
9
+ from loguru import logger
10
+ from pinder.core import PinderSystem, get_index
11
+ from pinder.core.loader.geodata import PairedPDB, structure2tensor
12
+ from pinder.core.loader.structure import Structure
13
+ from tqdm.auto import tqdm
14
+
15
+ # setup root dir and pythonpath
16
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
17
+
18
+ try:
19
+ from torch_cluster import knn_graph
20
+
21
+ torch_cluster_installed = True
22
+ except ImportError:
23
+ logger.warning(
24
+ "torch-cluster is not installed!"
25
+ "Please install the appropriate library for your pytorch installation."
26
+ "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
27
+ )
28
+ torch_cluster_installed = False
29
+
30
+
31
+ def create_lr_files(system_id: str, apo_complex_path: str, save_path: str):
32
+ apo_r_path = os.path.join(save_path, f"apo_r_{system_id}.pdb")
33
+ apo_l_path = os.path.join(save_path, f"apo_l_{system_id}.pdb")
34
+ native_path = apo_complex_path.with_name(apo_complex_path.stem + f"{system_id}.pdb")
35
+ with open(native_path) as infile, open(apo_r_path, "w") as output_r, open(
36
+ apo_l_path, "w"
37
+ ) as output_l:
38
+
39
+ for line in infile:
40
+ # Check if the line is an ATOM or HETATM line and has a chain ID at position 21
41
+ if line.startswith("ATOM") or line.startswith("HETATM"):
42
+ chain_id = line[21]
43
+ if chain_id == "R":
44
+ output_r.write(line)
45
+ elif chain_id == "L":
46
+ output_l.write(line)
47
+ else:
48
+ # Write other lines (e.g., HEADER, REMARK) to both files
49
+ output_r.write(line)
50
+ output_l.write(line)
51
+ return apo_r_path, apo_l_path
52
+
53
+
54
+ class CropPairedPDB(PairedPDB):
55
+ @classmethod
56
+ def from_crop_system(
57
+ cls,
58
+ system_id: str,
59
+ root: str = "./data/",
60
+ k: int = 10,
61
+ add_edges: bool = True,
62
+ predicted_structures: bool = True,
63
+ split: str = "train",
64
+ ) -> None:
65
+ system = PinderSystem(system_id)
66
+ # Create directories if they do not exist
67
+ for subdir in ["apo", "holo", "predicted"]:
68
+ os.makedirs(Path(root) / "raw" / subdir / split, exist_ok=True)
69
+
70
+ try:
71
+ holo_complex, apo_complex, pred_complex = system.create_masked_bound_unbound_complexes(
72
+ renumber_residues=True
73
+ )
74
+ for complex_type, complex_obj in zip(
75
+ ["apo", "holo", "predicted"], [apo_complex, holo_complex, pred_complex]
76
+ ):
77
+ complex_obj.to_pdb(
78
+ Path(root) / "raw" / complex_type / split / f"{system_id}_complex.pdb"
79
+ )
80
+ except Exception as e:
81
+ logger.error(f"Error in writing PDB files: {e}, {system_id}")
82
+ return None
83
+
84
+ if predicted_structures:
85
+ apo_complex = pred_complex
86
+ save_path = os.path.join(root, "processed", "predicted", split)
87
+ else:
88
+ save_path = os.path.join(root, "processed", "apo", split)
89
+
90
+ # create the directory if it does not exist
91
+ os.makedirs(save_path, exist_ok=True)
92
+
93
+ graph = cls.from_structure_pair(
94
+ holo_complex=holo_complex,
95
+ apo_complex=apo_complex,
96
+ add_edges=add_edges,
97
+ k=k,
98
+ )
99
+ torch.save(graph, os.path.join(save_path, f"{system_id}.pt"))
100
+
101
+ @classmethod
102
+ def from_structure_pair(
103
+ cls,
104
+ holo_complex: Structure,
105
+ apo_complex: Structure,
106
+ add_edges: bool = True,
107
+ k: int = 10,
108
+ ) -> PairedPDB:
109
+ def get_structure_props(structure: Structure, start: int, end: Optional[int]):
110
+ calpha = structure.filter("atom_name", mask=["CA"])
111
+ return structure2tensor(
112
+ atom_coordinates=structure.coords[start:end],
113
+ atom_types=structure.atom_array.atom_name[start:end],
114
+ element_types=structure.atom_array.element[start:end],
115
+ residue_coordinates=calpha.coords[start:end],
116
+ residue_types=calpha.atom_array.res_name[start:end],
117
+ residue_ids=calpha.atom_array.res_id[start:end],
118
+ )
119
+
120
+ graph = cls()
121
+ r_h = (holo_complex.dataframe["chain_id"] == "R").sum()
122
+ r_a = (apo_complex.dataframe["chain_id"] == "R").sum()
123
+
124
+ holo_r_props = get_structure_props(holo_complex, 0, r_h)
125
+ holo_l_props = get_structure_props(holo_complex, r_h, None)
126
+ apo_r_props = get_structure_props(apo_complex, 0, r_a)
127
+ apo_l_props = get_structure_props(apo_complex, r_a, None)
128
+
129
+ graph["ligand"].x = apo_l_props["atom_types"]
130
+ graph["ligand"].pos = apo_l_props["atom_coordinates"]
131
+ graph["receptor"].x = apo_r_props["atom_types"]
132
+ graph["receptor"].pos = apo_r_props["atom_coordinates"]
133
+ graph["ligand"].y = holo_l_props["atom_coordinates"]
134
+ graph["receptor"].y = holo_r_props["atom_coordinates"]
135
+
136
+ if add_edges and torch_cluster_installed:
137
+ graph["ligand", "ligand"].edge_index = knn_graph(graph["ligand"].pos, k=k)
138
+ graph["receptor", "receptor"].edge_index = knn_graph(graph["receptor"].pos, k=k)
139
+
140
+ return graph
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = ArgumentParser()
145
+ parser.add_argument("--n_jobs", type=int, default=20)
146
+ parser.add_argument("--k", type=int, default=10)
147
+ parser.add_argument("--predicted_structures", action="store_true")
148
+ parser.add_argument("--split", type=str, default="train")
149
+ args = parser.parse_args()
150
+
151
+ predicted_structures = args.predicted_structures
152
+
153
+ # get indices for train, validation, and test splits
154
+ indices = get_index()
155
+
156
+ if predicted_structures:
157
+ query = '(split == "{split}") and ((apo_R == False and apo_L == False) and (predicted_R==True and predicted_L==True))'
158
+ else:
159
+ query = '(split == "{split}") and (apo_R == True and apo_L == True)'
160
+
161
+ system_idx = indices.query(query.format(split=args.split)).reset_index(drop=True)
162
+
163
+ system_ids = system_idx.id.tolist()
164
+
165
+ def process_system_id(system_id: str):
166
+ graph = CropPairedPDB.from_crop_system(
167
+ system_id,
168
+ predicted_structures=predicted_structures,
169
+ k=args.k,
170
+ split=args.split,
171
+ )
172
+ return graph
173
+
174
+ with multiprocessing.Pool(args.n_jobs) as pool:
175
+ results = list(tqdm(pool.imap(process_system_id, system_ids), total=len(system_ids)))
src/data/pinder_datamodule.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, Optional
3
+
4
+ import pandas as pd
5
+ import rootutils
6
+ from lightning import LightningDataModule
7
+ from torch_geometric.data import Dataset
8
+ from torch_geometric.loader import DataLoader
9
+
10
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
11
+
12
+ from src.data.components.pinder_dataset import PinderDataset
13
+
14
+
15
+ class PINDERDataModule(LightningDataModule):
16
+ """`LightningDataModule` for the PINDER dataset."""
17
+
18
+ def __init__(
19
+ self,
20
+ data_dir: str = "data/processed",
21
+ predicted_structures: bool = False,
22
+ high_quality: bool = False,
23
+ batch_size: int = 1,
24
+ num_workers: int = 0,
25
+ pin_memory: bool = True,
26
+ ) -> None:
27
+ """Initialize the `PINDERDataModule`.
28
+
29
+ Args:
30
+ data_dir: Data for pinder. Defaults to "data/processed".
31
+ predicted_structures: Whether to use predicted structures. Defaults to True.
32
+ batch_size: Batch size. Defaults to 64.
33
+ num_workers: Number of workers for parallel processing. Defaults to 0.
34
+ pin_memory: Whether to pin memory. Defaults to True.
35
+ """
36
+ super().__init__()
37
+
38
+ # this line allows to access init params with 'self.hparams' attribute
39
+ # also ensures init params will be stored in ckpt
40
+ self.save_hyperparameters(logger=False)
41
+
42
+ # get metadata
43
+ metadata = pd.read_csv(os.path.join(self.hparams.data_dir, "metadata.csv"))
44
+
45
+ def get_files(split: str, complex_types: list) -> list:
46
+ file_df = metadata[
47
+ (metadata["split"] == split) & (metadata["complex"].isin(complex_types))
48
+ ]
49
+ file_df["file_paths"] = file_df.apply(
50
+ lambda row: os.path.join(
51
+ "./data/processed", row["complex"], row["split"], row["file_paths"]
52
+ ),
53
+ axis=1,
54
+ )
55
+ return file_df["file_paths"].tolist()
56
+
57
+ complex_types = ["apo", "predicted"] if self.hparams.predicted_structures else ["apo"]
58
+ self.train_files = get_files("train", complex_types)
59
+ self.val_files = get_files("val", complex_types)
60
+ self.test_files = get_files("test", complex_types)
61
+
62
+ self.data_train: Optional[Dataset] = None
63
+ self.data_val: Optional[Dataset] = None
64
+ self.data_test: Optional[Dataset] = None
65
+
66
+ self.batch_size_per_device = batch_size
67
+
68
+ def setup(self, stage: Optional[str] = None) -> None:
69
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
70
+
71
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
72
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
73
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
74
+ `self.setup()` once the data is prepared and available for use.
75
+
76
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
77
+ """
78
+ # Divide batch size by the number of devices.
79
+ if self.trainer is not None:
80
+ if self.hparams.batch_size % self.trainer.world_size != 0:
81
+ raise RuntimeError(
82
+ f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
83
+ )
84
+ self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
85
+
86
+ # load and split datasets only if not loaded already
87
+ if not self.data_train and not self.data_val and not self.data_test:
88
+ self.data_train = PinderDataset(self.train_files)
89
+ self.data_val = PinderDataset(self.val_files)
90
+ self.data_test = PinderDataset(self.test_files)
91
+
92
+ def train_dataloader(self) -> DataLoader:
93
+ """Create and return the train dataloader.
94
+
95
+ :return: The train dataloader.
96
+ """
97
+ return DataLoader(
98
+ dataset=self.data_train,
99
+ batch_size=self.batch_size_per_device,
100
+ num_workers=self.hparams.num_workers,
101
+ pin_memory=self.hparams.pin_memory,
102
+ shuffle=True,
103
+ drop_last=True,
104
+ )
105
+
106
+ def val_dataloader(self) -> DataLoader:
107
+ """Create and return the validation dataloader.
108
+
109
+ :return: The validation dataloader.
110
+ """
111
+ return DataLoader(
112
+ dataset=self.data_val,
113
+ batch_size=self.batch_size_per_device,
114
+ num_workers=self.hparams.num_workers,
115
+ pin_memory=self.hparams.pin_memory,
116
+ shuffle=False,
117
+ )
118
+
119
+ def test_dataloader(self) -> DataLoader:
120
+ """Create and return the test dataloader.
121
+
122
+ :return: The test dataloader.
123
+ """
124
+ return DataLoader(
125
+ dataset=self.data_test,
126
+ batch_size=self.batch_size_per_device,
127
+ num_workers=self.hparams.num_workers,
128
+ pin_memory=self.hparams.pin_memory,
129
+ shuffle=False,
130
+ )
131
+
132
+ def teardown(self, stage: Optional[str] = None) -> None:
133
+ """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
134
+ `trainer.test()`, and `trainer.predict()`.
135
+
136
+ :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
137
+ Defaults to ``None``.
138
+ """
139
+ pass
140
+
141
+ def state_dict(self) -> Dict[Any, Any]:
142
+ """Called when saving a checkpoint. Implement to generate and save the datamodule state.
143
+
144
+ :return: A dictionary containing the datamodule state that you want to save.
145
+ """
146
+ return {}
147
+
148
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
149
+ """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
150
+ `state_dict()`.
151
+
152
+ :param state_dict: The datamodule state returned by `self.state_dict()`.
153
+ """
154
+ pass
155
+
156
+
157
+ if __name__ == "__main__":
158
+ datamodule = PINDERDataModule()
159
+ datamodule.setup()
160
+ # print(datamodule.train_files[64])
161
+ train_loader = datamodule.train_dataloader()
162
+ val_loader = datamodule.val_dataloader()
163
+ test_loader = datamodule.test_dataloader()
164
+ print(f"Number of training batches: {len(train_loader)}")
165
+ print(f"Number of validation batches: {len(val_loader)}")
166
+ print(f"Number of test batches: {len(test_loader)}")
167
+ print(next(iter(train_loader)))
src/eval.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+
3
+ import hydra
4
+ import rootutils
5
+ from lightning import LightningDataModule, LightningModule, Trainer
6
+ from lightning.pytorch.loggers import Logger
7
+ from omegaconf import DictConfig
8
+
9
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10
+ # ------------------------------------------------------------------------------------ #
11
+ # the setup_root above is equivalent to:
12
+ # - adding project root dir to PYTHONPATH
13
+ # (so you don't need to force user to install project as a package)
14
+ # (necessary before importing any local modules e.g. `from src import utils`)
15
+ # - setting up PROJECT_ROOT environment variable
16
+ # (which is used as a base for paths in "configs/paths/default.yaml")
17
+ # (this way all filepaths are the same no matter where you run the code)
18
+ # - loading environment variables from ".env" in root dir
19
+ #
20
+ # you can remove it if you:
21
+ # 1. either install project as a package or move entry files to project root dir
22
+ # 2. set `root_dir` to "." in "configs/paths/default.yaml"
23
+ #
24
+ # more info: https://github.com/ashleve/rootutils
25
+ # ------------------------------------------------------------------------------------ #
26
+
27
+ from src.utils import (
28
+ RankedLogger,
29
+ extras,
30
+ instantiate_loggers,
31
+ log_hyperparameters,
32
+ task_wrapper,
33
+ )
34
+
35
+ log = RankedLogger(__name__, rank_zero_only=True)
36
+
37
+
38
+ @task_wrapper
39
+ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
40
+ """Evaluates given checkpoint on a datamodule testset.
41
+
42
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
43
+ failure. Useful for multiruns, saving info about the crash, etc.
44
+
45
+ :param cfg: DictConfig configuration composed by Hydra.
46
+ :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
47
+ """
48
+ assert cfg.ckpt_path
49
+
50
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
51
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
52
+
53
+ log.info(f"Instantiating model <{cfg.model._target_}>")
54
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
55
+
56
+ log.info("Instantiating loggers...")
57
+ logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
58
+
59
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
60
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
61
+
62
+ object_dict = {
63
+ "cfg": cfg,
64
+ "datamodule": datamodule,
65
+ "model": model,
66
+ "logger": logger,
67
+ "trainer": trainer,
68
+ }
69
+
70
+ if logger:
71
+ log.info("Logging hyperparameters!")
72
+ log_hyperparameters(object_dict)
73
+
74
+ log.info("Starting testing!")
75
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
76
+
77
+ # for predictions use trainer.predict(...)
78
+ # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
79
+
80
+ metric_dict = trainer.callback_metrics
81
+
82
+ return metric_dict, object_dict
83
+
84
+
85
+ @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
86
+ def main(cfg: DictConfig) -> None:
87
+ """Main entry point for evaluation.
88
+
89
+ :param cfg: DictConfig configuration composed by Hydra.
90
+ """
91
+ # apply extra utilities
92
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
93
+ extras(cfg)
94
+
95
+ evaluate(cfg)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
src/models/__init__.py ADDED
File without changes
src/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
src/models/__pycache__/pinder_module.cpython-310.pyc ADDED
Binary file (8.44 kB). View file
 
src/models/components/__init__.py ADDED
File without changes
src/models/components/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
src/models/components/__pycache__/equivariant_mpnn.cpython-310.pyc ADDED
Binary file (6.84 kB). View file
 
src/models/components/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
src/models/components/equivariant_mpnn.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rootutils
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential
5
+ from torch_geometric.loader import DataLoader
6
+ from torch_geometric.nn import MessagePassing
7
+ from torch_scatter import scatter
8
+
9
+ # setup root dir and pythonpath
10
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
11
+
12
+ from src.data.components.pinder_dataset import PinderDataset
13
+ from src.models.components.utils import (
14
+ compute_euler_angles_from_rotation_matrices,
15
+ compute_rotation_matrix_from_ortho6d,
16
+ )
17
+
18
+
19
+ class EquivariantMPNNLayer(MessagePassing):
20
+ def __init__(self, emb_dim=64, out_dim=128, aggr="add"):
21
+ r"""Message Passing Neural Network Layer
22
+
23
+ This layer is equivariant to 3D rotations and translations.
24
+
25
+ Args:
26
+ emb_dim: (int) - hidden dimension d
27
+ edge_dim: (int) - edge feature dimension d_e
28
+ aggr: (str) - aggregation function \oplus (sum/mean/max)
29
+ """
30
+ # Set the aggregation function
31
+ super().__init__(aggr=aggr)
32
+
33
+ self.emb_dim = emb_dim
34
+
35
+ #
36
+ self.mlp_msg = Sequential(
37
+ Linear(2 * emb_dim + 1, emb_dim),
38
+ BatchNorm1d(emb_dim),
39
+ ReLU(),
40
+ Linear(emb_dim, emb_dim),
41
+ BatchNorm1d(emb_dim),
42
+ ReLU(),
43
+ )
44
+
45
+ self.mlp_pos = Sequential(
46
+ Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1)
47
+ ) # MLP \psi
48
+ self.mlp_upd = Sequential(
49
+ Linear(2 * emb_dim, emb_dim),
50
+ BatchNorm1d(emb_dim),
51
+ ReLU(),
52
+ Linear(emb_dim, emb_dim),
53
+ BatchNorm1d(emb_dim),
54
+ ReLU(),
55
+ ) # MLP \phi
56
+ # ===========================================
57
+
58
+ self.lin_out = Linear(emb_dim, out_dim)
59
+
60
+ def forward(self, data):
61
+ """
62
+ The forward pass updates node features h via one round of message passing.
63
+
64
+ Args:
65
+ h: (n, d) - initial node features
66
+ pos: (n, 3) - initial node coordinates
67
+ edge_index: (e, 2) - pairs of edges (i, j)
68
+ edge_attr: (e, d_e) - edge features
69
+
70
+ Returns:
71
+ out: [(n, d),(n,3)] - updated node features
72
+ """
73
+
74
+ #
75
+ h, pos, edge_index = data
76
+ h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos)
77
+ h_out = self.lin_out(h_out)
78
+ return h_out, pos_out, edge_index
79
+ # ==========================================
80
+
81
+ #
82
+ def message(self, h_i, h_j, pos_i, pos_j):
83
+ # Compute distance between nodes i and j (Euclidean distance)
84
+ # distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1)
85
+ pos_diff = pos_i - pos_j
86
+ dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
87
+
88
+ # Concatenate node features, edge features, and distance
89
+ msg = torch.cat([h_i, h_j, dists], dim=-1)
90
+ msg = self.mlp_msg(msg)
91
+ pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1)
92
+
93
+ # (e, d)
94
+ return msg, pos_diff
95
+
96
+ # ...
97
+ #
98
+ def aggregate(self, inputs, index):
99
+ """The aggregate function aggregates the messages from neighboring nodes,
100
+ according to the chosen aggregation function ('sum' by default).
101
+
102
+ Args:
103
+ inputs: (e, d) - messages m_ij from destination to source nodes
104
+ index: (e, 1) - list of source nodes for each edge/message in input
105
+
106
+ Returns:
107
+ aggr_out: (n, d) - aggregated messages m_i
108
+ """
109
+ msgs, pos_diffs = inputs
110
+
111
+ msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
112
+
113
+ pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
114
+
115
+ return msg_aggr, pos_aggr
116
+
117
+ def update(self, aggr_out, h, pos):
118
+ msg_aggr, pos_aggr = aggr_out
119
+
120
+ upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1))
121
+
122
+ upd_pos = pos + pos_aggr
123
+
124
+ return upd_out, upd_pos
125
+
126
+ def __repr__(self) -> str:
127
+ return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
128
+
129
+
130
+ class PinderMPNNModel(Module):
131
+ def __init__(self, input_dim=1, emb_dim=64, num_heads=5):
132
+ """Message Passing Neural Network model for graph property prediction
133
+
134
+ This model uses both node features and coordinates as inputs, and
135
+ is invariant to 3D rotations and translations (the constituent MPNN layers
136
+ are equivariant to 3D rotations and translations).
137
+
138
+ Args:
139
+ emb_dim: (int) - hidden dimension d
140
+ input_dim: (int) - initial node feature dimension d_n
141
+ edge_dim: (int) - edge feature dimension d_e
142
+ out_dim: (int) - output dimension (fixed to 1)
143
+ """
144
+ super().__init__()
145
+
146
+ # Linear projection for initial node features
147
+ self.lin_in_rec = Linear(input_dim, emb_dim)
148
+ self.lin_in_lig = Linear(input_dim, emb_dim)
149
+
150
+ # Stack of MPNN layers
151
+ self.receptor_mpnn = Sequential(
152
+ EquivariantMPNNLayer(emb_dim, 128, aggr="mean"),
153
+ EquivariantMPNNLayer(128, 256, aggr="mean"),
154
+ # EquivariantMPNNLayer(256, 512, aggr="mean"),
155
+ # EquivariantMPNNLayer(512, 512, aggr="mean"),
156
+ )
157
+ self.ligand_mpnn = Sequential(
158
+ EquivariantMPNNLayer(64, 128, aggr="mean"),
159
+ EquivariantMPNNLayer(128, 256, aggr="mean"),
160
+ # EquivariantMPNNLayer(256, 512, aggr="mean"),
161
+ # EquivariantMPNNLayer(512, 512, aggr="mean"),
162
+ )
163
+
164
+ # Cross-attention layer
165
+ self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
166
+ self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
167
+
168
+ # MLPs for translation prediction
169
+ self.fc_translation_rec = nn.Linear(256 + 3, 3)
170
+ self.fc_translation_lig = nn.Linear(256 + 3, 3)
171
+
172
+ def forward(self, batch):
173
+ """
174
+ The main forward pass of the model.
175
+
176
+ Args:
177
+ batch: Same as in forward_rot_trans.
178
+
179
+ Returns:
180
+ transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3)
181
+ representing the transformed ligand coordinates after applying the predicted
182
+ rotation and translation.
183
+ """
184
+ h_receptor = self.lin_in_rec(batch["receptor"].x)
185
+ h_ligand = self.lin_in_lig(batch["ligand"].x)
186
+
187
+ pos_receptor = batch["receptor"].pos
188
+ pos_ligand = batch["ligand"].pos
189
+
190
+ h_receptor, pos_receptor, _ = self.receptor_mpnn(
191
+ (h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index)
192
+ )
193
+
194
+ h_ligand, pos_ligand, _ = self.ligand_mpnn(
195
+ (h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index)
196
+ )
197
+
198
+ attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand)
199
+
200
+ attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor)
201
+
202
+ emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1)
203
+ emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1)
204
+
205
+ translation_vector_r = self.fc_translation_rec(emb_features_receptor)
206
+ translation_vector_l = self.fc_translation_lig(emb_features_ligand)
207
+
208
+ ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec)
209
+ ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig)
210
+
211
+ receptor_coords = (
212
+ compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi
213
+ )
214
+ ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi
215
+
216
+ receptor_coords = receptor_coords + translation_vector_r
217
+ ligand_coords = ligand_coords + translation_vector_l
218
+
219
+ return receptor_coords, ligand_coords
220
+
221
+
222
+ if __name__ == "__main__":
223
+ file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
224
+ dataset = PinderDataset(file_paths=file_paths * 3)
225
+ loader = DataLoader(dataset, batch_size=3, shuffle=False)
226
+ batch = next(iter(loader))
227
+ model = PinderMPNNModel()
228
+ print("Number of parameters:", sum(p.numel() for p in model.parameters()))
229
+ receptor_coords, ligand_coords = model(batch)
230
+ print(receptor_coords.shape)
231
+ print(ligand_coords.shape)
src/models/components/utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ # batch*n
5
+ def normalize_vector(v):
6
+ batch = v.shape[0]
7
+ v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
8
+ eps = torch.tensor(1e-8, device=v.device)
9
+ v_mag = torch.max(v_mag, eps)
10
+ v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
11
+ v = v / v_mag
12
+ return v
13
+
14
+
15
+ # u, v batch*n
16
+ def cross_product(u, v):
17
+ batch = u.shape[0]
18
+ # print (u.shape)
19
+ # print (v.shape)
20
+ i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
21
+ j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
22
+ k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
23
+
24
+ out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) # batch*3
25
+
26
+ return out
27
+
28
+
29
+ # poses batch*6
30
+ # poses
31
+ def compute_rotation_matrix_from_ortho6d(poses):
32
+ x_raw = poses[:, 0:3] # batch*3
33
+ y_raw = poses[:, 3:6] # batch*3
34
+
35
+ x = normalize_vector(x_raw) # batch*3
36
+ z = cross_product(x, y_raw) # batch*3
37
+ z = normalize_vector(z) # batch*3
38
+ y = cross_product(z, x) # batch*3
39
+
40
+ x = x.view(-1, 3, 1)
41
+ y = y.view(-1, 3, 1)
42
+ z = z.view(-1, 3, 1)
43
+ matrix = torch.cat((x, y, z), 2) # batch*3*3
44
+ return matrix
45
+
46
+
47
+ # input batch*4*4 or batch*3*3
48
+ # output torch batch*3 x, y, z in radiant
49
+ # the rotation is in the sequence of x,y,z
50
+ def compute_euler_angles_from_rotation_matrices(rotation_matrices):
51
+ batch = rotation_matrices.shape[0]
52
+ R = rotation_matrices
53
+ sy = torch.sqrt(R[:, 0, 0] * R[:, 0, 0] + R[:, 1, 0] * R[:, 1, 0])
54
+ singular = sy < 1e-6
55
+ singular = singular.float()
56
+
57
+ x = torch.atan2(R[:, 2, 1], R[:, 2, 2])
58
+ y = torch.atan2(-R[:, 2, 0], sy)
59
+ z = torch.atan2(R[:, 1, 0], R[:, 0, 0])
60
+
61
+ xs = torch.atan2(-R[:, 1, 2], R[:, 1, 1])
62
+ ys = torch.atan2(-R[:, 2, 0], sy)
63
+ zs = R[:, 1, 0] * 0
64
+
65
+ out_euler = torch.zeros(batch, 3, device=rotation_matrices.device)
66
+
67
+ out_euler[:, 0] = x * (1 - singular) + xs * singular
68
+ out_euler[:, 1] = y * (1 - singular) + ys * singular
69
+ out_euler[:, 2] = z * (1 - singular) + zs * singular
70
+
71
+ return out_euler
72
+
73
+
74
+ def get_R(x, y, z):
75
+ """Get rotation matrix from three rotation angles (radians). right-handed.
76
+ Args:
77
+ x: rotation angle around x-axis
78
+ y: rotation angle around y-axis
79
+ z: rotation angle around z-axis
80
+ Returns:
81
+ R: [3, 3]. rotation matrix.
82
+ """
83
+ # x
84
+ Rx = torch.tensor(
85
+ [[1, 0, 0], [0, torch.cos(x), -torch.sin(x)], [0, torch.sin(x), torch.cos(x)]],
86
+ device=x.device,
87
+ )
88
+ # y
89
+ Ry = torch.tensor(
90
+ [[torch.cos(y), 0, torch.sin(y)], [0, 1, 0], [-torch.sin(y), 0, torch.cos(y)]],
91
+ device=y.device,
92
+ )
93
+ # z
94
+ Rz = torch.tensor(
95
+ [[torch.cos(z), -torch.sin(z), 0], [torch.sin(z), torch.cos(z), 0], [0, 0, 1]],
96
+ device=z.device,
97
+ )
98
+
99
+ R = torch.mm(Rz, torch.mm(Ry, Rx))
100
+ return R
src/models/pinder_module.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Tuple
2
+
3
+ import torch
4
+ from lightning import LightningModule
5
+ from torchmetrics import MeanMetric, MinMetric
6
+ from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
7
+
8
+
9
+ class PinderLitModule(LightningModule):
10
+ """Example of a `LightningModule` for MNIST classification.
11
+
12
+ A `LightningModule` implements 8 key methods:
13
+
14
+ ```python
15
+ def __init__(self):
16
+ # Define initialization code here.
17
+
18
+ def setup(self, stage):
19
+ # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
20
+ # This hook is called on every process when using DDP.
21
+
22
+ def training_step(self, batch, batch_idx):
23
+ # The complete training step.
24
+
25
+ def validation_step(self, batch, batch_idx):
26
+ # The complete validation step.
27
+
28
+ def test_step(self, batch, batch_idx):
29
+ # The complete test step.
30
+
31
+ def predict_step(self, batch, batch_idx):
32
+ # The complete predict step.
33
+
34
+ def configure_optimizers(self):
35
+ # Define and configure optimizers and LR schedulers.
36
+ ```
37
+
38
+ Docs:
39
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ net: torch.nn.Module,
45
+ optimizer: torch.optim.Optimizer,
46
+ scheduler: torch.optim.lr_scheduler,
47
+ compile: bool,
48
+ ) -> None:
49
+ """Initialize a `MNISTLitModule`.
50
+
51
+ :param net: The model to train.
52
+ :param optimizer: The optimizer to use for training.
53
+ :param scheduler: The learning rate scheduler to use for training.
54
+ """
55
+ super().__init__()
56
+
57
+ # this line allows to access init params with 'self.hparams' attribute
58
+ # also ensures init params will be stored in ckpt
59
+ self.save_hyperparameters(logger=False)
60
+
61
+ self.net = net
62
+
63
+ # loss function
64
+ self.criterion = torch.nn.MSELoss()
65
+
66
+ # metric objects for calculating and averaging accuracy across batches
67
+ self.train_mse_ligand = MeanSquaredError()
68
+ self.val_mse_ligand = MeanSquaredError()
69
+ self.test_mse_ligand = MeanSquaredError()
70
+
71
+ self.train_mse_receptor = MeanSquaredError()
72
+ self.val_mse_receptor = MeanSquaredError()
73
+ self.test_mse_receptor = MeanSquaredError()
74
+
75
+ self.train_mae_receptor = MeanAbsoluteError()
76
+ self.val_mae_receptor = MeanAbsoluteError()
77
+ self.test_mae_receptor = MeanAbsoluteError()
78
+
79
+ self.train_mae_ligand = MeanAbsoluteError()
80
+ self.val_mae_ligand = MeanAbsoluteError()
81
+ self.test_mae_ligand = MeanAbsoluteError()
82
+
83
+ # for averaging loss across batches
84
+ self.train_loss = MeanMetric()
85
+ self.val_loss = MeanMetric()
86
+ self.test_loss = MeanMetric()
87
+
88
+ # for tracking best so far validation mse
89
+ self.val_mse_best = MinMetric()
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ """Perform a forward pass through the model `self.net`.
93
+
94
+ :param x: A tensor of images.
95
+ :return: A tensor of logits.
96
+ """
97
+ return self.net(x)
98
+
99
+ def on_train_start(self) -> None:
100
+ """Lightning hook that is called when training begins."""
101
+ # by default lightning executes validation step sanity checks before training starts,
102
+ # so it's worth to make sure validation metrics don't store results from these checks
103
+ self.val_loss.reset()
104
+ self.val_mse_ligand.reset()
105
+ self.val_mse_receptor.reset()
106
+ self.val_mae_receptor.reset()
107
+ self.val_mae_ligand.reset()
108
+ self.val_mse_best.reset()
109
+
110
+ def model_step(
111
+ self, batch: Tuple[torch.Tensor, torch.Tensor]
112
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
113
+ """Perform a single model step on a batch of data.
114
+
115
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
116
+
117
+ :return: A tuple containing (in order):
118
+ - A tensor of losses.
119
+ - A tensor of predictions.
120
+ - A tensor of target labels.
121
+ """
122
+
123
+ receptor_coords, ligand_coords = self.forward(batch)
124
+ loss_receptor = self.criterion(receptor_coords, batch["receptor"].y)
125
+ loss_ligand = self.criterion(ligand_coords, batch["ligand"].y)
126
+ loss = loss_receptor + loss_ligand
127
+ return loss, receptor_coords, ligand_coords, batch["receptor"].y, batch["ligand"].y
128
+
129
+ def training_step(
130
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
131
+ ) -> torch.Tensor:
132
+ """Perform a single training step on a batch of data from the training set.
133
+
134
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
135
+ labels.
136
+ :param batch_idx: The index of the current batch.
137
+ :return: A tensor of losses between model predictions and targets.
138
+ """
139
+ loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
140
+ batch
141
+ )
142
+
143
+ # update and log metrics
144
+ self.train_loss(loss)
145
+ self.train_mse_ligand(ligand_coords, ligand_targets)
146
+ self.train_mse_receptor(receptor_coords, receptor_targets)
147
+ self.train_mae_ligand(ligand_coords, ligand_targets)
148
+ self.train_mae_receptor(receptor_coords, receptor_targets)
149
+ self.log("train/loss", self.train_loss, on_step=True, on_epoch=False, prog_bar=True)
150
+ self.log(
151
+ "train/mse_ligand", self.train_mse_ligand, on_step=True, on_epoch=False, prog_bar=True
152
+ )
153
+ self.log(
154
+ "train/mse_receptor",
155
+ self.train_mse_receptor,
156
+ on_step=True,
157
+ on_epoch=False,
158
+ prog_bar=True,
159
+ )
160
+ self.log(
161
+ "train/mae_ligand", self.train_mae_ligand, on_step=True, on_epoch=False, prog_bar=True
162
+ )
163
+ self.log(
164
+ "train/mae_receptor",
165
+ self.train_mae_receptor,
166
+ on_step=True,
167
+ on_epoch=False,
168
+ prog_bar=True,
169
+ )
170
+
171
+ # return loss or backpropagation will fail
172
+ return loss
173
+
174
+ def on_train_epoch_end(self) -> None:
175
+ "Lightning hook that is called when a training epoch ends."
176
+ pass
177
+
178
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
179
+ """Perform a single validation step on a batch of data from the validation set.
180
+
181
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
182
+ labels.
183
+ :param batch_idx: The index of the current batch.
184
+ """
185
+ loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
186
+ batch
187
+ )
188
+
189
+ # update and log metrics
190
+ self.val_loss(loss)
191
+ self.val_mse_ligand(ligand_coords, ligand_targets)
192
+ self.val_mse_receptor(receptor_coords, receptor_targets)
193
+ self.val_mae_ligand(ligand_coords, ligand_targets)
194
+ self.val_mae_receptor(receptor_coords, receptor_targets)
195
+ self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
196
+ self.log(
197
+ "val/mse_ligand", self.val_mse_ligand, on_step=False, on_epoch=True, prog_bar=True
198
+ )
199
+ self.log(
200
+ "val/mse_receptor", self.val_mse_receptor, on_step=False, on_epoch=True, prog_bar=True
201
+ )
202
+ self.log(
203
+ "val/mae_ligand", self.val_mae_ligand, on_step=False, on_epoch=True, prog_bar=True
204
+ )
205
+ self.log(
206
+ "val/mae_receptor", self.val_mae_receptor, on_step=False, on_epoch=True, prog_bar=True
207
+ )
208
+
209
+ def on_validation_epoch_end(self) -> None:
210
+ "Lightning hook that is called when a validation epoch ends."
211
+ acc = self.val_mse_ligand.compute() # get current val acc
212
+ self.val_mse_best(acc) # update best so far val acc
213
+ # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
214
+ # otherwise metric would be reset by lightning after each epoch
215
+ self.log("val/acc_best", self.val_mse_best.compute(), sync_dist=True, prog_bar=True)
216
+
217
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
218
+ """Perform a single test step on a batch of data from the test set.
219
+
220
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
221
+ labels.
222
+ :param batch_idx: The index of the current batch.
223
+ """
224
+ loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
225
+ batch
226
+ )
227
+
228
+ # update and log metrics
229
+ self.test_loss(loss)
230
+ self.test_mse_ligand(ligand_coords, ligand_targets)
231
+ self.test_mse_receptor(receptor_coords, receptor_targets)
232
+ self.test_mae_ligand(ligand_coords, ligand_targets)
233
+ self.test_mae_receptor(receptor_coords, receptor_targets)
234
+ self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
235
+ self.log(
236
+ "test/mse_ligand", self.test_mse_ligand, on_step=False, on_epoch=True, prog_bar=True
237
+ )
238
+ self.log(
239
+ "test/mse_receptor",
240
+ self.test_mse_receptor,
241
+ on_step=False,
242
+ on_epoch=True,
243
+ prog_bar=True,
244
+ )
245
+ self.log(
246
+ "test/mae_ligand", self.test_mae_ligand, on_step=False, on_epoch=True, prog_bar=True
247
+ )
248
+ self.log(
249
+ "test/mae_receptor",
250
+ self.test_mae_receptor,
251
+ on_step=False,
252
+ on_epoch=True,
253
+ prog_bar=True,
254
+ )
255
+
256
+ def on_test_epoch_end(self) -> None:
257
+ """Lightning hook that is called when a test epoch ends."""
258
+ pass
259
+
260
+ def setup(self, stage: str) -> None:
261
+ """Lightning hook that is called at the beginning of fit (train + validate), validate,
262
+ test, or predict.
263
+
264
+ This is a good hook when you need to build models dynamically or adjust something about
265
+ them. This hook is called on every process when using DDP.
266
+
267
+ :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
268
+ """
269
+ if self.hparams.compile and stage == "fit":
270
+ self.net = torch.compile(self.net)
271
+
272
+ def configure_optimizers(self) -> Dict[str, Any]:
273
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
274
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
275
+
276
+ Examples:
277
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
278
+
279
+ :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
280
+ """
281
+ optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
282
+ if self.hparams.scheduler is not None:
283
+ scheduler = self.hparams.scheduler(optimizer=optimizer)
284
+ return {
285
+ "optimizer": optimizer,
286
+ "lr_scheduler": {
287
+ "scheduler": scheduler,
288
+ "monitor": "val/loss",
289
+ "interval": "epoch",
290
+ "frequency": 1,
291
+ },
292
+ }
293
+ return {"optimizer": optimizer}
294
+
295
+
296
+ if __name__ == "__main__":
297
+ _ = PinderLitModule(None, None, None, None)
src/train.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import hydra
4
+ import lightning as L
5
+ import rootutils
6
+ import torch
7
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
8
+ from lightning.pytorch.loggers import Logger
9
+ from omegaconf import DictConfig
10
+
11
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
12
+ # ------------------------------------------------------------------------------------ #
13
+ # the setup_root above is equivalent to:
14
+ # - adding project root dir to PYTHONPATH
15
+ # (so you don't need to force user to install project as a package)
16
+ # (necessary before importing any local modules e.g. `from src import utils`)
17
+ # - setting up PROJECT_ROOT environment variable
18
+ # (which is used as a base for paths in "configs/paths/default.yaml")
19
+ # (this way all filepaths are the same no matter where you run the code)
20
+ # - loading environment variables from ".env" in root dir
21
+ #
22
+ # you can remove it if you:
23
+ # 1. either install project as a package or move entry files to project root dir
24
+ # 2. set `root_dir` to "." in "configs/paths/default.yaml"
25
+ #
26
+ # more info: https://github.com/ashleve/rootutils
27
+ # ------------------------------------------------------------------------------------ #
28
+
29
+ from src.utils import (
30
+ RankedLogger,
31
+ extras,
32
+ get_metric_value,
33
+ instantiate_callbacks,
34
+ instantiate_loggers,
35
+ log_hyperparameters,
36
+ task_wrapper,
37
+ )
38
+
39
+ log = RankedLogger(__name__, rank_zero_only=True)
40
+
41
+
42
+ @task_wrapper
43
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
44
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
45
+ training.
46
+
47
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
48
+ failure. Useful for multiruns, saving info about the crash, etc.
49
+
50
+ :param cfg: A DictConfig configuration composed by Hydra.
51
+ :return: A tuple with metrics and dict with all instantiated objects.
52
+ """
53
+ # set seed for random number generators in pytorch, numpy and python.random
54
+ if cfg.get("seed"):
55
+ L.seed_everything(cfg.seed, workers=True)
56
+
57
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
58
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
59
+
60
+ log.info(f"Instantiating model <{cfg.model._target_}>")
61
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
62
+
63
+ log.info("Instantiating callbacks...")
64
+ callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
65
+
66
+ log.info("Instantiating loggers...")
67
+ logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
68
+
69
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
70
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
71
+
72
+ object_dict = {
73
+ "cfg": cfg,
74
+ "datamodule": datamodule,
75
+ "model": model,
76
+ "callbacks": callbacks,
77
+ "logger": logger,
78
+ "trainer": trainer,
79
+ }
80
+
81
+ if logger:
82
+ log.info("Logging hyperparameters!")
83
+ log_hyperparameters(object_dict)
84
+
85
+ if cfg.get("train"):
86
+ log.info("Starting training!")
87
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
88
+
89
+ train_metrics = trainer.callback_metrics
90
+
91
+ if cfg.get("test"):
92
+ log.info("Starting testing!")
93
+ ckpt_path = trainer.checkpoint_callback.best_model_path
94
+ if ckpt_path == "":
95
+ log.warning("Best ckpt not found! Using current weights for testing...")
96
+ ckpt_path = None
97
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
98
+ log.info(f"Best ckpt path: {ckpt_path}")
99
+
100
+ test_metrics = trainer.callback_metrics
101
+
102
+ # merge train and test metrics
103
+ metric_dict = {**train_metrics, **test_metrics}
104
+
105
+ return metric_dict, object_dict
106
+
107
+
108
+ @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
109
+ def main(cfg: DictConfig) -> Optional[float]:
110
+ """Main entry point for training.
111
+
112
+ :param cfg: DictConfig configuration composed by Hydra.
113
+ :return: Optional[float] with optimized metric value.
114
+ """
115
+ # apply extra utilities
116
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
117
+ extras(cfg)
118
+
119
+ # train the model
120
+ metric_dict, _ = train(cfg)
121
+
122
+ # safely retrieve metric value for hydra-based hyperparameter optimization
123
+ metric_value = get_metric_value(
124
+ metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
125
+ )
126
+
127
+ # return optimized metric
128
+ return metric_value
129
+
130
+
131
+ if __name__ == "__main__":
132
+ torch.set_float32_matmul_precision("high")
133
+ main()
src/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
2
+ from src.utils.logging_utils import log_hyperparameters
3
+ from src.utils.pylogger import RankedLogger
4
+ from src.utils.rich_utils import enforce_tags, print_config_tree
5
+ from src.utils.utils import extras, get_metric_value, task_wrapper
src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (546 Bytes). View file
 
src/utils/__pycache__/instantiators.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
src/utils/__pycache__/logging_utils.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
src/utils/__pycache__/pylogger.cpython-310.pyc ADDED
Binary file (2.55 kB). View file
 
src/utils/__pycache__/rich_utils.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
src/utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
src/utils/instantiators.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import hydra
4
+ from lightning import Callback
5
+ from lightning.pytorch.loggers import Logger
6
+ from omegaconf import DictConfig
7
+
8
+ from src.utils import pylogger
9
+
10
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
11
+
12
+
13
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
+ """Instantiates callbacks from config.
15
+
16
+ :param callbacks_cfg: A DictConfig object containing callback configurations.
17
+ :return: A list of instantiated callbacks.
18
+ """
19
+ callbacks: List[Callback] = []
20
+
21
+ if not callbacks_cfg:
22
+ log.warning("No callback configs found! Skipping..")
23
+ return callbacks
24
+
25
+ if not isinstance(callbacks_cfg, DictConfig):
26
+ raise TypeError("Callbacks config must be a DictConfig!")
27
+
28
+ for _, cb_conf in callbacks_cfg.items():
29
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
30
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
31
+ callbacks.append(hydra.utils.instantiate(cb_conf))
32
+
33
+ return callbacks
34
+
35
+
36
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
37
+ """Instantiates loggers from config.
38
+
39
+ :param logger_cfg: A DictConfig object containing logger configurations.
40
+ :return: A list of instantiated loggers.
41
+ """
42
+ logger: List[Logger] = []
43
+
44
+ if not logger_cfg:
45
+ log.warning("No logger configs found! Skipping...")
46
+ return logger
47
+
48
+ if not isinstance(logger_cfg, DictConfig):
49
+ raise TypeError("Logger config must be a DictConfig!")
50
+
51
+ for _, lg_conf in logger_cfg.items():
52
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
53
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
54
+ logger.append(hydra.utils.instantiate(lg_conf))
55
+
56
+ return logger
src/utils/logging_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from lightning_utilities.core.rank_zero import rank_zero_only
4
+ from omegaconf import OmegaConf
5
+
6
+ from src.utils import pylogger
7
+
8
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
9
+
10
+
11
+ @rank_zero_only
12
+ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
13
+ """Controls which config parts are saved by Lightning loggers.
14
+
15
+ Additionally saves:
16
+ - Number of model parameters
17
+
18
+ :param object_dict: A dictionary containing the following objects:
19
+ - `"cfg"`: A DictConfig object containing the main config.
20
+ - `"model"`: The Lightning model.
21
+ - `"trainer"`: The Lightning trainer.
22
+ """
23
+ hparams = {}
24
+
25
+ cfg = OmegaConf.to_container(object_dict["cfg"])
26
+ model = object_dict["model"]
27
+ trainer = object_dict["trainer"]
28
+
29
+ if not trainer.logger:
30
+ log.warning("Logger not found! Skipping hyperparameter logging...")
31
+ return
32
+
33
+ hparams["model"] = cfg["model"]
34
+
35
+ # save number of model parameters
36
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
37
+ hparams["model/params/trainable"] = sum(
38
+ p.numel() for p in model.parameters() if p.requires_grad
39
+ )
40
+ hparams["model/params/non_trainable"] = sum(
41
+ p.numel() for p in model.parameters() if not p.requires_grad
42
+ )
43
+
44
+ hparams["data"] = cfg["data"]
45
+ hparams["trainer"] = cfg["trainer"]
46
+
47
+ hparams["callbacks"] = cfg.get("callbacks")
48
+ hparams["extras"] = cfg.get("extras")
49
+
50
+ hparams["task_name"] = cfg.get("task_name")
51
+ hparams["tags"] = cfg.get("tags")
52
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
53
+ hparams["seed"] = cfg.get("seed")
54
+
55
+ # send hparams to all loggers
56
+ for logger in trainer.loggers:
57
+ logger.log_hyperparams(hparams)
src/utils/pylogger.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Mapping, Optional
3
+
4
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
+
6
+
7
+ class RankedLogger(logging.LoggerAdapter):
8
+ """A multi-GPU-friendly python command line logger."""
9
+
10
+ def __init__(
11
+ self,
12
+ name: str = __name__,
13
+ rank_zero_only: bool = False,
14
+ extra: Optional[Mapping[str, object]] = None,
15
+ ) -> None:
16
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17
+ with their rank prefixed in the log message.
18
+
19
+ :param name: The name of the logger. Default is ``__name__``.
20
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22
+ """
23
+ logger = logging.getLogger(name)
24
+ super().__init__(logger=logger, extra=extra)
25
+ self.rank_zero_only = rank_zero_only
26
+
27
+ def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
28
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
29
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
30
+ occur on that rank/process.
31
+
32
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
33
+ :param msg: The message to log.
34
+ :param rank: The rank to log at.
35
+ :param args: Additional args to pass to the underlying logging function.
36
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
37
+ """
38
+ if self.isEnabledFor(level):
39
+ msg, kwargs = self.process(msg, kwargs)
40
+ current_rank = getattr(rank_zero_only, "rank", None)
41
+ if current_rank is None:
42
+ raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
43
+ msg = rank_prefixed_message(msg, current_rank)
44
+ if self.rank_zero_only:
45
+ if current_rank == 0:
46
+ self.logger.log(level, msg, *args, **kwargs)
47
+ else:
48
+ if rank is None:
49
+ self.logger.log(level, msg, *args, **kwargs)
50
+ elif current_rank == rank:
51
+ self.logger.log(level, msg, *args, **kwargs)
src/utils/rich_utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning_utilities.core.rank_zero import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from src.utils import pylogger
13
+
14
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
15
+
16
+
17
+ @rank_zero_only
18
+ def print_config_tree(
19
+ cfg: DictConfig,
20
+ print_order: Sequence[str] = (
21
+ "data",
22
+ "model",
23
+ "callbacks",
24
+ "logger",
25
+ "trainer",
26
+ "paths",
27
+ "extras",
28
+ ),
29
+ resolve: bool = False,
30
+ save_to_file: bool = False,
31
+ ) -> None:
32
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
33
+
34
+ :param cfg: A DictConfig composed by Hydra.
35
+ :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
36
+ "callbacks", "logger", "trainer", "paths", "extras")``.
37
+ :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
38
+ :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
39
+ """
40
+ style = "dim"
41
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
42
+
43
+ queue = []
44
+
45
+ # add fields from `print_order` to queue
46
+ for field in print_order:
47
+ (
48
+ queue.append(field)
49
+ if field in cfg
50
+ else log.warning(
51
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
52
+ )
53
+ )
54
+
55
+ # add all the other fields to queue (not specified in `print_order`)
56
+ for field in cfg:
57
+ if field not in queue:
58
+ queue.append(field)
59
+
60
+ # generate config tree from queue
61
+ for field in queue:
62
+ branch = tree.add(field, style=style, guide_style=style)
63
+
64
+ config_group = cfg[field]
65
+ if isinstance(config_group, DictConfig):
66
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
67
+ else:
68
+ branch_content = str(config_group)
69
+
70
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
71
+
72
+ # print config tree
73
+ rich.print(tree)
74
+
75
+ # save config tree to file
76
+ if save_to_file:
77
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
78
+ rich.print(tree, file=file)
79
+
80
+
81
+ @rank_zero_only
82
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
83
+ """Prompts user to input tags from command line if no tags are provided in config.
84
+
85
+ :param cfg: A DictConfig composed by Hydra.
86
+ :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
87
+ """
88
+ if not cfg.get("tags"):
89
+ if "id" in HydraConfig().cfg.hydra.job:
90
+ raise ValueError("Specify tags before launching a multirun!")
91
+
92
+ log.warning("No tags provided in config. Prompting user to input tags...")
93
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
94
+ tags = [t.strip() for t in tags.split(",") if t != ""]
95
+
96
+ with open_dict(cfg):
97
+ cfg.tags = tags
98
+
99
+ log.info(f"Tags: {cfg.tags}")
100
+
101
+ if save_to_file:
102
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
103
+ rich.print(cfg.tags, file=file)
src/utils/utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from importlib.util import find_spec
3
+ from typing import Any, Callable, Dict, Optional, Tuple
4
+
5
+ from omegaconf import DictConfig
6
+
7
+ from src.utils import pylogger, rich_utils
8
+
9
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
10
+
11
+
12
+ def extras(cfg: DictConfig) -> None:
13
+ """Applies optional utilities before the task is started.
14
+
15
+ Utilities:
16
+ - Ignoring python warnings
17
+ - Setting tags from command line
18
+ - Rich config printing
19
+
20
+ :param cfg: A DictConfig object containing the config tree.
21
+ """
22
+ # return if no `extras` config
23
+ if not cfg.get("extras"):
24
+ log.warning("Extras config not found! <cfg.extras=null>")
25
+ return
26
+
27
+ # disable python warnings
28
+ if cfg.extras.get("ignore_warnings"):
29
+ log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
30
+ warnings.filterwarnings("ignore")
31
+
32
+ # prompt user to input tags from command line if none are provided in the config
33
+ if cfg.extras.get("enforce_tags"):
34
+ log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
35
+ rich_utils.enforce_tags(cfg, save_to_file=True)
36
+
37
+ # pretty print config tree using Rich library
38
+ if cfg.extras.get("print_config"):
39
+ log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
40
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
41
+
42
+
43
+ def task_wrapper(task_func: Callable) -> Callable:
44
+ """Optional decorator that controls the failure behavior when executing the task function.
45
+
46
+ This wrapper can be used to:
47
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
48
+ - save the exception to a `.log` file
49
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
50
+ - etc. (adjust depending on your needs)
51
+
52
+ Example:
53
+ ```
54
+ @utils.task_wrapper
55
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
56
+ ...
57
+ return metric_dict, object_dict
58
+ ```
59
+
60
+ :param task_func: The task function to be wrapped.
61
+
62
+ :return: The wrapped task function.
63
+ """
64
+
65
+ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
66
+ # execute the task
67
+ try:
68
+ metric_dict, object_dict = task_func(cfg=cfg)
69
+
70
+ # things to do if exception occurs
71
+ except Exception as ex:
72
+ # save exception to `.log` file
73
+ log.exception("")
74
+
75
+ # some hyperparameter combinations might be invalid or cause out-of-memory errors
76
+ # so when using hparam search plugins like Optuna, you might want to disable
77
+ # raising the below exception to avoid multirun failure
78
+ raise ex
79
+
80
+ # things to always do after either success or exception
81
+ finally:
82
+ # display output dir path in terminal
83
+ log.info(f"Output dir: {cfg.paths.output_dir}")
84
+
85
+ # always close wandb run (even if exception occurs so multirun won't fail)
86
+ if find_spec("wandb"): # check if wandb is installed
87
+ import wandb
88
+
89
+ if wandb.run:
90
+ log.info("Closing wandb!")
91
+ wandb.finish()
92
+
93
+ return metric_dict, object_dict
94
+
95
+ return wrap
96
+
97
+
98
+ def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
99
+ """Safely retrieves value of the metric logged in LightningModule.
100
+
101
+ :param metric_dict: A dict containing metric values.
102
+ :param metric_name: If provided, the name of the metric to retrieve.
103
+ :return: If a metric name was provided, the value of the metric.
104
+ """
105
+ if not metric_name:
106
+ log.info("Metric name is None! Skipping metric value retrieval...")
107
+ return None
108
+
109
+ if metric_name not in metric_dict:
110
+ raise Exception(
111
+ f"Metric value not found! <metric_name={metric_name}>\n"
112
+ "Make sure metric name logged in LightningModule is correct!\n"
113
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
114
+ )
115
+
116
+ metric_value = metric_dict[metric_name].item()
117
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
118
+
119
+ return metric_value