Gilmullin Almaz
Refactor code structure for improved readability and maintainability
72a3513
"""Module for the preparation and training of a policy network used in the expansion of
nodes in tree search.
This module includes functions for creating training datasets and running the training
process for the policy network.
"""
import warnings
from pathlib import Path
from typing import Union, List
import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import random_split
from torch_geometric.data.lightning import LightningDataset
from synplan.ml.networks.policy import PolicyNetwork
from synplan.ml.training.preprocessing import (
FilteringPolicyDataset,
RankingPolicyDataset,
)
from synplan.utils.config import PolicyNetworkConfig
from synplan.utils.logging import DisableLogger, HiddenPrints
warnings.filterwarnings("ignore")
def create_policy_dataset(
reaction_rules_path: str,
molecules_or_reactions_path: str,
output_path: str,
dataset_type: str = "filtering",
batch_size: int = 100,
num_cpus: int = 1,
training_data_ratio: float = 0.8,
):
"""
Create a training dataset for a policy network.
:param reaction_rules_path: Path to the reaction rules file.
:param molecules_or_reactions_path: Path to the molecules or reactions file used to create the training set.
:param output_path: Path to store the processed dataset.
:param dataset_type: Type of the dataset to be created ('ranking' or 'filtering').
:param batch_size: The size of batch of molecules/reactions.
:param training_data_ratio: Ratio of training data to total data.
:param num_cpus: Number of CPUs to use for data processing.
:return: A `LightningDataset` object containing training and validation datasets.
"""
with DisableLogger(), HiddenPrints():
if dataset_type == "filtering":
full_dataset = FilteringPolicyDataset(
reaction_rules_path=reaction_rules_path,
molecules_path=molecules_or_reactions_path,
output_path=output_path,
num_cpus=num_cpus,
)
elif dataset_type == "ranking":
full_dataset = RankingPolicyDataset(
reaction_rules_path=reaction_rules_path,
reactions_path=molecules_or_reactions_path,
output_path=output_path,
)
train_size = int(training_data_ratio * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(
full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
)
print(
f"Training set size: {len(train_dataset)}, validation set size: {len(val_dataset)}"
)
datamodule = LightningDataset(
train_dataset,
val_dataset,
batch_size=batch_size,
pin_memory=True,
drop_last=True,
)
return datamodule
def run_policy_training(
datamodule: LightningDataset,
config: PolicyNetworkConfig,
results_path: str,
weights_file_name: str = "policy_network",
accelerator: str = "gpu",
devices: Union[List[int], str, int] = "auto",
silent: bool = False,
) -> None:
"""
Trains a policy network using a given datamodule and training configuration.
:param datamodule: A PyTorch Lightning `DataModule` class instance. It is responsible for loading, processing, and preparing the training data for the model.
:param config: The dictionary that contains various configuration settings for the policy training process.
:param results_path: Path to store the training results and logs.
:param accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances. Default: "gpu".
:param devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto".
:param silent: Run in the silent mode with no progress bars. Default: True.
:param weights_file_name: The name of weights file to be saved. Default: "policy_network".
:return: None.
"""
results_path = Path(results_path)
results_path.mkdir(exist_ok=True)
network = PolicyNetwork(
vector_dim=config.vector_dim,
n_rules=datamodule.train_dataset.dataset.num_classes,
batch_size=config.batch_size,
dropout=config.dropout,
num_conv_layers=config.num_conv_layers,
learning_rate=config.learning_rate,
policy_type=config.policy_type,
)
checkpoint = ModelCheckpoint(
dirpath=results_path, filename=weights_file_name, monitor="val_loss", mode="min"
)
if silent:
enable_progress_bar = False
else:
enable_progress_bar = True
trainer = Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=config.num_epoch,
callbacks=[checkpoint],
logger=False,
gradient_clip_val=1.0,
enable_progress_bar=enable_progress_bar,
)
if silent:
with DisableLogger(), HiddenPrints():
trainer.fit(network, datamodule)
else:
trainer.fit(network, datamodule)
ba = round(trainer.logged_metrics["train_balanced_accuracy_y_step"].item(), 3)
print(f"Policy network balanced accuracy: {ba}")