Spaces:
Running
Running
| """Module containing commands line scripts for training and planning steps.""" | |
| import os | |
| from pathlib import Path | |
| import warnings | |
| import click | |
| import yaml | |
| from synplan.chem.data.filtering import ReactionFilterConfig, filter_reactions_from_file | |
| from synplan.chem.data.standardizing import ( | |
| ReactionStandardizationConfig, | |
| standardize_reactions_from_file, | |
| ) | |
| from synplan.chem.reaction_routes.clustering import run_cluster_cli | |
| from synplan.chem.reaction_rules.extraction import extract_rules_from_reactions | |
| from synplan.chem.utils import standardize_building_blocks | |
| from synplan.mcts.search import run_search | |
| from synplan.ml.training.reinforcement import run_updating | |
| from synplan.ml.training.supervised import create_policy_dataset, run_policy_training | |
| from synplan.utils.config import ( | |
| PolicyEvaluationConfig, | |
| PolicyNetworkConfig, | |
| RDKitEvaluationConfig, | |
| RandomEvaluationConfig, | |
| RolloutEvaluationConfig, | |
| RuleExtractionConfig, | |
| TreeConfig, | |
| TuningConfig, | |
| ValueNetworkConfig, | |
| ValueNetworkEvaluationConfig, | |
| ) | |
| from synplan.utils.loading import ( | |
| download_all_data, | |
| load_building_blocks, | |
| load_policy_function, | |
| load_reaction_rules, | |
| ) | |
| try: | |
| from importlib.metadata import PackageNotFoundError, version as _dist_version | |
| except Exception: # pragma: no cover | |
| _dist_version = None # type: ignore[assignment] | |
| PackageNotFoundError = Exception # type: ignore[assignment] | |
| def _resolve_cli_version() -> str: | |
| # Prefer installed distribution version | |
| if _dist_version is not None: | |
| try: | |
| return _dist_version("SynPlanner") | |
| except PackageNotFoundError: | |
| pass | |
| # Fallback to package attribute in editable/dev mode | |
| try: | |
| from synplan import __version__ as _pkg_version | |
| return _pkg_version | |
| except Exception: | |
| return "0.0.0+unknown" | |
| warnings.filterwarnings("ignore") | |
| def synplan(): | |
| """SynPlanner command line interface.""" | |
| def download_all_data_cli(save_to: str = ".") -> None: | |
| """Downloads all data for training, planning and benchmarking SynPlanner.""" | |
| download_all_data(save_to=save_to) | |
| def building_blocks_standardizing_cli(input_file: str, output_file: str) -> None: | |
| """Standardizes building blocks.""" | |
| standardize_building_blocks(input_file=input_file, output_file=output_file) | |
| def reaction_standardizing_cli( | |
| config_path: str, input_file: str, output_file: str, num_cpus: int | |
| ) -> None: | |
| """Standardizes reactions and remove duplicates.""" | |
| stand_config = ReactionStandardizationConfig.from_yaml(config_path) | |
| standardize_reactions_from_file( | |
| config=stand_config, | |
| input_reaction_data_path=input_file, | |
| standardized_reaction_data_path=output_file, | |
| num_cpus=num_cpus, | |
| batch_size=100, | |
| ) | |
| def reaction_filtering_cli( | |
| config_path: str, input_file: str, output_file: str, num_cpus: int | |
| ): | |
| """Filters erroneous reactions.""" | |
| reaction_check_config = ReactionFilterConfig().from_yaml(config_path) | |
| filter_reactions_from_file( | |
| config=reaction_check_config, | |
| input_reaction_data_path=input_file, | |
| filtered_reaction_data_path=output_file, | |
| num_cpus=num_cpus, | |
| batch_size=100, | |
| ) | |
| def rule_extracting_cli( | |
| config_path: str, input_file: str, output_file: str, num_cpus: int | |
| ): | |
| """Reaction rules extraction.""" | |
| reaction_rule_config = RuleExtractionConfig.from_yaml(config_path) | |
| extract_rules_from_reactions( | |
| config=reaction_rule_config, | |
| reaction_data_path=input_file, | |
| reaction_rules_path=output_file, | |
| num_cpus=num_cpus, | |
| batch_size=100, | |
| ) | |
| def ranking_policy_training_cli( | |
| config_path: str, | |
| reaction_data: str, | |
| reaction_rules: str, | |
| results_dir: str, | |
| num_cpus: int, | |
| ) -> None: | |
| """Ranking policy network training.""" | |
| policy_config = PolicyNetworkConfig.from_yaml(config_path) | |
| policy_config.policy_type = "ranking" | |
| policy_dataset_file = os.path.join(results_dir, "policy_dataset.dt") | |
| datamodule = create_policy_dataset( | |
| reaction_rules_path=reaction_rules, | |
| molecules_or_reactions_path=reaction_data, | |
| output_path=policy_dataset_file, | |
| dataset_type="ranking", | |
| batch_size=policy_config.batch_size, | |
| num_cpus=num_cpus, | |
| ) | |
| run_policy_training(datamodule, config=policy_config, results_path=results_dir) | |
| def filtering_policy_training_cli( | |
| config_path: str, | |
| molecule_data: str, | |
| reaction_rules: str, | |
| results_dir: str, | |
| num_cpus: int, | |
| ): | |
| """Filtering policy network training.""" | |
| policy_config = PolicyNetworkConfig.from_yaml(config_path) | |
| policy_config.policy_type = "filtering" | |
| policy_dataset_file = os.path.join(results_dir, "policy_dataset.ckpt") | |
| datamodule = create_policy_dataset( | |
| reaction_rules_path=reaction_rules, | |
| molecules_or_reactions_path=molecule_data, | |
| output_path=policy_dataset_file, | |
| dataset_type="filtering", | |
| batch_size=policy_config.batch_size, | |
| num_cpus=num_cpus, | |
| ) | |
| run_policy_training(datamodule, config=policy_config, results_path=results_dir) | |
| def value_network_tuning_cli( | |
| config_path: str, | |
| targets: str, | |
| reaction_rules: str, | |
| building_blocks: str, | |
| policy_network: str, | |
| value_network: str, | |
| results_dir: str, | |
| ): | |
| """Value network tuning.""" | |
| with open(config_path, "r", encoding="utf-8") as file: | |
| config = yaml.safe_load(file) | |
| policy_config = PolicyNetworkConfig.from_dict(config["node_expansion"]) | |
| policy_config.weights_path = policy_network | |
| value_config = ValueNetworkConfig.from_dict(config["value_network"]) | |
| if value_network is None: | |
| value_config.weights_path = os.path.join( | |
| results_dir, "weights", "value_network.ckpt" | |
| ) | |
| tree_config = TreeConfig.from_dict(config["tree"]) | |
| tuning_config = TuningConfig.from_dict(config["tuning"]) | |
| run_updating( | |
| targets_path=targets, | |
| tree_config=tree_config, | |
| policy_config=policy_config, | |
| value_config=value_config, | |
| reinforce_config=tuning_config, | |
| reaction_rules_path=reaction_rules, | |
| building_blocks_path=building_blocks, | |
| results_root=results_dir, | |
| ) | |
| def planning_cli( | |
| config_path: str, | |
| targets: str, | |
| reaction_rules: str, | |
| building_blocks: str, | |
| policy_network: str, | |
| value_network: str, | |
| results_dir: str, | |
| ): | |
| """Retrosynthetic planning.""" | |
| with open(config_path, "r", encoding="utf-8") as file: | |
| config = yaml.safe_load(file) | |
| search_config = {**config["tree"], **config["node_evaluation"]} | |
| policy_config = PolicyNetworkConfig.from_dict( | |
| {**config["node_expansion"], **{"weights_path": policy_network}} | |
| ) | |
| # Create evaluation config based on evaluation_type | |
| node_evaluation = config.get("node_evaluation", {}) | |
| evaluation_type = node_evaluation.get("evaluation_type", "rollout") | |
| if evaluation_type == "gcn": | |
| # Value network evaluation | |
| if value_network is None: | |
| raise ValueError("value_network is required when evaluation_type is 'gcn'") | |
| evaluation_config = ValueNetworkEvaluationConfig( | |
| weights_path=value_network, | |
| normalize=node_evaluation.get("normalize", False), | |
| ) | |
| elif evaluation_type == "rollout": | |
| # Rollout evaluation - need to load resources | |
| policy_function = load_policy_function(weights_path=policy_network) | |
| reaction_rules_list = load_reaction_rules(reaction_rules) | |
| building_blocks_set = load_building_blocks(building_blocks, standardize=False) | |
| evaluation_config = RolloutEvaluationConfig( | |
| policy_network=policy_function, | |
| reaction_rules=reaction_rules_list, | |
| building_blocks=building_blocks_set, | |
| min_mol_size=search_config.get("min_mol_size", 6), | |
| max_depth=search_config.get("max_depth", 6), | |
| normalize=node_evaluation.get("normalize", False), | |
| ) | |
| elif evaluation_type == "random": | |
| evaluation_config = RandomEvaluationConfig( | |
| normalize=node_evaluation.get("normalize", False), | |
| ) | |
| elif evaluation_type == "policy": | |
| evaluation_config = PolicyEvaluationConfig( | |
| normalize=node_evaluation.get("normalize", False), | |
| ) | |
| elif evaluation_type == "rdkit": | |
| evaluation_config = RDKitEvaluationConfig( | |
| score_function=node_evaluation.get("score_function", "sascore"), | |
| normalize=node_evaluation.get("normalize", False), | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unknown evaluation_type: {evaluation_type}. " | |
| f"Expected one of: 'gcn', 'rollout', 'random', 'policy', 'rdkit'" | |
| ) | |
| run_search( | |
| targets_path=targets, | |
| search_config=search_config, | |
| policy_config=policy_config, | |
| evaluation_config=evaluation_config, | |
| reaction_rules_path=reaction_rules, | |
| building_blocks_path=building_blocks, | |
| results_root=results_dir, | |
| ) | |
| def cluster_route_from_file_cli( | |
| targets: str, | |
| routes_file: str, | |
| cluster_results_dir: str, | |
| perform_subcluster: bool, | |
| subcluster_results_dir: str, | |
| ): | |
| """Clustering the routes from planning""" | |
| run_cluster_cli( | |
| routes_file=routes_file, | |
| cluster_results_dir=cluster_results_dir, | |
| perform_subcluster=perform_subcluster, | |
| subcluster_results_dir=subcluster_results_dir if perform_subcluster else None, | |
| ) | |
| if __name__ == "__main__": | |
| synplan() | |