Gilmullin Almaz commited on
Commit
72a3513
·
1 Parent(s): 3e5f8cc

Refactor code structure for improved readability and maintainability

Browse files
Files changed (50) hide show
  1. .gitattributes +0 -35
  2. Dockerfile +0 -21
  3. README.md +11 -14
  4. app.py +1349 -0
  5. pre-requirements.txt +7 -0
  6. requirements.txt +6 -3
  7. src/streamlit_app.py +0 -40
  8. synplan/__init__.py +3 -0
  9. synplan/chem/__init__.py +3 -0
  10. synplan/chem/data/__init__.py +0 -0
  11. synplan/chem/data/filtering.py +962 -0
  12. synplan/chem/data/standardizing.py +1187 -0
  13. synplan/chem/precursor.py +100 -0
  14. synplan/chem/reaction.py +125 -0
  15. synplan/chem/reaction_routes/__init__.py +0 -0
  16. synplan/chem/reaction_routes/clustering.py +859 -0
  17. synplan/chem/reaction_routes/io.py +286 -0
  18. synplan/chem/reaction_routes/leaving_groups.py +131 -0
  19. synplan/chem/reaction_routes/route_cgr.py +570 -0
  20. synplan/chem/reaction_routes/visualisation.py +903 -0
  21. synplan/chem/reaction_rules/__init__.py +0 -0
  22. synplan/chem/reaction_rules/extraction.py +744 -0
  23. synplan/chem/reaction_rules/manual/__init__.py +6 -0
  24. synplan/chem/reaction_rules/manual/decompositions.py +413 -0
  25. synplan/chem/reaction_rules/manual/transformations.py +532 -0
  26. synplan/chem/utils.py +225 -0
  27. synplan/interfaces/__init__.py +0 -0
  28. synplan/interfaces/cli.py +506 -0
  29. synplan/interfaces/gui.py +1323 -0
  30. synplan/mcts/__init__.py +8 -0
  31. synplan/mcts/evaluation.py +45 -0
  32. synplan/mcts/expansion.py +96 -0
  33. synplan/mcts/node.py +47 -0
  34. synplan/mcts/search.py +199 -0
  35. synplan/mcts/tree.py +635 -0
  36. synplan/ml/__init__.py +0 -0
  37. synplan/ml/networks/__init__.py +0 -0
  38. synplan/ml/networks/modules.py +234 -0
  39. synplan/ml/networks/policy.py +137 -0
  40. synplan/ml/networks/value.py +67 -0
  41. synplan/ml/training/__init__.py +11 -0
  42. synplan/ml/training/preprocessing.py +516 -0
  43. synplan/ml/training/reinforcement.py +379 -0
  44. synplan/ml/training/supervised.py +153 -0
  45. synplan/utils/__init__.py +4 -0
  46. synplan/utils/config.py +543 -0
  47. synplan/utils/files.py +226 -0
  48. synplan/utils/loading.py +151 -0
  49. synplan/utils/logging.py +179 -0
  50. synplan/utils/visualisation.py +1365 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile DELETED
@@ -1,21 +0,0 @@
1
- FROM python:3.9-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- software-properties-common \
9
- git \
10
- && rm -rf /var/lib/apt/lists/*
11
-
12
- COPY requirements.txt ./
13
- COPY src/ ./src/
14
-
15
- RUN pip3 install -r requirements.txt
16
-
17
- EXPOSE 8501
18
-
19
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,20 +1,17 @@
1
  ---
2
- title: Synplanner Dev
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
  pinned: false
11
- short_description: Developers mode for synplanner
12
  license: mit
 
13
  ---
14
 
15
- # Welcome to Streamlit!
 
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
1
  ---
2
+ title: SynPlanner GUI
3
+ emoji: 🧪
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.37.0
8
+ app_file: app.py
 
9
  pinned: false
 
10
  license: mit
11
+ python_version: 3.11.9
12
  ---
13
 
14
+ # SynPlanner Graphical User Interface (GUI)
15
+ Try the GUI to find reaction paths...
16
 
17
+ **documentation to be done**
 
 
 
app.py ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import pickle
3
+ import re
4
+ import uuid
5
+ import io
6
+ import zipfile
7
+
8
+ import pandas as pd
9
+ import streamlit as st
10
+ from CGRtools.files import SMILESRead
11
+ from streamlit_ketcher import st_ketcher
12
+ from huggingface_hub import hf_hub_download
13
+ from huggingface_hub.utils import disable_progress_bars
14
+
15
+
16
+ from synplan.mcts.expansion import PolicyNetworkFunction
17
+ from synplan.mcts.search import extract_tree_stats
18
+ from synplan.mcts.tree import Tree
19
+ from synplan.chem.utils import mol_from_smiles
20
+ from synplan.chem.reaction_routes.route_cgr import *
21
+ from synplan.chem.reaction_routes.clustering import *
22
+
23
+ from synplan.utils.visualisation import (
24
+ routes_clustering_report,
25
+ routes_subclustering_report,
26
+ generate_results_html,
27
+ html_top_routes_cluster,
28
+ get_route_svg,
29
+ get_route_svg_from_json,
30
+ get_route_svg_mod
31
+ )
32
+ from synplan.utils.config import TreeConfig, PolicyNetworkConfig
33
+ from synplan.utils.loading import load_reaction_rules, load_building_blocks
34
+
35
+
36
+ import psutil
37
+ import gc
38
+
39
+
40
+ disable_progress_bars("huggingface_hub")
41
+
42
+ smiles_parser = SMILESRead.create_parser(ignore=True)
43
+ DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
44
+
45
+
46
+ # --- Helper Functions ---
47
+ def download_button(
48
+ object_to_download, download_filename, button_text, pickle_it=False
49
+ ):
50
+ """
51
+ Issued from
52
+ Generates a link to download the given object_to_download.
53
+ Params:
54
+ ------
55
+ object_to_download: The object to be downloaded.
56
+ download_filename (str): filename and extension of file. e.g. mydata.csv,
57
+ some_txt_output.txt download_link_text (str): Text to display for download
58
+ link.
59
+ button_text (str): Text to display on download button (e.g. 'click here to download file')
60
+ pickle_it (bool): If True, pickle file.
61
+ Returns:
62
+ -------
63
+ (str): the anchor tag to download object_to_download
64
+ Examples:
65
+ --------
66
+ download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
67
+ download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
68
+ """
69
+ if pickle_it:
70
+ try:
71
+ object_to_download = pickle.dumps(object_to_download)
72
+ except pickle.PicklingError as e:
73
+ st.write(e)
74
+ return None
75
+
76
+ else:
77
+ if isinstance(object_to_download, bytes):
78
+ pass
79
+
80
+ elif isinstance(object_to_download, pd.DataFrame):
81
+ object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
82
+
83
+ try:
84
+ b64 = base64.b64encode(object_to_download.encode()).decode()
85
+ except AttributeError:
86
+ b64 = base64.b64encode(object_to_download).decode()
87
+
88
+ button_uuid = str(uuid.uuid4()).replace("-", "")
89
+ button_id = re.sub("\d+", "", button_uuid)
90
+
91
+ custom_css = f"""
92
+ <style>
93
+ #{button_id} {{
94
+ background-color: rgb(255, 255, 255);
95
+ color: rgb(38, 39, 48);
96
+ text-decoration: none;
97
+ border-radius: 4px;
98
+ border-width: 1px;
99
+ border-style: solid;
100
+ border-color: rgb(230, 234, 241);
101
+ border-image: initial;
102
+ }}
103
+ #{button_id}:hover {{
104
+ border-color: rgb(246, 51, 102);
105
+ color: rgb(246, 51, 102);
106
+ }}
107
+ #{button_id}:active {{
108
+ box-shadow: none;
109
+ background-color: rgb(246, 51, 102);
110
+ color: white;
111
+ }}
112
+ </style> """
113
+
114
+ dl_link = (
115
+ custom_css
116
+ + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
117
+ )
118
+ return dl_link
119
+
120
+
121
+ @st.cache_resource
122
+ def load_planning_resources_cached(): # Renamed to avoid conflict if main calls it directly
123
+ building_blocks_path = hf_hub_download(
124
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
125
+ filename="building_blocks_em_sa_ln.smi",
126
+ subfolder="building_blocks",
127
+ local_dir=".",
128
+ )
129
+ ranking_policy_weights_path = hf_hub_download(
130
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
131
+ filename="ranking_policy_network.ckpt",
132
+ subfolder="uspto/weights",
133
+ local_dir=".",
134
+ )
135
+ reaction_rules_path = hf_hub_download(
136
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
137
+ filename="uspto_reaction_rules.pickle",
138
+ subfolder="uspto",
139
+ local_dir=".",
140
+ )
141
+ return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
142
+
143
+
144
+ # --- GUI Sections ---
145
+
146
+
147
+ def initialize_app():
148
+ """1. Initialization: Setting up the main window, layout, and initial widgets."""
149
+ st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
150
+
151
+ # Initialize session state variables if they don't exist.
152
+ if "planning_done" not in st.session_state:
153
+ st.session_state.planning_done = False
154
+ if "tree" not in st.session_state:
155
+ st.session_state.tree = None
156
+ if "res" not in st.session_state:
157
+ st.session_state.res = None
158
+ if "target_smiles" not in st.session_state:
159
+ st.session_state.target_smiles = (
160
+ "" # Initial value, might be overwritten by ketcher
161
+ )
162
+
163
+ # Clustering state
164
+ if "clustering_done" not in st.session_state:
165
+ st.session_state.clustering_done = False
166
+ if "clusters" not in st.session_state:
167
+ st.session_state.clusters = None
168
+ if "reactions_dict" not in st.session_state:
169
+ st.session_state.reactions_dict = None
170
+ if "num_clusters_setting" not in st.session_state: # Store the setting used
171
+ st.session_state.num_clusters_setting = 10
172
+ if "route_cgrs_dict" not in st.session_state:
173
+ st.session_state.route_cgrs_dict = None
174
+ if "sb_cgrs_dict" not in st.session_state:
175
+ st.session_state.sb_cgrs_dict = None
176
+ if "route_json" not in st.session_state:
177
+ st.session_state.route_json = None
178
+
179
+ # Subclustering state
180
+ if "subclustering_done" not in st.session_state:
181
+ st.session_state.subclustering_done = False
182
+ if "subclusters" not in st.session_state: # Renamed from 'sub' for clarity
183
+ st.session_state.subclusters = None
184
+
185
+ # Download state (less critical now with direct download links)
186
+ if "clusters_downloaded" not in st.session_state: # Example, might not be needed
187
+ st.session_state.clusters_downloaded = False
188
+
189
+ if "ketcher" not in st.session_state: # For ketcher persistence
190
+ st.session_state.ketcher = DEFAULT_MOL
191
+
192
+ intro_text = """
193
+ This is a demo of the graphical user interface of
194
+ [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
195
+ SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
196
+
197
+ More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
198
+ """
199
+ st.title("`SynPlanner GUI`")
200
+ st.write(intro_text)
201
+
202
+
203
+ def setup_sidebar():
204
+ """2. Sidebar: Handling the widgets and logic within the sidebar area."""
205
+ # st.sidebar.image("img/logo.png") # Assuming img/logo.png is available
206
+ st.sidebar.title("Docs")
207
+ st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
208
+
209
+ st.sidebar.title("Tutorials")
210
+ st.sidebar.markdown(
211
+ "https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials"
212
+ )
213
+
214
+ st.sidebar.title("Paper")
215
+ st.sidebar.markdown(
216
+ "https://chemrxiv.org/engage/chemrxiv/article-details/66add90bc9c6a5c07ae65796"
217
+ )
218
+
219
+ st.sidebar.title("Issues")
220
+ st.sidebar.markdown(
221
+ "[Report a bug 🐞](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
222
+ )
223
+
224
+
225
+ def handle_molecule_input():
226
+ """3. Molecule Input: Managing the input area for molecule data with two-way synchronization."""
227
+ st.header("Molecule input")
228
+ st.markdown(
229
+ """
230
+ You can provide a molecular structure by either providing:
231
+ * SMILES string + Enter
232
+ * Draw it + Apply
233
+ """
234
+ )
235
+
236
+ if "shared_smiles" not in st.session_state:
237
+ st.session_state.shared_smiles = st.session_state.get("ketcher", DEFAULT_MOL)
238
+
239
+ if "ketcher_render_count" not in st.session_state:
240
+ st.session_state.ketcher_render_count = 0
241
+
242
+ def text_input_changed_callback():
243
+ new_text_value = (
244
+ st.session_state.smiles_text_input_key_for_sync
245
+ ) # Key of the text_input
246
+ if new_text_value != st.session_state.shared_smiles:
247
+ st.session_state.shared_smiles = new_text_value
248
+ st.session_state.ketcher = new_text_value
249
+ st.session_state.ketcher_render_count += 1
250
+
251
+ # SMILES Text Input
252
+ st.text_input(
253
+ "SMILES:",
254
+ value=st.session_state.shared_smiles,
255
+ key="smiles_text_input_key_for_sync", # Unique key for this widget
256
+ on_change=text_input_changed_callback,
257
+ help="Enter SMILES string and press Enter. The drawing will update, and vice-versa.",
258
+ )
259
+
260
+ ketcher_key = f"ketcher_widget_for_sync_{st.session_state.ketcher_render_count}"
261
+ smile_code_output_from_ketcher = st_ketcher(
262
+ st.session_state.shared_smiles, key=ketcher_key
263
+ )
264
+
265
+ if smile_code_output_from_ketcher != st.session_state.shared_smiles:
266
+ st.session_state.shared_smiles = smile_code_output_from_ketcher
267
+ st.session_state.ketcher = smile_code_output_from_ketcher
268
+ st.rerun()
269
+
270
+ current_smiles_for_planning = st.session_state.shared_smiles
271
+
272
+ last_planned_smiles = st.session_state.get("target_smiles")
273
+ if (
274
+ last_planned_smiles
275
+ and current_smiles_for_planning != last_planned_smiles
276
+ and st.session_state.get("planning_done", False)
277
+ ):
278
+ st.warning(
279
+ "Molecule structure has changed since the last successful planning run. "
280
+ "Results shown below (if any) are for the previous molecule. "
281
+ "Please re-run planning for the current structure."
282
+ )
283
+
284
+ # Ensure st.session_state.ketcher is consistent for other parts of the app
285
+ if st.session_state.get("ketcher") != current_smiles_for_planning:
286
+ st.session_state.ketcher = current_smiles_for_planning
287
+
288
+ return current_smiles_for_planning
289
+
290
+
291
+ def setup_planning_options():
292
+ """4. Planning: Encapsulating the logic related to the "planning" functionality."""
293
+ st.header("Launch calculation")
294
+ st.markdown(
295
+ """If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
296
+ )
297
+
298
+ st.markdown(
299
+ f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
300
+ )
301
+
302
+ st.subheader("Planning options")
303
+ st.markdown(
304
+ """
305
+ The description of each option can be found in the
306
+ [Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
307
+ """
308
+ )
309
+
310
+ col_options_1, col_options_2 = st.columns(2, gap="medium")
311
+ with col_options_1:
312
+ search_strategy_input = st.selectbox(
313
+ label="Search strategy",
314
+ options=(
315
+ "Expansion first",
316
+ "Evaluation first",
317
+ ),
318
+ index=0,
319
+ key="search_strategy_input",
320
+ )
321
+ ucb_type = st.selectbox(
322
+ label="UCB type",
323
+ options=("uct", "puct", "value"),
324
+ index=0,
325
+ key="ucb_type_input",
326
+ )
327
+ c_ucb = st.number_input(
328
+ "C coefficient of UCB",
329
+ value=0.1,
330
+ placeholder="Type a number...",
331
+ key="c_ucb_input",
332
+ )
333
+
334
+ with col_options_2:
335
+ max_iterations = st.slider(
336
+ "Total number of MCTS iterations",
337
+ min_value=50,
338
+ max_value=3000,
339
+ value=1000,
340
+ key="max_iterations_slider",
341
+ )
342
+ max_depth = st.slider(
343
+ "Maximal number of reaction steps",
344
+ min_value=3,
345
+ max_value=9,
346
+ value=6,
347
+ key="max_depth_slider",
348
+ )
349
+ min_mol_size = st.slider(
350
+ "Minimum size of a molecule to be precursor",
351
+ min_value=0,
352
+ max_value=7,
353
+ value=0,
354
+ key="min_mol_size_slider",
355
+ help="Number of non-hydrogen atoms in molecule",
356
+ )
357
+
358
+ search_strategy_translator = {
359
+ "Expansion first": "expansion_first",
360
+ "Evaluation first": "evaluation_first",
361
+ }
362
+ search_strategy = search_strategy_translator[search_strategy_input]
363
+
364
+ planning_params = {
365
+ "search_strategy": search_strategy,
366
+ "ucb_type": ucb_type,
367
+ "c_ucb": c_ucb,
368
+ "max_iterations": max_iterations,
369
+ "max_depth": max_depth,
370
+ "min_mol_size": min_mol_size,
371
+ }
372
+
373
+ if st.button("Start retrosynthetic planning", key="submit_planning_button"):
374
+ # Reset downstream states if replanning
375
+ st.session_state.planning_done = False
376
+ st.session_state.clustering_done = False
377
+ st.session_state.subclustering_done = False
378
+ st.session_state.tree = None
379
+ st.session_state.res = None
380
+ st.session_state.clusters = None
381
+ st.session_state.reactions_dict = None
382
+ st.session_state.subclusters = None
383
+ st.session_state.route_cgrs_dict = None
384
+ st.session_state.sb_cgrs_dict = None
385
+ st.session_state.route_json = None
386
+ active_smile_code = st.session_state.get(
387
+ "ketcher", DEFAULT_MOL
388
+ ) # Get current SMILES
389
+ st.session_state.target_smiles = (
390
+ active_smile_code # Store the SMILES used for this run
391
+ )
392
+
393
+ try:
394
+ target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
395
+ if target_molecule is None:
396
+ raise ValueError(f"Could not parse the input SMILES: {active_smile_code}")
397
+
398
+ (
399
+ building_blocks_path,
400
+ ranking_policy_weights_path,
401
+ reaction_rules_path,
402
+ ) = load_planning_resources_cached()
403
+ with st.spinner("Running retrosynthetic planning..."):
404
+ with st.status("Loading resources...", expanded=False) as status:
405
+ st.write("Loading building blocks...")
406
+ building_blocks = load_building_blocks(
407
+ building_blocks_path, standardize=False
408
+ )
409
+ st.write("Loading reaction rules...")
410
+ reaction_rules = load_reaction_rules(reaction_rules_path)
411
+ st.write("Loading policy network...")
412
+ policy_config = PolicyNetworkConfig(
413
+ weights_path=ranking_policy_weights_path
414
+ )
415
+ policy_function = PolicyNetworkFunction(
416
+ policy_config=policy_config
417
+ )
418
+ status.update(label="Resources loaded!", state="complete")
419
+
420
+ tree_config = TreeConfig(
421
+ search_strategy=planning_params["search_strategy"],
422
+ evaluation_type="rollout",
423
+ max_iterations=planning_params["max_iterations"],
424
+ max_depth=planning_params["max_depth"],
425
+ min_mol_size=planning_params["min_mol_size"],
426
+ init_node_value=0.5,
427
+ ucb_type=planning_params["ucb_type"],
428
+ c_ucb=planning_params["c_ucb"],
429
+ silent=True,
430
+ )
431
+
432
+ tree = Tree(
433
+ target=target_molecule,
434
+ config=tree_config,
435
+ reaction_rules=reaction_rules,
436
+ building_blocks=building_blocks,
437
+ expansion_function=policy_function,
438
+ evaluation_function=None,
439
+ )
440
+
441
+ mcts_progress_text = "Running MCTS iterations..."
442
+ mcts_bar = st.progress(0, text=mcts_progress_text)
443
+ for step, (solved, route_id) in enumerate(tree):
444
+ progress_value = min(
445
+ 1.0, (step + 1) / planning_params["max_iterations"]
446
+ )
447
+ mcts_bar.progress(
448
+ progress_value,
449
+ text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
450
+ )
451
+
452
+ res = extract_tree_stats(tree, target_molecule)
453
+
454
+ st.session_state["tree"] = tree
455
+ st.session_state["res"] = res
456
+ st.session_state.planning_done = True
457
+ st.rerun()
458
+
459
+ except (ValueError, KeyError, FileNotFoundError, TypeError) as e:
460
+ st.error(f"An error occurred during planning: {e}")
461
+ st.session_state.planning_done = False
462
+
463
+
464
+ def display_planning_results():
465
+ """5. Planning Results Display: Handling the presentation of results."""
466
+ if st.session_state.get("planning_done", False):
467
+ res = st.session_state.res
468
+ tree = st.session_state.tree
469
+
470
+ if res is None or tree is None:
471
+ st.error(
472
+ "Planning results are missing from session state. Please re-run planning."
473
+ )
474
+ st.session_state.planning_done = False # Reset state
475
+ return # Exit this function if no results
476
+
477
+ if res.get("solved", False): # Use .get for safety
478
+ st.header("Planning Results")
479
+ winning_nodes = (
480
+ sorted(set(tree.winning_nodes))
481
+ if hasattr(tree, "winning_nodes") and tree.winning_nodes
482
+ else []
483
+ )
484
+ st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
485
+
486
+ st.subheader("Examples of found retrosynthetic routes")
487
+ image_counter = 0
488
+ visualised_route_ids = set()
489
+
490
+ if not winning_nodes:
491
+ st.warning(
492
+ "Planning solved, but no winning nodes found in the tree object."
493
+ )
494
+ else:
495
+ for n, route_id in enumerate(winning_nodes):
496
+ if image_counter >= 3:
497
+ break
498
+ if route_id not in visualised_route_ids:
499
+ try:
500
+ visualised_route_ids.add(route_id)
501
+ num_steps = len(tree.synthesis_route(route_id))
502
+ route_score = round(tree.route_score(route_id), 3)
503
+ svg = get_route_svg(tree, route_id)
504
+ if svg:
505
+ st.image(
506
+ svg,
507
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
508
+ )
509
+ image_counter += 1
510
+ else:
511
+ st.warning(
512
+ f"Could not generate SVG for route {route_id}."
513
+ )
514
+ except Exception as e:
515
+ st.error(f"Error displaying route {route_id}: {e}")
516
+ else: # Not solved
517
+ st.header("Planning Results")
518
+ st.warning(
519
+ "No reaction path found for the target molecule with the current settings."
520
+ )
521
+ st.write(
522
+ "Find below the unfinished pathways"
523
+ )
524
+ image_counter = 0
525
+ for route_id in list(tree.nodes.keys())[1:tree.config.max_iterations:50]:
526
+ svg = get_route_svg_mod(tree, route_id)
527
+ # display(SVG(get_route_svg_mod(tree, route_id)))
528
+ if svg:
529
+ st.image(
530
+ svg,
531
+ caption=f"Route {route_id};",
532
+ )
533
+ image_counter += 1
534
+ reactions = tree.synthesis_route(route_id)
535
+ for reaction in reactions:
536
+ st.write(reaction)
537
+ else:
538
+ st.warning(
539
+ f"Could not generate SVG for route {route_id}."
540
+ )
541
+ if image_counter >= 20:
542
+ break
543
+
544
+ # st.warning(
545
+ # "No reaction path found for the target molecule with the current settings."
546
+ # )
547
+ # st.write(
548
+ # "Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
549
+ # )
550
+ # stat_col, _ = st.columns(2)
551
+ # with stat_col:
552
+ # st.subheader("Run Statistics (No Solution)")
553
+ # try:
554
+ # if (
555
+ # "target_smiles" not in res
556
+ # and "target_smiles" in st.session_state
557
+ # ):
558
+ # res["target_smiles"] = st.session_state.target_smiles
559
+ # cols_to_show = [
560
+ # col
561
+ # for col in [
562
+ # "target_smiles",
563
+ # "num_nodes",
564
+ # "num_iter",
565
+ # "search_time",
566
+ # ]
567
+ # if col in res
568
+ # ]
569
+ # if cols_to_show:
570
+ # df = pd.DataFrame(res, index=[0])[cols_to_show]
571
+ # st.dataframe(df)
572
+ # else:
573
+ # st.write("No statistics to display for the unsuccessful run.")
574
+ # except Exception as e:
575
+ # st.error(f"Error displaying statistics: {e}")
576
+ # st.write(res)
577
+
578
+
579
+ def download_planning_results():
580
+ """6. Planning Results Download: Providing functionality to download."""
581
+ if (
582
+ st.session_state.get("planning_done", False)
583
+ and st.session_state.res
584
+ and st.session_state.res.get("solved", False)
585
+ ):
586
+ res = st.session_state.res
587
+ tree = st.session_state.tree
588
+ # This section is usually placed within a column in the original script
589
+ # We'll assume it's called after display_planning_results and can use a new column or area.
590
+ # For proper layout, this should be integrated with display_planning_results' columns.
591
+ # For now, creating a placeholder or separate section for downloads:
592
+ # st.subheader("Downloads") # This might be redundant if called within a layout context.
593
+
594
+ # The original code places downloads in the second column of planning results.
595
+ # To replicate, we'd need to pass the column object or call this within that context.
596
+ # Simulating this by just creating the download links:
597
+ try:
598
+ html_body = generate_results_html(tree, html_path=None, extended=True)
599
+ dl_html = download_button(
600
+ html_body,
601
+ f"results_synplanner_{st.session_state.target_smiles}.html",
602
+ "Download results (HTML)",
603
+ )
604
+ if dl_html:
605
+ st.markdown(dl_html, unsafe_allow_html=True)
606
+
607
+ try:
608
+ res_df = pd.DataFrame(res, index=[0])
609
+ dl_csv = download_button(
610
+ res_df,
611
+ f"stats_synplanner_{st.session_state.target_smiles}.csv",
612
+ "Download statistics (CSV)",
613
+ )
614
+ if dl_csv:
615
+ st.markdown(dl_csv, unsafe_allow_html=True)
616
+ except Exception as e:
617
+ st.error(f"Could not prepare statistics CSV for download: {e}")
618
+
619
+ except Exception as e:
620
+ st.error(f"Error generating download links for planning results: {e}")
621
+
622
+
623
+ def setup_clustering():
624
+ """7. Clustering: Encapsulating the logic related to the "clustering" functionality."""
625
+ if (
626
+ st.session_state.get("planning_done", False)
627
+ and st.session_state.res
628
+ and st.session_state.res.get("solved", False)
629
+ ):
630
+ st.divider()
631
+ st.header("Clustering the retrosynthetic routes")
632
+
633
+ if st.button("Run Clustering", key="submit_clustering_button"):
634
+ # st.session_state.num_clusters_setting = num_clusters_input
635
+ st.session_state.clustering_done = False
636
+ st.session_state.subclustering_done = False
637
+ st.session_state.clusters = None
638
+ st.session_state.reactions_dict = None
639
+ st.session_state.subclusters = None
640
+ st.session_state.route_cgrs_dict = None
641
+ st.session_state.sb_cgrs_dict = None
642
+ st.session_state.route_json = None
643
+
644
+ with st.spinner("Performing clustering..."):
645
+ try:
646
+ current_tree = st.session_state.tree
647
+ if not current_tree:
648
+ st.error("Tree object not found. Please re-run planning.")
649
+ return
650
+
651
+ st.write("Calculating RoutesCGRs...")
652
+ route_cgrs_dict = compose_all_route_cgrs(current_tree)
653
+ st.write("Processing SB-CGRs...")
654
+ sb_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
655
+
656
+ results = cluster_routes(
657
+ sb_cgrs_dict, use_strat=False
658
+ ) # num_clusters was removed from args
659
+ results = dict(sorted(results.items(), key=lambda x: float(x[0])))
660
+
661
+ st.session_state.clusters = results
662
+ st.session_state.route_cgrs_dict = route_cgrs_dict
663
+ st.session_state.sb_cgrs_dict = sb_cgrs_dict
664
+ st.write("Extracting reactions...")
665
+ st.session_state.reactions_dict = extract_reactions(current_tree)
666
+ st.session_state.route_json = make_json(st.session_state.reactions_dict)
667
+
668
+ if (
669
+ st.session_state.clusters is not None
670
+ and st.session_state.reactions_dict is not None
671
+ ): # Check for None explicitly
672
+ st.session_state.clustering_done = True
673
+ st.success(
674
+ f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
675
+ )
676
+ else:
677
+ st.error("Clustering failed or returned empty results.")
678
+ st.session_state.clustering_done = False
679
+
680
+ del results # route_cgrs_dict, sb_cgrs_dict are stored
681
+ gc.collect()
682
+ st.rerun()
683
+ except Exception as e:
684
+ st.error(f"An error occurred during clustering: {e}")
685
+ st.session_state.clustering_done = False
686
+
687
+
688
+ def display_clustering_results():
689
+ """8. Clustering Results Display: Handling the presentation of results."""
690
+ if st.session_state.get("clustering_done", False):
691
+ clusters = st.session_state.clusters
692
+ # reactions_dict = st.session_state.reactions_dict # Needed for download, not directly for display here
693
+ tree = st.session_state.tree
694
+ MAX_DISPLAY_CLUSTERS_DATA = 10
695
+
696
+ if (
697
+ clusters is None or tree is None
698
+ ): # reactions_dict removed as not critical for display part
699
+ st.error(
700
+ "Clustering results (clusters or tree) are missing. Please re-run clustering."
701
+ )
702
+ st.session_state.clustering_done = False
703
+ return
704
+
705
+ st.subheader(f"Best routes from {len(clusters)} Found Clusters")
706
+ clusters_items = list(clusters.items())
707
+ first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
708
+ remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
709
+
710
+ for cluster_num, group_data in first_items:
711
+ if (
712
+ not group_data
713
+ or "route_ids" not in group_data
714
+ or not group_data["route_ids"]
715
+ ):
716
+ st.warning(f"Cluster {cluster_num} has no data or route_ids.")
717
+ continue
718
+ st.markdown(
719
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
720
+ )
721
+ route_id = group_data["route_ids"][0]
722
+ try:
723
+ num_steps = len(tree.synthesis_route(route_id))
724
+ route_score = round(tree.route_score(route_id), 3)
725
+ # svg = get_route_svg(tree, route_id)
726
+ svg = get_route_svg_from_json(st.session_state.route_json, route_id)
727
+ sb_cgr = group_data.get("sb_cgr") # Safely get sb_cgr
728
+ sb_cgr_svg = None
729
+ if sb_cgr:
730
+ sb_cgr.clean2d()
731
+ sb_cgr_svg = cgr_display(sb_cgr)
732
+
733
+ if svg and sb_cgr_svg:
734
+ col1, col2 = st.columns([0.2, 0.8])
735
+ with col1:
736
+ st.image(sb_cgr_svg, caption="SB-CGR")
737
+ with col2:
738
+ st.image(
739
+ svg,
740
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
741
+ )
742
+ elif svg: # Only route SVG available
743
+ st.image(
744
+ svg,
745
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
746
+ )
747
+ st.warning(
748
+ f"SB-CGR could not be displayed for cluster {cluster_num}."
749
+ )
750
+ else:
751
+ st.warning(
752
+ f"Could not generate SVG for route {route_id} or its SB-CGR."
753
+ )
754
+ except Exception as e:
755
+ st.error(
756
+ f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
757
+ )
758
+
759
+ if remaining_items:
760
+ with st.expander(f"... and {len(remaining_items)} more clusters"):
761
+ for cluster_num, group_data in remaining_items:
762
+ if (
763
+ not group_data
764
+ or "route_ids" not in group_data
765
+ or not group_data["route_ids"]
766
+ ):
767
+ st.warning(
768
+ f"Cluster {cluster_num} in expansion has no data or route_ids."
769
+ )
770
+ continue
771
+ st.markdown(
772
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
773
+ )
774
+ route_id = group_data["route_ids"][0]
775
+ try:
776
+ num_steps = len(tree.synthesis_route(route_id))
777
+ route_score = round(tree.route_score(route_id), 3)
778
+ # svg = get_route_svg(tree, route_id)
779
+ svg = get_route_svg_from_json(st.session_state.route_json, route_id)
780
+ sb_cgr = group_data.get("sb_cgr")
781
+ sb_cgr_svg = None
782
+ if sb_cgr:
783
+ sb_cgr.clean2d()
784
+ sb_cgr_svg = cgr_display(sb_cgr)
785
+
786
+ if svg and sb_cgr_svg:
787
+ col1, col2 = st.columns([0.2, 0.8])
788
+ with col1:
789
+ st.image(sb_cgr_svg, caption="SB-CGR")
790
+ with col2:
791
+ st.image(
792
+ svg,
793
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
794
+ )
795
+ elif svg:
796
+ st.image(
797
+ svg,
798
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
799
+ )
800
+ st.warning(
801
+ f"SB-CGR could not be displayed for cluster {cluster_num}."
802
+ )
803
+ else:
804
+ st.warning(
805
+ f"Could not generate SVG for route {route_id} or its SB-CGR."
806
+ )
807
+ except Exception as e:
808
+ st.error(
809
+ f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
810
+ )
811
+
812
+
813
+ def download_clustering_results():
814
+ """10. Clustering Results Download: Providing functionality to download."""
815
+ if st.session_state.get("clustering_done", False):
816
+ tree_for_html = st.session_state.get("tree")
817
+ clusters_for_html = st.session_state.get("clusters")
818
+ sb_cgrs_for_html = st.session_state.get(
819
+ "sb_cgrs_dict"
820
+ ) # This was used instead of reactions_dict in the original for report
821
+
822
+ if not tree_for_html:
823
+ st.warning("MCTS Tree data not found. Cannot generate cluster reports.")
824
+ return
825
+ if not clusters_for_html:
826
+ st.warning("Cluster data not found. Cannot generate cluster reports.")
827
+ return
828
+ # sb_cgrs_for_html is optional for routes_clustering_report if not essential
829
+
830
+ st.subheader("Cluster Reports") # Changed subheader in original
831
+ st.write("Generate downloadable HTML reports for each cluster:")
832
+
833
+ MAX_DOWNLOAD_LINKS_DISPLAYED = 10
834
+ num_clusters_total = len(clusters_for_html)
835
+ clusters_items = list(clusters_for_html.items())
836
+
837
+ for i, (cluster_idx, group_data) in enumerate(
838
+ clusters_items
839
+ ): # group_data might not be needed here if report uses cluster_idx
840
+ if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
841
+ break
842
+ try:
843
+ html_content = routes_clustering_report(
844
+ tree_for_html,
845
+ clusters_for_html, # Pass the whole dict
846
+ str(cluster_idx), # Pass the key of the cluster
847
+ sb_cgrs_for_html, # Pass the sb_cgrs dict
848
+ aam=False,
849
+ )
850
+ st.download_button(
851
+ label=f"Download report for cluster {cluster_idx}",
852
+ data=html_content,
853
+ file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
854
+ mime="text/html",
855
+ key=f"download_cluster_{cluster_idx}",
856
+ )
857
+ except Exception as e:
858
+ st.error(f"Error generating report for cluster {cluster_idx}: {e}")
859
+
860
+ if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
861
+ remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
862
+ remaining_count = len(remaining_items)
863
+ expander_label = f"Show remaining {remaining_count} cluster reports"
864
+ with st.expander(expander_label):
865
+ for (
866
+ group_index,
867
+ _,
868
+ ) in remaining_items: # group_data not needed here either
869
+ try:
870
+ html_content = routes_clustering_report(
871
+ tree_for_html,
872
+ clusters_for_html,
873
+ str(group_index),
874
+ sb_cgrs_for_html,
875
+ aam=False,
876
+ )
877
+ st.download_button(
878
+ label=f"Download report for cluster {group_index}",
879
+ data=html_content,
880
+ file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
881
+ mime="text/html",
882
+ key=f"download_cluster_expanded_{group_index}",
883
+ )
884
+ except Exception as e:
885
+ st.error(
886
+ f"Error generating report for cluster {group_index} (expanded): {e}"
887
+ )
888
+
889
+ try:
890
+ buffer = io.BytesIO()
891
+ with zipfile.ZipFile(
892
+ buffer, mode="w", compression=zipfile.ZIP_DEFLATED
893
+ ) as zf:
894
+ for idx, _ in clusters_items: # group_data not needed
895
+ html_content_zip = routes_clustering_report(
896
+ tree_for_html,
897
+ clusters_for_html,
898
+ str(idx),
899
+ sb_cgrs_for_html,
900
+ aam=False,
901
+ )
902
+ filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
903
+ zf.writestr(filename, html_content_zip)
904
+ buffer.seek(0)
905
+
906
+ st.download_button(
907
+ label="📦 Download all cluster reports as ZIP",
908
+ data=buffer,
909
+ file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
910
+ mime="application/zip",
911
+ key="download_all_clusters_zip",
912
+ )
913
+ except Exception as e:
914
+ st.error(f"Error generating ZIP file for cluster reports: {e}")
915
+
916
+
917
+ def setup_subclustering():
918
+ """11. Subclustering: Encapsulating the logic related to the "subclustering" functionality."""
919
+ if st.session_state.get(
920
+ "clustering_done", False
921
+ ): # Subclustering depends on clustering being done
922
+ st.divider()
923
+ st.header("Sub-Clustering within a selected Cluster")
924
+
925
+ if st.button("Run Subclustering Analysis", key="submit_subclustering_button"):
926
+ st.session_state.subclustering_done = False
927
+ st.session_state.subclusters = None
928
+ with st.spinner("Performing subclustering analysis..."):
929
+ try:
930
+ clusters_for_sub = st.session_state.get("clusters")
931
+ sb_cgrs_dict_for_sub = st.session_state.get("sb_cgrs_dict")
932
+ route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
933
+
934
+ if (
935
+ clusters_for_sub
936
+ and sb_cgrs_dict_for_sub
937
+ and route_cgrs_dict_for_sub
938
+ ): # Ensure all are present
939
+ all_subgroups = subcluster_all_clusters(
940
+ clusters_for_sub,
941
+ sb_cgrs_dict_for_sub,
942
+ route_cgrs_dict_for_sub,
943
+ )
944
+ st.session_state.subclusters = all_subgroups
945
+ st.session_state.subclustering_done = True
946
+ st.success("Subclustering analysis complete.")
947
+ gc.collect()
948
+ st.rerun()
949
+ else:
950
+ missing = []
951
+ if not clusters_for_sub:
952
+ missing.append("clusters")
953
+ if not sb_cgrs_dict_for_sub:
954
+ missing.append("SB-CGRs dictionary")
955
+ if not route_cgrs_dict_for_sub:
956
+ missing.append("RouteCGRs dictionary")
957
+ st.error(
958
+ f"Cannot run subclustering. Missing data: {', '.join(missing)}. Please ensure clustering ran successfully."
959
+ )
960
+ st.session_state.subclustering_done = False
961
+
962
+ except Exception as e:
963
+ st.error(f"An error occurred during subclustering: {e}")
964
+ st.session_state.subclustering_done = False
965
+
966
+
967
+ def display_subclustering_results():
968
+ """12. Subclustering Results Display: Handling the presentation of results."""
969
+ if st.session_state.get("subclustering_done", False):
970
+ sub = st.session_state.get("subclusters")
971
+ tree = st.session_state.get("tree")
972
+ # clusters_for_sub_display = st.session_state.get('clusters') # Not directly used in display logic from original code snippet
973
+
974
+ if not sub or not tree:
975
+ st.error(
976
+ "Subclustering results (subclusters or tree) are missing. Please re-run subclustering."
977
+ )
978
+ st.session_state.subclustering_done = False
979
+ return
980
+
981
+ sub_input_col, sub_display_col = st.columns([0.25, 0.75])
982
+
983
+ with sub_input_col:
984
+ st.subheader("Select Cluster and Subcluster")
985
+ available_cluster_nums = list(sub.keys())
986
+ if not available_cluster_nums:
987
+ st.warning("No clusters available in subclustering results.")
988
+ return # Exit if no clusters to select
989
+
990
+ user_input_cluster_num_display = st.selectbox(
991
+ "Select Cluster #:",
992
+ options=sorted(available_cluster_nums),
993
+ key="subcluster_num_select_key",
994
+ )
995
+
996
+ selected_subcluster_idx = 0
997
+
998
+ if user_input_cluster_num_display in sub:
999
+ sub_step_cluster = sub[user_input_cluster_num_display]
1000
+ allowed_subclusters_indices = sorted(list(sub_step_cluster.keys()))
1001
+
1002
+ if not allowed_subclusters_indices:
1003
+ st.warning(
1004
+ f"No reaction steps (subclusters) found for Cluster {user_input_cluster_num_display}."
1005
+ )
1006
+ else:
1007
+ selected_subcluster_idx = st.selectbox(
1008
+ "Select Subcluster Index:",
1009
+ options=allowed_subclusters_indices,
1010
+ key="subcluster_index_select_key",
1011
+ )
1012
+ if selected_subcluster_idx in sub[user_input_cluster_num_display]:
1013
+ current_subcluster_data = sub[user_input_cluster_num_display][
1014
+ selected_subcluster_idx
1015
+ ]
1016
+ if "sb_cgr" in current_subcluster_data:
1017
+ cluster_sb_cgr_display = current_subcluster_data["sb_cgr"]
1018
+ cluster_sb_cgr_display.clean2d()
1019
+ st.image(
1020
+ cluster_sb_cgr_display.depict(),
1021
+ caption=f"SB-CGR of parent Cluster {user_input_cluster_num_display}",
1022
+ )
1023
+ else:
1024
+ st.warning("SB-CGR for this subcluster not found.")
1025
+ else:
1026
+ st.warning(
1027
+ f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
1028
+ )
1029
+ return
1030
+
1031
+ with sub_display_col:
1032
+ st.subheader("Subcluster Details")
1033
+ if (
1034
+ user_input_cluster_num_display in sub
1035
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
1036
+ ):
1037
+
1038
+ subcluster_content = sub[user_input_cluster_num_display][
1039
+ selected_subcluster_idx
1040
+ ]
1041
+
1042
+ # subcluster_to_display = post_process_subgroup(subcluster_content) #Under development
1043
+ subcluster_to_display = subcluster_content
1044
+ if (
1045
+ not subcluster_to_display
1046
+ or "routes_data" not in subcluster_to_display
1047
+ or not subcluster_to_display["routes_data"]
1048
+ ):
1049
+ st.info("No routes or data found for this subcluster selection.")
1050
+ else:
1051
+ MAX_ROUTES_PER_SUBCLUSTER = 5
1052
+ all_route_ids_in_subcluster = list(
1053
+ subcluster_to_display["routes_data"].keys()
1054
+ )
1055
+ routes_to_display_direct = all_route_ids_in_subcluster[
1056
+ :MAX_ROUTES_PER_SUBCLUSTER
1057
+ ]
1058
+ remaining_routes_sub = all_route_ids_in_subcluster[
1059
+ MAX_ROUTES_PER_SUBCLUSTER:
1060
+ ]
1061
+
1062
+ st.markdown(
1063
+ f"--- \n**Subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}** (Size: {len(all_route_ids_in_subcluster)})"
1064
+ )
1065
+
1066
+ if "synthon_reaction" in subcluster_to_display:
1067
+ synthon_reaction = subcluster_to_display["synthon_reaction"]
1068
+ try:
1069
+ synthon_reaction.clean2d()
1070
+ st.image(
1071
+ depict_custom_reaction(synthon_reaction),
1072
+ caption=f"Markush-like pseudo reaction of subcluster",
1073
+ ) # Assuming depict_custom_reaction
1074
+ except Exception as e_depict:
1075
+ st.warning(f"Could not depict synthon reaction: {e_depict}")
1076
+ else:
1077
+ st.info("No synthon reaction data for this subcluster.")
1078
+ with st.container(height=500):
1079
+ for route_id in routes_to_display_direct:
1080
+ try:
1081
+ route_score_sub = round(tree.route_score(route_id), 3)
1082
+ # svg_sub = get_route_svg(tree, route_id)
1083
+ svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
1084
+ if svg_sub:
1085
+ st.image(
1086
+ svg_sub,
1087
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1088
+ )
1089
+ else:
1090
+ st.warning(
1091
+ f"Could not generate SVG for route {route_id}."
1092
+ )
1093
+ except Exception as e:
1094
+ st.error(
1095
+ f"Error displaying route {route_id} in subcluster: {e}"
1096
+ )
1097
+
1098
+ if remaining_routes_sub:
1099
+ with st.expander(
1100
+ f"... and {len(remaining_routes_sub)} more routes in this subcluster"
1101
+ ):
1102
+ for route_id in remaining_routes_sub:
1103
+ try:
1104
+ route_score_sub = round(
1105
+ tree.route_score(route_id), 3
1106
+ )
1107
+ # svg_sub = get_route_svg(tree, route_id)
1108
+ svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
1109
+ if svg_sub:
1110
+ st.image(
1111
+ svg_sub,
1112
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1113
+ )
1114
+ else:
1115
+ st.warning(
1116
+ f"Could not generate SVG for route {route_id}."
1117
+ )
1118
+ except Exception as e:
1119
+ st.error(
1120
+ f"Error displaying route {route_id} in subcluster (expanded): {e}"
1121
+ )
1122
+ else:
1123
+ st.info("Select a valid cluster and subcluster index to see details.")
1124
+
1125
+
1126
+ def download_subclustering_results():
1127
+ """13. Subclustering Results Download: Providing functionality to download."""
1128
+ if (
1129
+ st.session_state.get("subclustering_done", False)
1130
+ and "subcluster_num_select_key" in st.session_state
1131
+ and "subcluster_index_select_key" in st.session_state
1132
+ ):
1133
+
1134
+ sub = st.session_state.get("subclusters")
1135
+ tree = st.session_state.get("tree")
1136
+ sb_cgrs_for_report = st.session_state.get(
1137
+ "sb_cgrs_dict"
1138
+ ) # Used by routes_subclustering_report
1139
+
1140
+ user_input_cluster_num_display = st.session_state.subcluster_num_select_key
1141
+ selected_subcluster_idx = st.session_state.subcluster_index_select_key
1142
+
1143
+ if not tree or not sub or not sb_cgrs_for_report:
1144
+ st.warning(
1145
+ "Missing data for subclustering report generation (tree, subclusters, or SB-CGRs)."
1146
+ )
1147
+ return
1148
+
1149
+ if (
1150
+ user_input_cluster_num_display in sub
1151
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
1152
+ ):
1153
+
1154
+ subcluster_data_for_report = sub[user_input_cluster_num_display][
1155
+ selected_subcluster_idx
1156
+ ]
1157
+ # Apply the same post-processing as in display
1158
+ processed_subcluster_data = post_process_subgroup(
1159
+ subcluster_data_for_report
1160
+ )
1161
+ if "routes_data" in subcluster_data_for_report and isinstance(
1162
+ subcluster_data_for_report["routes_data"], dict
1163
+ ):
1164
+ processed_subcluster_data["group_lgs"] = group_by_identical_values(
1165
+ subcluster_data_for_report["routes_data"]
1166
+ )
1167
+ else:
1168
+ processed_subcluster_data["group_lgs"] = {}
1169
+
1170
+ try:
1171
+ subcluster_html_content = routes_subclustering_report(
1172
+ tree,
1173
+ processed_subcluster_data, # Pass the specific post-processed subcluster data
1174
+ user_input_cluster_num_display,
1175
+ selected_subcluster_idx,
1176
+ sb_cgrs_for_report, # Pass the whole sb_cgrs dict
1177
+ if_lg_group=True, # This parameter was in the original call
1178
+ )
1179
+ st.download_button(
1180
+ label=f"Download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}",
1181
+ data=subcluster_html_content,
1182
+ file_name=f"subcluster_{user_input_cluster_num_display}.{selected_subcluster_idx}_{st.session_state.target_smiles}.html",
1183
+ mime="text/html",
1184
+ key=f"download_subcluster_{user_input_cluster_num_display}_{selected_subcluster_idx}",
1185
+ )
1186
+ except Exception as e:
1187
+ st.error(
1188
+ f"Error generating download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
1189
+ )
1190
+ # else:
1191
+ # This case is handled by the display logic mostly, download button just won't appear or will be for previous valid selection.
1192
+
1193
+
1194
+ def implement_restart():
1195
+ """14. Restart: Implementing the logic to reset or restart the application state."""
1196
+ st.divider()
1197
+ st.header("Restart Application State")
1198
+ if st.button("Clear All Results & Restart", key="restart_button"):
1199
+ keys_to_clear = [
1200
+ "planning_done",
1201
+ "tree",
1202
+ "res",
1203
+ "target_smiles",
1204
+ "clustering_done",
1205
+ "clusters",
1206
+ "reactions_dict",
1207
+ "num_clusters_setting",
1208
+ "route_cgrs_dict",
1209
+ "sb_cgrs_dict",
1210
+ "route_json",
1211
+ "subclustering_done",
1212
+ "subclusters", # "sub" was renamed
1213
+ "clusters_downloaded",
1214
+ # Potentially ketcher related keys if they need manual reset beyond new input
1215
+ "ketcher_widget",
1216
+ "smiles_text_input_key", # Keys for widgets
1217
+ "subcluster_num_select_key",
1218
+ "subcluster_index_select_key",
1219
+ ]
1220
+ for key in keys_to_clear:
1221
+ if key in st.session_state:
1222
+ del st.session_state[key]
1223
+
1224
+ # Reset ketcher input to default by resetting its session state variable
1225
+ st.session_state.ketcher = DEFAULT_MOL
1226
+ # Also explicitly set target_smiles to empty or default to avoid stale data
1227
+ st.session_state.target_smiles = ""
1228
+
1229
+ # It's generally better to let Streamlit manage widget state if possible,
1230
+ # but for a full reset, clearing their explicit session state keys might be needed.
1231
+ st.rerun()
1232
+
1233
+
1234
+ # --- Main Application Flow ---
1235
+ def main():
1236
+ initialize_app()
1237
+ setup_sidebar()
1238
+ current_smile_code = handle_molecule_input()
1239
+ # Update session_state.ketcher if current_smile_code has changed from ketcher output
1240
+ if st.session_state.get("ketcher") != current_smile_code:
1241
+ st.session_state.ketcher = current_smile_code
1242
+ # No rerun here, let the flow continue. handle_molecule_input already warns.
1243
+
1244
+ setup_planning_options() # This function now also handles the button press and logic for planning
1245
+
1246
+ # Display planning results and download options together
1247
+ if st.session_state.get("planning_done", False):
1248
+ display_planning_results() # Displays stats and routes
1249
+ if st.session_state.res and st.session_state.res.get("solved", False):
1250
+ stat_col, download_col = st.columns(
1251
+ 2, gap="medium"
1252
+ ) # Placeholder for download column
1253
+ with stat_col:
1254
+ st.subheader("Statistics")
1255
+ try:
1256
+ res = st.session_state.res
1257
+ if (
1258
+ "target_smiles" not in res
1259
+ and "target_smiles" in st.session_state
1260
+ ):
1261
+ res["target_smiles"] = st.session_state.target_smiles
1262
+ cols_to_show = [
1263
+ col
1264
+ for col in [
1265
+ "target_smiles",
1266
+ "num_routes",
1267
+ "num_nodes",
1268
+ "num_iter",
1269
+ "search_time",
1270
+ ]
1271
+ if col in res
1272
+ ]
1273
+ if cols_to_show: # Ensure there are columns to show
1274
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
1275
+ st.dataframe(df)
1276
+ else:
1277
+ st.write("No statistics to display from planning results.")
1278
+ except Exception as e:
1279
+ st.error(f"Error displaying statistics: {e}")
1280
+ st.write(res) # Show raw dict if DataFrame fails
1281
+ with download_col:
1282
+ st.subheader("Planning Downloads") # Adding a subheader for clarity
1283
+ download_planning_results()
1284
+
1285
+ # Clustering section (setup button, display, download)
1286
+ if (
1287
+ st.session_state.get("planning_done", False)
1288
+ and st.session_state.res
1289
+ and st.session_state.res.get("solved", False)
1290
+ ):
1291
+ setup_clustering() # Contains the "Run Clustering" button and logic
1292
+ if st.session_state.get("clustering_done", False):
1293
+ display_clustering_results() # Displays cluster routes and stats
1294
+ cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
1295
+
1296
+ with cluster_stat_col:
1297
+ clusters = st.session_state.clusters
1298
+ cluster_sizes = [
1299
+ cluster.get("group_size", 0)
1300
+ for cluster in clusters.values()
1301
+ if cluster
1302
+ ] # Safe get
1303
+ st.subheader("Cluster Statistics")
1304
+ if cluster_sizes:
1305
+ cluster_df = pd.DataFrame(
1306
+ {
1307
+ "Cluster": [
1308
+ k for k, v in clusters.items() if v
1309
+ ], # Filter out empty clusters
1310
+ "Number of Routes": [
1311
+ v["group_size"] for v in clusters.values() if v
1312
+ ],
1313
+ }
1314
+ )
1315
+ if not cluster_df.empty:
1316
+ cluster_df.index += 1
1317
+ st.dataframe(cluster_df)
1318
+ best_route_html = html_top_routes_cluster(
1319
+ clusters,
1320
+ st.session_state.tree,
1321
+ st.session_state.target_smiles,
1322
+ )
1323
+ st.download_button(
1324
+ label=f"Download best route from each cluster",
1325
+ data=best_route_html,
1326
+ file_name=f"cluster_best_{st.session_state.target_smiles}.html",
1327
+ mime="text/html",
1328
+ key=f"download_cluster_best",
1329
+ )
1330
+ else:
1331
+ st.write("No valid cluster data to display statistics for.")
1332
+ # download_top_routes_cluster()
1333
+ else:
1334
+ st.write("No cluster data to display statistics for.")
1335
+ with cluster_download_col:
1336
+ download_clustering_results()
1337
+
1338
+ # Subclustering section (setup button, display, download)
1339
+ if st.session_state.get("clustering_done", False): # Depends on clustering
1340
+ setup_subclustering() # Contains "Run Subclustering" button
1341
+ if st.session_state.get("subclustering_done", False):
1342
+ display_subclustering_results() # Displays subcluster details and routes
1343
+ download_subclustering_results() # This needs to be called after selections are made in display.
1344
+
1345
+ implement_restart()
1346
+
1347
+
1348
+ if __name__ == "__main__":
1349
+ main()
pre-requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ torch==2.2.2+cpu
3
+ scikit-learn==1.5.1
4
+ scipy==1.14.0
5
+ fastcluster==1.2.6
6
+ matplotlib==3.10.1
7
+ seaborn==0.13.2
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
1
+ streamlit
2
+ streamlit_ketcher
3
+ git+https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner.git
4
+
5
+ git+https://github.com/cimm-kzn/StructureFingerprint.git
6
+ scikit-learn
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
synplan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mcts import *
2
+
3
+ __all__ = ["Tree"]
synplan/chem/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from CGRtools.files import SMILESRead
2
+
3
+ smiles_parser = SMILESRead.create_parser(ignore=True)
synplan/chem/data/__init__.py ADDED
File without changes
synplan/chem/data/filtering.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes abd functions for reactions filtering."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from io import TextIOWrapper
6
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import ray
10
+ import yaml
11
+ from CGRtools.containers import CGRContainer, MoleculeContainer, ReactionContainer
12
+ from chython.algorithms.fingerprints.morgan import MorganFingerprint
13
+ from tqdm import tqdm
14
+
15
+ from synplan.chem.data.standardizing import (
16
+ AromaticFormStandardizer,
17
+ KekuleFormStandardizer,
18
+ RemoveReagentsStandardizer,
19
+ )
20
+ from synplan.chem.utils import cgrtools_to_chython_molecule
21
+ from synplan.utils.config import ConfigABC, convert_config_to_dict
22
+ from synplan.utils.files import ReactionReader, ReactionWriter
23
+
24
+
25
+ @dataclass
26
+ class CompeteProductsConfig(ConfigABC):
27
+ fingerprint_tanimoto_threshold: float = 0.3
28
+ mcs_tanimoto_threshold: float = 0.6
29
+
30
+ @staticmethod
31
+ def from_dict(config_dict: Dict[str, Any]) -> "CompeteProductsConfig":
32
+ """Create an instance of CompeteProductsConfig from a dictionary."""
33
+ return CompeteProductsConfig(**config_dict)
34
+
35
+ @staticmethod
36
+ def from_yaml(file_path: str) -> "CompeteProductsConfig":
37
+ """Deserialize a YAML file into a CompeteProductsConfig object."""
38
+ with open(file_path, "r", encoding="utf-8") as file:
39
+ config_dict = yaml.safe_load(file)
40
+ return CompeteProductsConfig.from_dict(config_dict)
41
+
42
+ def _validate_params(self, params: Dict[str, Any]) -> None:
43
+ """Validate configuration parameters."""
44
+ if not isinstance(params.get("fingerprint_tanimoto_threshold"), float) or not (
45
+ 0 <= params["fingerprint_tanimoto_threshold"] <= 1
46
+ ):
47
+ raise ValueError(
48
+ "Invalid 'fingerprint_tanimoto_threshold'; expected a float between 0 and 1"
49
+ )
50
+
51
+ if not isinstance(params.get("mcs_tanimoto_threshold"), float) or not (
52
+ 0 <= params["mcs_tanimoto_threshold"] <= 1
53
+ ):
54
+ raise ValueError(
55
+ "Invalid 'mcs_tanimoto_threshold'; expected a float between 0 and 1"
56
+ )
57
+
58
+
59
+ class CompeteProductsFilter:
60
+ """Checks if there are compete reactions."""
61
+
62
+ def __init__(
63
+ self,
64
+ fingerprint_tanimoto_threshold: float = 0.3,
65
+ mcs_tanimoto_threshold: float = 0.6,
66
+ ):
67
+ self.fingerprint_tanimoto_threshold = fingerprint_tanimoto_threshold
68
+ self.mcs_tanimoto_threshold = mcs_tanimoto_threshold
69
+
70
+ @staticmethod
71
+ def from_config(config: CompeteProductsConfig) -> "CompeteProductsFilter":
72
+ """Creates an instance of CompeteProductsFilter from a configuration object."""
73
+ return CompeteProductsFilter(
74
+ config.fingerprint_tanimoto_threshold, config.mcs_tanimoto_threshold
75
+ )
76
+
77
+ def __call__(self, reaction: ReactionContainer) -> bool:
78
+ """Checks if the reaction has competing products, else False.
79
+
80
+ :param reaction: Input reaction.
81
+ :return: Returns True if the reaction has competing products, else False.
82
+ """
83
+ mf = MorganFingerprint()
84
+ is_compete = False
85
+
86
+ # check for compete products using both fingerprint similarity and maximum common substructure (MCS) similarity
87
+ for mol in reaction.reagents:
88
+ for other_mol in reaction.products:
89
+ if len(mol) > 6 and len(other_mol) > 6:
90
+ # compute fingerprint similarity
91
+ molf = mf.transform([cgrtools_to_chython_molecule(mol)])
92
+ other_molf = mf.transform([cgrtools_to_chython_molecule(other_mol)])
93
+ fingerprint_tanimoto = tanimoto_kernel(molf, other_molf)[0][0]
94
+
95
+ # if fingerprint similarity is high enough, check for MCS similarity
96
+ if fingerprint_tanimoto > self.fingerprint_tanimoto_threshold:
97
+ try:
98
+ # find the maximum common substructure (MCS) and compute its size
99
+ clique_size = len(
100
+ next(mol.get_mcs_mapping(other_mol, limit=100))
101
+ )
102
+
103
+ # calculate MCS similarity based on MCS size
104
+ mcs_tanimoto = clique_size / (
105
+ len(mol) + len(other_mol) - clique_size
106
+ )
107
+
108
+ # if MCS similarity is also high enough, mark the reaction as having compete products
109
+ if mcs_tanimoto > self.mcs_tanimoto_threshold:
110
+ is_compete = True
111
+ break
112
+ except StopIteration:
113
+ continue
114
+
115
+ return is_compete
116
+
117
+
118
+ @dataclass
119
+ class DynamicBondsConfig(ConfigABC):
120
+ min_bonds_number: int = 1
121
+ max_bonds_number: int = 6
122
+
123
+ @staticmethod
124
+ def from_dict(config_dict: Dict[str, Any]) -> "DynamicBondsConfig":
125
+ """Create an instance of DynamicBondsConfig from a dictionary."""
126
+ return DynamicBondsConfig(**config_dict)
127
+
128
+ @staticmethod
129
+ def from_yaml(file_path: str) -> "DynamicBondsConfig":
130
+ """Deserialize a YAML file into a DynamicBondsConfig object."""
131
+ with open(file_path, "r") as file:
132
+ config_dict = yaml.safe_load(file)
133
+ return DynamicBondsConfig.from_dict(config_dict)
134
+
135
+ def _validate_params(self, params: Dict[str, Any]) -> None:
136
+ """Validate configuration parameters."""
137
+ if (
138
+ not isinstance(params.get("min_bonds_number"), int)
139
+ or params["min_bonds_number"] < 0
140
+ ):
141
+ raise ValueError(
142
+ "Invalid 'min_bonds_number'; expected a non-negative integer"
143
+ )
144
+
145
+ if (
146
+ not isinstance(params.get("max_bonds_number"), int)
147
+ or params["max_bonds_number"] < 0
148
+ ):
149
+ raise ValueError(
150
+ "Invalid 'max_bonds_number'; expected a non-negative integer"
151
+ )
152
+
153
+ if params["min_bonds_number"] > params["max_bonds_number"]:
154
+ raise ValueError(
155
+ "'min_bonds_number' cannot be greater than 'max_bonds_number'"
156
+ )
157
+
158
+
159
+ class DynamicBondsFilter:
160
+ """Checks if there is an unacceptable number of dynamic bonds in CGR."""
161
+
162
+ def __init__(self, min_bonds_number: int = 1, max_bonds_number: int = 6):
163
+ self.min_bonds_number = min_bonds_number
164
+ self.max_bonds_number = max_bonds_number
165
+
166
+ @staticmethod
167
+ def from_config(config: DynamicBondsConfig):
168
+ """Creates an instance of DynamicBondsChecker from a configuration object."""
169
+ return DynamicBondsFilter(config.min_bonds_number, config.max_bonds_number)
170
+
171
+ def __call__(self, reaction: ReactionContainer) -> bool:
172
+ cgr = ~reaction
173
+ return not (
174
+ self.min_bonds_number <= len(cgr.center_bonds) <= self.max_bonds_number
175
+ )
176
+
177
+
178
+ @dataclass
179
+ class SmallMoleculesConfig(ConfigABC):
180
+ mol_max_size: int = 6
181
+
182
+ @staticmethod
183
+ def from_dict(config_dict: Dict[str, Any]) -> "SmallMoleculesConfig":
184
+ """Creates an instance of SmallMoleculesConfig from a dictionary."""
185
+ return SmallMoleculesConfig(**config_dict)
186
+
187
+ @staticmethod
188
+ def from_yaml(file_path: str) -> "SmallMoleculesConfig":
189
+ """Deserialize a YAML file into a SmallMoleculesConfig object."""
190
+ with open(file_path, "r") as file:
191
+ config_dict = yaml.safe_load(file)
192
+ return SmallMoleculesConfig.from_dict(config_dict)
193
+
194
+ def _validate_params(self, params: Dict[str, Any]) -> None:
195
+ """Validate configuration parameters."""
196
+ if (
197
+ not isinstance(params.get("mol_max_size"), int)
198
+ or params["mol_max_size"] < 1
199
+ ):
200
+ raise ValueError("Invalid 'mol_max_size'; expected a positive integer")
201
+
202
+
203
+ class SmallMoleculesFilter:
204
+ """Checks if there are only small molecules in the reaction or if there is only one
205
+ small reactant or product."""
206
+
207
+ def __init__(self, mol_max_size: int = 6):
208
+ self.limit = mol_max_size
209
+
210
+ @staticmethod
211
+ def from_config(config: SmallMoleculesConfig) -> "SmallMoleculesFilter":
212
+ """Creates an instance of SmallMoleculesChecker from a configuration object."""
213
+ return SmallMoleculesFilter(config.mol_max_size)
214
+
215
+ def __call__(self, reaction: ReactionContainer) -> bool:
216
+ if (
217
+ (
218
+ len(reaction.reactants) == 1
219
+ and self.are_only_small_molecules(reaction.reactants)
220
+ )
221
+ or (
222
+ len(reaction.products) == 1
223
+ and self.are_only_small_molecules(reaction.products)
224
+ )
225
+ or (
226
+ self.are_only_small_molecules(reaction.reactants)
227
+ and self.are_only_small_molecules(reaction.products)
228
+ )
229
+ ):
230
+ return True
231
+ return False
232
+
233
+ def are_only_small_molecules(self, molecules: Iterable[MoleculeContainer]) -> bool:
234
+ """Checks if all molecules in the given iterable are small molecules."""
235
+ return all(len(molecule) <= self.limit for molecule in molecules)
236
+
237
+
238
+ @dataclass
239
+ class CGRConnectedComponentsConfig:
240
+ pass
241
+
242
+
243
+ class CGRConnectedComponentsFilter:
244
+ """Checks if CGR contains unrelated components (without reagents)."""
245
+
246
+ @staticmethod
247
+ def from_config(
248
+ config: CGRConnectedComponentsConfig,
249
+ ) -> "CGRConnectedComponentsFilter":
250
+ """Creates an instance of CGRConnectedComponentsChecker from a configuration
251
+ object."""
252
+ return CGRConnectedComponentsFilter()
253
+
254
+ def __call__(self, reaction: ReactionContainer) -> bool:
255
+ tmp_reaction = ReactionContainer(reaction.reactants, reaction.products)
256
+ cgr = ~tmp_reaction
257
+ return cgr.connected_components_count > 1
258
+
259
+
260
+ @dataclass
261
+ class RingsChangeConfig:
262
+ pass
263
+
264
+
265
+ class RingsChangeFilter:
266
+ """Checks if there is changing rings number in the reaction."""
267
+
268
+ @staticmethod
269
+ def from_config(config: RingsChangeConfig) -> "RingsChangeFilter":
270
+ """Creates an instance of RingsChecker from a configuration object."""
271
+ return RingsChangeFilter()
272
+
273
+ def __call__(self, reaction: ReactionContainer):
274
+ """
275
+ Returns True if there are valence mistakes in the reaction or there is a
276
+ reaction with mismatch numbers of all rings or aromatic rings in reactants and
277
+ products (reaction in rings)
278
+
279
+ :param reaction: Input reaction.
280
+ :return: Returns True if there are valence mistakes in the reaction.
281
+
282
+ """
283
+
284
+ r_rings, r_arom_rings = self._calc_rings(reaction.reactants)
285
+ p_rings, p_arom_rings = self._calc_rings(reaction.products)
286
+
287
+ return (r_arom_rings != p_arom_rings) or (r_rings != p_rings)
288
+
289
+ @staticmethod
290
+ def _calc_rings(molecules: Iterable) -> Tuple[int, int]:
291
+ """
292
+ Calculates number of all rings and number of aromatic rings in molecules.
293
+
294
+ :param molecules: Set of molecules.
295
+ :return: Number of all rings and number of aromatic rings in molecules
296
+ """
297
+ rings, arom_rings = 0, 0
298
+ for mol in molecules:
299
+ rings += mol.rings_count
300
+ arom_rings += len(mol.aromatic_rings)
301
+ return rings, arom_rings
302
+
303
+
304
+ @dataclass
305
+ class StrangeCarbonsConfig:
306
+ # currently empty, but can be extended in the future if needed
307
+ pass
308
+
309
+
310
+ class StrangeCarbonsFilter:
311
+ """Checks if there are 'strange' carbons in the reaction."""
312
+
313
+ @staticmethod
314
+ def from_config(config: StrangeCarbonsConfig) -> "StrangeCarbonsFilter":
315
+ """Creates an instance of StrangeCarbonsChecker from a configuration object."""
316
+ return StrangeCarbonsFilter()
317
+
318
+ def __call__(self, reaction: ReactionContainer) -> bool:
319
+ for molecule in reaction.reactants + reaction.products:
320
+ atoms_types = {
321
+ a.atomic_symbol for _, a in molecule.atoms()
322
+ } # atoms types in molecule
323
+ if len(atoms_types) == 1 and atoms_types.pop() == "C":
324
+ if len(molecule) == 1: # methane
325
+ return True
326
+ bond_types = {int(b) for _, _, b in molecule.bonds()}
327
+ if len(bond_types) == 1 and bond_types.pop() != 4:
328
+ return True # C molecules with only one type of bond (not aromatic)
329
+ return False
330
+
331
+
332
+ @dataclass
333
+ class NoReactionConfig:
334
+ # Currently empty, but can be extended in the future if needed
335
+ pass
336
+
337
+
338
+ class NoReactionFilter:
339
+ """Checks if there is no reaction in the provided reaction container."""
340
+
341
+ @staticmethod
342
+ def from_config(config: NoReactionConfig) -> "NoReactionFilter":
343
+ """Creates an instance of NoReactionChecker from a configuration object."""
344
+ return NoReactionFilter()
345
+
346
+ def __call__(self, reaction: ReactionContainer) -> bool:
347
+ cgr = ~reaction
348
+ return not cgr.center_atoms and not cgr.center_bonds
349
+
350
+
351
+ @dataclass
352
+ class MultiCenterConfig:
353
+ pass
354
+
355
+
356
+ class MultiCenterFilter:
357
+ """Checks if there is a multicenter reaction."""
358
+
359
+ @staticmethod
360
+ def from_config(config: MultiCenterConfig) -> "MultiCenterFilter":
361
+ return MultiCenterFilter()
362
+
363
+ def __call__(self, reaction: ReactionContainer) -> bool:
364
+ cgr = ~reaction
365
+ return len(cgr.centers_list) > 1
366
+
367
+
368
+ @dataclass
369
+ class WrongCHBreakingConfig:
370
+ pass
371
+
372
+
373
+ class WrongCHBreakingFilter:
374
+ """Checks for incorrect C-C bond formation from breaking a C-H bond."""
375
+
376
+ @staticmethod
377
+ def from_config(config: WrongCHBreakingConfig) -> "WrongCHBreakingFilter":
378
+ return WrongCHBreakingFilter()
379
+
380
+ def __call__(self, reaction: ReactionContainer) -> bool:
381
+ """
382
+ Determines if a reaction involves incorrect C-C bond formation from breaking
383
+ a C-H bond.
384
+
385
+ :param reaction: The reaction to be filtered.
386
+ :return: True if incorrect C-C bond formation is found, False otherwise.
387
+
388
+ """
389
+
390
+ if reaction.check_valence():
391
+ return False
392
+
393
+ copy_reaction = reaction.copy()
394
+ copy_reaction.explicify_hydrogens()
395
+ cgr = ~copy_reaction
396
+ reduced_cgr = cgr.augmented_substructure(cgr.center_atoms, deep=1)
397
+
398
+ return self.is_wrong_c_h_breaking(reduced_cgr)
399
+
400
+ @staticmethod
401
+ def is_wrong_c_h_breaking(cgr: CGRContainer) -> bool:
402
+ """
403
+ Checks for incorrect C-C bond formation from breaking a C-H bond in a CGR.
404
+
405
+ :param cgr: The CGR with explicified hydrogens.
406
+ :return: True if incorrect C-C bond formation is found, False otherwise.
407
+
408
+ """
409
+ for atom_id in cgr.center_atoms:
410
+ if cgr.atom(atom_id).atomic_symbol == "C":
411
+ is_c_h_breaking, is_c_c_formation = False, False
412
+ c_with_h_id, another_c_id = None, None
413
+
414
+ for neighbour_id, bond in cgr._bonds[atom_id].items():
415
+ neighbour = cgr.atom(neighbour_id)
416
+
417
+ if (
418
+ bond.order
419
+ and not bond.p_order
420
+ and neighbour.atomic_symbol == "H"
421
+ ):
422
+ is_c_h_breaking = True
423
+ c_with_h_id = atom_id
424
+
425
+ elif (
426
+ not bond.order
427
+ and bond.p_order
428
+ and neighbour.atomic_symbol == "C"
429
+ ):
430
+ is_c_c_formation = True
431
+ another_c_id = neighbour_id
432
+
433
+ if is_c_h_breaking and is_c_c_formation:
434
+ # check for presence of heteroatoms in the first environment of 2 bonding carbons
435
+ if any(
436
+ cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
437
+ for neighbour_id in cgr._bonds[c_with_h_id]
438
+ ) or any(
439
+ cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
440
+ for neighbour_id in cgr._bonds[another_c_id]
441
+ ):
442
+ return False
443
+ return True
444
+
445
+ return False
446
+
447
+
448
+ @dataclass
449
+ class CCsp3BreakingConfig:
450
+ pass
451
+
452
+
453
+ class CCsp3BreakingFilter:
454
+ """Checks if there is C(sp3)-C bond breaking."""
455
+
456
+ @staticmethod
457
+ def from_config(config: CCsp3BreakingConfig) -> "CCsp3BreakingFilter":
458
+ return CCsp3BreakingFilter()
459
+
460
+ def __call__(self, reaction: ReactionContainer) -> bool:
461
+ """
462
+ Returns True if there is C(sp3)-C bonds breaking, else False.
463
+
464
+ :param reaction: Input reaction
465
+ :return: Returns True if there is C(sp3)-C bonds breaking, else False.
466
+
467
+ """
468
+ cgr = ~reaction
469
+ reaction_center = cgr.augmented_substructure(cgr.center_atoms, deep=1)
470
+ for atom_id, neighbour_id, bond in reaction_center.bonds():
471
+ atom = reaction_center.atom(atom_id)
472
+ neighbour = reaction_center.atom(neighbour_id)
473
+
474
+ is_bond_broken = bond.order is not None and bond.p_order is None
475
+ are_atoms_carbons = (
476
+ atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
477
+ )
478
+ is_atom_sp3 = atom.hybridization == 1 or neighbour.hybridization == 1
479
+
480
+ if is_bond_broken and are_atoms_carbons and is_atom_sp3:
481
+ return True
482
+ return False
483
+
484
+
485
+ @dataclass
486
+ class CCRingBreakingConfig:
487
+ """
488
+ Object to pass to ReactionFilterConfig if you want to enable C-C ring breaking filter
489
+
490
+ """
491
+
492
+ pass
493
+
494
+
495
+ class CCRingBreakingFilter:
496
+ """Checks if a reaction involves ring C-C bond breaking."""
497
+
498
+ @staticmethod
499
+ def from_config(config: CCRingBreakingConfig):
500
+ return CCRingBreakingFilter()
501
+
502
+ def __call__(self, reaction: ReactionContainer) -> bool:
503
+ """
504
+ Returns True if the reaction involves ring C-C bond breaking, else False.
505
+
506
+ :param reaction: Input reaction
507
+ :return: Returns True if the reaction involves ring C-C bond breaking, else
508
+ False.
509
+
510
+ """
511
+ cgr = ~reaction
512
+
513
+ # Extract reactants' center atoms and their rings
514
+ reactants_center_atoms = {}
515
+ reactants_rings = set()
516
+ for reactant in reaction.reactants:
517
+ reactants_rings.update(reactant.sssr)
518
+ for n, atom in reactant.atoms():
519
+ if n in cgr.center_atoms:
520
+ reactants_center_atoms[n] = atom
521
+
522
+ # identify reaction center based on center atoms
523
+ reaction_center = cgr.augmented_substructure(atoms=cgr.center_atoms, deep=0)
524
+
525
+ # iterate over bonds in the reaction center and filter for ring C-C bond breaking
526
+ for atom_id, neighbour_id, bond in reaction_center.bonds():
527
+ try:
528
+ # Retrieve corresponding atoms from reactants
529
+ atom = reactants_center_atoms[atom_id]
530
+ neighbour = reactants_center_atoms[neighbour_id]
531
+ except KeyError:
532
+ continue
533
+ else:
534
+ # Check if the bond is broken and both atoms are carbons in rings of size 5, 6, or 7
535
+ is_bond_broken = (bond.order is not None) and (bond.p_order is None)
536
+ are_atoms_carbons = (
537
+ atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
538
+ )
539
+ are_atoms_in_ring = (
540
+ set(atom.ring_sizes).intersection({5, 6, 7})
541
+ and set(neighbour.ring_sizes).intersection({5, 6, 7})
542
+ and any(
543
+ atom_id in ring and neighbour_id in ring
544
+ for ring in reactants_rings
545
+ )
546
+ )
547
+
548
+ # If all conditions are met, indicate ring C-C bond breaking
549
+ if is_bond_broken and are_atoms_carbons and are_atoms_in_ring:
550
+ return True
551
+
552
+ return False
553
+
554
+
555
+ @dataclass
556
+ class ReactionFilterConfig(ConfigABC):
557
+ """
558
+ Configuration class for reaction filtering. This class manages configuration
559
+ settings for various reaction filters, including paths, file formats, and filter-
560
+ specific parameters.
561
+
562
+ :ivar dynamic_bonds_config: Configuration for dynamic bonds checking.
563
+ :ivar small_molecules_config: Configuration for small molecules checking.
564
+ :ivar strange_carbons_config: Configuration for strange carbons checking.
565
+ :ivar compete_products_config: Configuration for competing products checking.
566
+ :ivar cgr_connected_components_config: Configuration for CGR connected components checking.
567
+ :ivar rings_change_config: Configuration for rings change checking.
568
+ :ivar no_reaction_config: Configuration for no reaction checking.
569
+ :ivar multi_center_config: Configuration for multi-center checking.
570
+ :ivar wrong_ch_breaking_config: Configuration for wrong C-H breaking checking.
571
+ :ivar cc_sp3_breaking_config: Configuration for CC sp3 breaking checking.
572
+ :ivar cc_ring_breaking_config: Configuration for CC ring breaking checking.
573
+
574
+ """
575
+
576
+ # configuration for reaction filters
577
+ dynamic_bonds_config: Optional[DynamicBondsConfig] = None
578
+ small_molecules_config: Optional[SmallMoleculesConfig] = None
579
+ strange_carbons_config: Optional[StrangeCarbonsConfig] = None
580
+ compete_products_config: Optional[CompeteProductsConfig] = None
581
+ cgr_connected_components_config: Optional[CGRConnectedComponentsConfig] = None
582
+ rings_change_config: Optional[RingsChangeConfig] = None
583
+ no_reaction_config: Optional[NoReactionConfig] = None
584
+ multi_center_config: Optional[MultiCenterConfig] = None
585
+ wrong_ch_breaking_config: Optional[WrongCHBreakingConfig] = None
586
+ cc_sp3_breaking_config: Optional[CCsp3BreakingConfig] = None
587
+ cc_ring_breaking_config: Optional[CCRingBreakingConfig] = None
588
+
589
+ def to_dict(self):
590
+ """Converts the configuration into a dictionary."""
591
+ config_dict = {
592
+ "dynamic_bonds_config": convert_config_to_dict(
593
+ self.dynamic_bonds_config, DynamicBondsConfig
594
+ ),
595
+ "small_molecules_config": convert_config_to_dict(
596
+ self.small_molecules_config, SmallMoleculesConfig
597
+ ),
598
+ "compete_products_config": convert_config_to_dict(
599
+ self.compete_products_config, CompeteProductsConfig
600
+ ),
601
+ "cgr_connected_components_config": (
602
+ {} if self.cgr_connected_components_config is not None else None
603
+ ),
604
+ "rings_change_config": {} if self.rings_change_config is not None else None,
605
+ "strange_carbons_config": (
606
+ {} if self.strange_carbons_config is not None else None
607
+ ),
608
+ "no_reaction_config": {} if self.no_reaction_config is not None else None,
609
+ "multi_center_config": {} if self.multi_center_config is not None else None,
610
+ "wrong_ch_breaking_config": (
611
+ {} if self.wrong_ch_breaking_config is not None else None
612
+ ),
613
+ "cc_sp3_breaking_config": (
614
+ {} if self.cc_sp3_breaking_config is not None else None
615
+ ),
616
+ "cc_ring_breaking_config": (
617
+ {} if self.cc_ring_breaking_config is not None else None
618
+ ),
619
+ }
620
+
621
+ filtered_config_dict = {k: v for k, v in config_dict.items() if v is not None}
622
+
623
+ return filtered_config_dict
624
+
625
+ @staticmethod
626
+ def from_dict(config_dict: Dict[str, Any]) -> "ReactionFilterConfig":
627
+ """Create an instance of ReactionCheckConfig from a dictionary."""
628
+ # Instantiate configuration objects if their corresponding dictionary is present
629
+ dynamic_bonds_config = (
630
+ DynamicBondsConfig(**config_dict["dynamic_bonds_config"])
631
+ if "dynamic_bonds_config" in config_dict
632
+ else None
633
+ )
634
+
635
+ small_molecules_config = (
636
+ SmallMoleculesConfig(**config_dict["small_molecules_config"])
637
+ if "small_molecules_config" in config_dict
638
+ else None
639
+ )
640
+
641
+ compete_products_config = (
642
+ CompeteProductsConfig(**config_dict["compete_products_config"])
643
+ if "compete_products_config" in config_dict
644
+ else None
645
+ )
646
+
647
+ cgr_connected_components_config = (
648
+ CGRConnectedComponentsConfig()
649
+ if "cgr_connected_components_config" in config_dict
650
+ else None
651
+ )
652
+
653
+ rings_change_config = (
654
+ RingsChangeConfig() if "rings_change_config" in config_dict else None
655
+ )
656
+
657
+ strange_carbons_config = (
658
+ StrangeCarbonsConfig() if "strange_carbons_config" in config_dict else None
659
+ )
660
+
661
+ no_reaction_config = (
662
+ NoReactionConfig() if "no_reaction_config" in config_dict else None
663
+ )
664
+
665
+ multi_center_config = (
666
+ MultiCenterConfig() if "multi_center_config" in config_dict else None
667
+ )
668
+
669
+ wrong_ch_breaking_config = (
670
+ WrongCHBreakingConfig()
671
+ if "wrong_ch_breaking_config" in config_dict
672
+ else None
673
+ )
674
+
675
+ cc_sp3_breaking_config = (
676
+ CCsp3BreakingConfig() if "cc_sp3_breaking_config" in config_dict else None
677
+ )
678
+
679
+ cc_ring_breaking_config = (
680
+ CCRingBreakingConfig() if "cc_ring_breaking_config" in config_dict else None
681
+ )
682
+
683
+ return ReactionFilterConfig(
684
+ dynamic_bonds_config=dynamic_bonds_config,
685
+ small_molecules_config=small_molecules_config,
686
+ compete_products_config=compete_products_config,
687
+ cgr_connected_components_config=cgr_connected_components_config,
688
+ rings_change_config=rings_change_config,
689
+ strange_carbons_config=strange_carbons_config,
690
+ no_reaction_config=no_reaction_config,
691
+ multi_center_config=multi_center_config,
692
+ wrong_ch_breaking_config=wrong_ch_breaking_config,
693
+ cc_sp3_breaking_config=cc_sp3_breaking_config,
694
+ cc_ring_breaking_config=cc_ring_breaking_config,
695
+ )
696
+
697
+ @staticmethod
698
+ def from_yaml(file_path: str) -> "ReactionFilterConfig":
699
+ """Deserializes a YAML file into a ReactionCheckConfig object."""
700
+ with open(file_path, "r", encoding="utf-8") as file:
701
+ config_dict = yaml.safe_load(file)
702
+ return ReactionFilterConfig.from_dict(config_dict)
703
+
704
+ def _validate_params(self, params: Dict[str, Any]):
705
+ pass
706
+
707
+ def create_filters(self):
708
+ filter_instances = []
709
+
710
+ if self.dynamic_bonds_config is not None:
711
+ filter_instances.append(
712
+ DynamicBondsFilter.from_config(self.dynamic_bonds_config)
713
+ )
714
+
715
+ if self.small_molecules_config is not None:
716
+ filter_instances.append(
717
+ SmallMoleculesFilter.from_config(self.small_molecules_config)
718
+ )
719
+
720
+ if self.strange_carbons_config is not None:
721
+ filter_instances.append(
722
+ StrangeCarbonsFilter.from_config(self.strange_carbons_config)
723
+ )
724
+
725
+ if self.compete_products_config is not None:
726
+ filter_instances.append(
727
+ CompeteProductsFilter.from_config(self.compete_products_config)
728
+ )
729
+
730
+ if self.cgr_connected_components_config is not None:
731
+ filter_instances.append(
732
+ CGRConnectedComponentsFilter.from_config(
733
+ self.cgr_connected_components_config
734
+ )
735
+ )
736
+
737
+ if self.rings_change_config is not None:
738
+ filter_instances.append(
739
+ RingsChangeFilter.from_config(self.rings_change_config)
740
+ )
741
+
742
+ if self.no_reaction_config is not None:
743
+ filter_instances.append(
744
+ NoReactionFilter.from_config(self.no_reaction_config)
745
+ )
746
+
747
+ if self.multi_center_config is not None:
748
+ filter_instances.append(
749
+ MultiCenterFilter.from_config(self.multi_center_config)
750
+ )
751
+
752
+ if self.wrong_ch_breaking_config is not None:
753
+ filter_instances.append(
754
+ WrongCHBreakingFilter.from_config(self.wrong_ch_breaking_config)
755
+ )
756
+
757
+ if self.cc_sp3_breaking_config is not None:
758
+ filter_instances.append(
759
+ CCsp3BreakingFilter.from_config(self.cc_sp3_breaking_config)
760
+ )
761
+
762
+ if self.cc_ring_breaking_config is not None:
763
+ filter_instances.append(
764
+ CCRingBreakingFilter.from_config(self.cc_ring_breaking_config)
765
+ )
766
+
767
+ return filter_instances
768
+
769
+
770
+ def tanimoto_kernel(x: MorganFingerprint, y: MorganFingerprint) -> float:
771
+ """Calculate the Tanimoto coefficient between each element of arrays x and y."""
772
+ x = x.astype(np.float64)
773
+ y = y.astype(np.float64)
774
+ x_dot = np.dot(x, y.T)
775
+ x2 = np.sum(x**2, axis=1)
776
+ y2 = np.sum(y**2, axis=1)
777
+
778
+ denominator = np.array([x2] * len(y2)).T + np.array([y2] * len(x2)) - x_dot
779
+ result = np.divide(
780
+ x_dot, denominator, out=np.zeros_like(x_dot), where=denominator != 0
781
+ )
782
+
783
+ return result
784
+
785
+
786
+ def filter_reaction(
787
+ reaction: ReactionContainer, config: ReactionFilterConfig, filters: list
788
+ ) -> Tuple[bool, ReactionContainer]:
789
+ """Checks the input reaction. Returns True if reaction is detected as erroneous and
790
+ returns reaction itself, which sometimes is modified and does not necessarily
791
+ correspond to the initial reaction.
792
+
793
+ :param reaction: Reaction to be filtered.
794
+ :param config: Reaction filtration configuration.
795
+ :param filters: The list of reaction filters.
796
+ :return: False and reaction if reaction is correct and True and reaction if reaction
797
+ is filtered (erroneous).
798
+ """
799
+
800
+ is_filtered = False
801
+
802
+ # run reaction standardization
803
+
804
+ standardizers = [
805
+ RemoveReagentsStandardizer(),
806
+ KekuleFormStandardizer(),
807
+ AromaticFormStandardizer(),
808
+ ]
809
+
810
+ for reaction_standardizer in standardizers:
811
+ reaction = reaction_standardizer(reaction)
812
+ if not reaction:
813
+ is_filtered = True
814
+ break
815
+
816
+ # run reaction filtration
817
+ if not is_filtered:
818
+ for reaction_filter in filters:
819
+ try: # CGRTools ValueError: mapping of graphs is not disjoint
820
+ if reaction_filter(reaction):
821
+ # if filter returns True it means the reaction doesn't pass the filter
822
+ reaction.meta["filtration_log"] = reaction_filter.__class__.__name__
823
+ is_filtered = True
824
+ except Exception as e:
825
+ logging.debug(e)
826
+ is_filtered = True
827
+
828
+ return is_filtered, reaction
829
+
830
+
831
+ @ray.remote
832
+ def process_batch(
833
+ batch: List[Tuple[int, ReactionContainer]],
834
+ config: ReactionFilterConfig,
835
+ filters: list,
836
+ ) -> List[Tuple[bool, ReactionContainer]]:
837
+ """
838
+ Processes a batch of reactions to extract reaction rules based on the given
839
+ configuration. This function operates as a remote task in a distributed system using
840
+ Ray.
841
+
842
+ :param batch: A list where each element is a tuple containing an index (int) and a
843
+ ReactionContainer object. The index is typically used to keep track of the
844
+ reaction's position in a larger dataset.
845
+ :param config: Reaction filtration configuration.
846
+ :param filters: The list of reaction filters.
847
+ :return: The list of tuples where each tuple include the reaction index, is ir
848
+ filtered or not (True/False) and reaction itself.
849
+
850
+ """
851
+
852
+ processed_reaction_list = []
853
+ for reaction in batch:
854
+ try: # CGRtools.exceptions.MappingError: atoms with number {52} not equal
855
+ is_filtered, processed_reaction = filter_reaction(reaction, config, filters)
856
+ processed_reaction_list.append((is_filtered, processed_reaction))
857
+ except Exception as e:
858
+ logging.debug(e)
859
+ processed_reaction_list.append((True, reaction))
860
+ return processed_reaction_list
861
+
862
+
863
+ def process_completed_batch(
864
+ futures: Dict,
865
+ result_file: TextIOWrapper,
866
+ n_filtered: int = 0,
867
+ ) -> int:
868
+ """
869
+ Processes completed batches of reactions.
870
+
871
+ :param futures: A dictionary of futures representing ongoing batch processing tasks.
872
+ :param result_file: The path to the file where filtered reactions will be stored.
873
+ :param n_filtered: The number of processed reactions.
874
+ :return: The numbers of filtered and correct reactions.
875
+
876
+ """
877
+
878
+ ready_id, running_id = ray.wait(list(futures.keys()), num_returns=1)
879
+ completed_batch = ray.get(ready_id[0])
880
+
881
+ # write results of the completed batch to file
882
+ for is_filtered, reaction in completed_batch:
883
+ if not is_filtered:
884
+ result_file.write(reaction)
885
+ n_filtered += 1
886
+
887
+ # remove completed future and update progress bar
888
+ del futures[ready_id[0]]
889
+
890
+ return n_filtered
891
+
892
+
893
+ def filter_reactions_from_file(
894
+ config: ReactionFilterConfig,
895
+ input_reaction_data_path: str,
896
+ filtered_reaction_data_path: str = "reaction_data_filtered.smi",
897
+ num_cpus: int = 1,
898
+ batch_size: int = 100,
899
+ ) -> None:
900
+ """
901
+ Processes reaction data, applying reaction filters based on the provided
902
+ configuration, and writes the results to specified files.
903
+
904
+ :param config: ReactionCheckConfig object containing all filtration configuration
905
+ settings.
906
+ :param input_reaction_data_path: Path to the reaction data file.
907
+ :param filtered_reaction_data_path: Name for the file that will contain filtered
908
+ reactions.
909
+ :param num_cpus: Number of CPUs to use for processing.
910
+ :param batch_size: Size of the batch for processing reactions.
911
+ :return: None. The function writes the processed reactions to specified RDF/smi
912
+ files.
913
+
914
+ """
915
+
916
+ filters = config.create_filters()
917
+
918
+ ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
919
+ max_concurrent_batches = num_cpus # limit the number of concurrent batches
920
+ lines_counter = 0
921
+ with ReactionReader(input_reaction_data_path) as reactions, ReactionWriter(
922
+ filtered_reaction_data_path
923
+ ) as result_file:
924
+
925
+ batches_to_process, batch = {}, []
926
+ n_filtered = 0
927
+ for index, reaction in tqdm(
928
+ enumerate(reactions),
929
+ desc="Number of reactions processed: ",
930
+ bar_format="{desc}{n} [{elapsed}]",
931
+ ):
932
+ lines_counter += 1
933
+ batch.append(reaction)
934
+ if len(batch) == batch_size:
935
+ batch_results = process_batch.remote(batch, config, filters)
936
+ batches_to_process[batch_results] = None
937
+ batch = []
938
+
939
+ # check and process completed tasks if we've reached the concurrency limit
940
+ while len(batches_to_process) >= max_concurrent_batches:
941
+ n_filtered = process_completed_batch(
942
+ batches_to_process,
943
+ result_file,
944
+ n_filtered,
945
+ )
946
+
947
+ # process the last batch if it's not empty
948
+ if batch:
949
+ batch_results = process_batch.remote(batch, config, filters)
950
+ batches_to_process[batch_results] = None
951
+
952
+ # process remaining batches
953
+ while batches_to_process:
954
+ n_filtered = process_completed_batch(
955
+ batches_to_process,
956
+ result_file,
957
+ n_filtered,
958
+ )
959
+
960
+ ray.shutdown()
961
+ print(f"Initial number of reactions: {lines_counter}")
962
+ print(f"Filtered number of reactions: {n_filtered}")
synplan/chem/data/standardizing.py ADDED
@@ -0,0 +1,1187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes and functions for reactions standardizing.
2
+
3
+ This module contains the open-source code from
4
+ https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning/blob/master/scripts/standardizer.py
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from contextlib import suppress
11
+ from dataclasses import dataclass
12
+ from io import TextIOWrapper
13
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Sequence, TextIO
14
+ from abc import ABC, abstractmethod
15
+ from pathlib import Path
16
+ import sys
17
+
18
+
19
+ import ray
20
+ import yaml
21
+ from CGRtools import smiles as smiles_cgrtools
22
+ from CGRtools.containers import MoleculeContainer
23
+ from CGRtools.containers import ReactionContainer
24
+ from CGRtools.containers import ReactionContainer as ReactionContainerCGRTools
25
+ from chython import ReactionContainer as ReactionContainerChython
26
+ from chython import smiles as smiles_chython
27
+ from tqdm.auto import tqdm
28
+
29
+ from synplan.chem.utils import unite_molecules
30
+ from synplan.utils.config import ConfigABC
31
+ from synplan.utils.files import ReactionReader, ReactionWriter
32
+ from synplan.utils.logging import init_logger, init_ray_logging
33
+
34
+ logger = logging.getLogger("synplan.chem.data.standardizing")
35
+
36
+
37
+ class StandardizationError(RuntimeError):
38
+ """Wraps the original exception and the reaction string that failed."""
39
+
40
+ def __init__(self, stage: str, reaction: str, original: Exception):
41
+ super().__init__(f"{stage} failed on {reaction}: {original}")
42
+ self.stage = stage
43
+ self.reaction = reaction
44
+ self.original = original
45
+
46
+
47
+ class BaseStandardizer(ABC):
48
+ """Template: subclasses override `_run` only."""
49
+
50
+ @classmethod
51
+ def from_config(cls, _cfg: object) -> "BaseStandardizer":
52
+ return cls()
53
+
54
+ @abstractmethod
55
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
56
+ """Run the standardization step on the reaction.
57
+
58
+ Args:
59
+ rxn: The reaction to standardize
60
+
61
+ Returns:
62
+ The standardized reaction
63
+
64
+ Raises:
65
+ StandardizationError: If standardization fails
66
+ """
67
+ ...
68
+
69
+ def __call__(self, rxn: ReactionContainer) -> ReactionContainer:
70
+ """Execute the standardization step with proper error handling.
71
+
72
+ Args:
73
+ rxn: The reaction to standardize
74
+
75
+ Returns:
76
+ The standardized reaction
77
+
78
+ Raises:
79
+ StandardizationError: If standardization fails
80
+ """
81
+ try:
82
+ return self._run(rxn)
83
+ except Exception as exc:
84
+ logging.debug("%s: %s", self.__class__.__name__, exc, exc_info=True)
85
+ raise StandardizationError(self.__class__.__name__, str(rxn), exc)
86
+
87
+
88
+ # Configuration classes
89
+ @dataclass
90
+ class ReactionMappingConfig:
91
+ pass
92
+
93
+
94
+ class ReactionMappingStandardizer(BaseStandardizer):
95
+ """Maps atoms of the reaction using chython (chytorch)."""
96
+
97
+ def _map_and_remove_reagents(
98
+ self, reaction: ReactionContainerChython
99
+ ) -> ReactionContainerChython:
100
+ """Map and remove reagents from the reaction.
101
+
102
+ Args:
103
+ reaction: Input reaction
104
+
105
+ Returns:
106
+ The mapped reaction with reagents removed
107
+ """
108
+ reaction.reset_mapping()
109
+ reaction.remove_reagents()
110
+ return reaction
111
+
112
+ def _run(self, rxn: ReactionContainerCGRTools) -> ReactionContainerCGRTools:
113
+ """Map atoms of the reaction using chython.
114
+
115
+ Args:
116
+ rxn: Input reaction
117
+
118
+ Returns:
119
+ The mapped reaction
120
+
121
+ Raises:
122
+ StandardizationError: If mapping fails
123
+ """
124
+ try:
125
+ # Convert to chython format
126
+ if isinstance(rxn, str):
127
+ chython_reaction = smiles_chython(rxn)
128
+ else:
129
+ # Convert CGRtools reaction to SMILES string, preserving reagents
130
+ reactants = ".".join(str(m) for m in rxn.reactants)
131
+ reagents = ".".join(str(m) for m in rxn.reagents)
132
+ products = ".".join(str(m) for m in rxn.products)
133
+ smiles = f"{reactants}>{reagents}>{products}"
134
+ # Parse SMILES string with chython
135
+ chython_reaction = smiles_chython(smiles)
136
+
137
+ # Map and remove reagents
138
+ reaction_mapped = self._map_and_remove_reagents(chython_reaction)
139
+ if not reaction_mapped:
140
+ raise StandardizationError(
141
+ "ReactionMapping", str(rxn), ValueError("Mapping failed")
142
+ )
143
+
144
+ # Convert back to CGRtools format
145
+ mapped_smiles = format(chython_reaction, "m")
146
+ result = smiles_cgrtools(mapped_smiles)
147
+ result.meta.update(rxn.meta) # Preserve metadata
148
+ return result
149
+ except Exception as e:
150
+ raise StandardizationError("ReactionMapping", str(rxn), e)
151
+
152
+
153
+ @dataclass
154
+ class FunctionalGroupsConfig:
155
+ pass
156
+
157
+
158
+ class FunctionalGroupsStandardizer(BaseStandardizer):
159
+ """Functional groups standardization."""
160
+
161
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
162
+ """Standardize functional groups in the reaction.
163
+
164
+ Args:
165
+ rxn: Input reaction
166
+
167
+ Returns:
168
+ The reaction with standardized functional groups
169
+
170
+ Raises:
171
+ StandardizationError: If standardization fails
172
+ """
173
+ rxn.standardize()
174
+ return rxn
175
+
176
+
177
+ @dataclass
178
+ class KekuleFormConfig:
179
+ pass
180
+
181
+
182
+ class KekuleFormStandardizer(BaseStandardizer):
183
+ """Reactants/reagents/products kekulization."""
184
+
185
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
186
+ """Kekulize the reaction.
187
+
188
+ Args:
189
+ rxn: The reaction to kekulize
190
+
191
+ Returns:
192
+ The kekulized reaction
193
+
194
+ Raises:
195
+ StandardizationError: If kekulization fails
196
+ """
197
+ rxn.kekule()
198
+ return rxn
199
+
200
+
201
+ @dataclass
202
+ class CheckValenceConfig:
203
+ pass
204
+
205
+
206
+ class CheckValenceStandardizer(BaseStandardizer):
207
+ """Check valence."""
208
+
209
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
210
+ """Check valence of atoms in the reaction.
211
+
212
+ Args:
213
+ rxn: Input reaction
214
+
215
+ Returns:
216
+ The reaction if valences are correct
217
+
218
+ Raises:
219
+ StandardizationError: If valence check fails
220
+ """
221
+ for molecule in rxn.reactants + rxn.products + rxn.reagents:
222
+ valence_mistakes = molecule.check_valence()
223
+ if valence_mistakes:
224
+ raise StandardizationError(
225
+ "CheckValence",
226
+ str(rxn),
227
+ ValueError(f"Valence errors: {valence_mistakes}"),
228
+ )
229
+ return rxn
230
+
231
+
232
+ @dataclass
233
+ class ImplicifyHydrogensConfig:
234
+ pass
235
+
236
+
237
+ class ImplicifyHydrogensStandardizer(BaseStandardizer):
238
+ """Implicify hydrogens."""
239
+
240
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
241
+ """Implicify hydrogens in the reaction.
242
+
243
+ Args:
244
+ rxn: Input reaction
245
+
246
+ Returns:
247
+ The reaction with implicified hydrogens
248
+
249
+ Raises:
250
+ StandardizationError: If hydrogen implicification fails
251
+ """
252
+ rxn.implicify_hydrogens()
253
+ return rxn
254
+
255
+
256
+ @dataclass
257
+ class CheckIsotopesConfig:
258
+ pass
259
+
260
+
261
+ class CheckIsotopesStandardizer(BaseStandardizer):
262
+ """Check isotopes."""
263
+
264
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
265
+ """Check and clean isotopes in the reaction.
266
+
267
+ Args:
268
+ rxn: Input reaction
269
+
270
+ Returns:
271
+ The reaction with cleaned isotopes
272
+
273
+ Raises:
274
+ StandardizationError: If isotope check/cleaning fails
275
+ """
276
+ is_isotope = False
277
+ for molecule in rxn.reactants + rxn.products:
278
+ for _, atom in molecule.atoms():
279
+ if atom.isotope:
280
+ is_isotope = True
281
+ break
282
+ if is_isotope:
283
+ break
284
+
285
+ if is_isotope:
286
+ rxn.clean_isotopes()
287
+
288
+ return rxn
289
+
290
+
291
+ @dataclass
292
+ class SplitIonsConfig:
293
+ pass
294
+
295
+
296
+ class SplitIonsStandardizer(BaseStandardizer):
297
+ """Computing charge of molecule."""
298
+
299
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
300
+ """Split ions in the reaction.
301
+
302
+ Args:
303
+ rxn: Input reaction
304
+
305
+ Returns:
306
+ The reaction with split ions
307
+
308
+ Raises:
309
+ StandardizationError: If ion splitting fails
310
+ """
311
+ reaction, return_code = self._split_ions(rxn)
312
+ if return_code == 2: # ions were split but the reaction is imbalanced
313
+ raise StandardizationError(
314
+ "SplitIons",
315
+ str(rxn),
316
+ ValueError("Reaction is imbalanced after ion splitting"),
317
+ )
318
+ return reaction
319
+
320
+ def _calc_charge(self, molecule: MoleculeContainer) -> int:
321
+ """Compute total charge of a molecule.
322
+
323
+ Args:
324
+ molecule: Input molecule
325
+
326
+ Returns:
327
+ The total charge of the molecule
328
+ """
329
+ return sum(molecule._charges.values())
330
+
331
+ def _split_ions(self, reaction: ReactionContainer) -> Tuple[ReactionContainer, int]:
332
+ """Split ions in a reaction.
333
+
334
+ Args:
335
+ reaction: Input reaction
336
+
337
+ Returns:
338
+ A tuple containing:
339
+ - The reaction with split ions
340
+ - Return code (0: nothing changed, 1: ions split, 2: ions split but imbalanced)
341
+ """
342
+ meta = reaction.meta
343
+ reaction_parts = []
344
+ return_codes = []
345
+
346
+ for molecules in (reaction.reactants, reaction.reagents, reaction.products):
347
+ # Split molecules into individual components
348
+ divided_molecules = []
349
+ for molecule in molecules:
350
+ if isinstance(molecule, str):
351
+ # If it's a string, try to parse it as a molecule
352
+ try:
353
+ molecule: MoleculeContainer = smiles_cgrtools(molecule)
354
+ except Exception as e:
355
+ logging.warning("Failed to parse molecule %s: %s", molecule, e)
356
+ continue
357
+
358
+ # Use the split method from CGRtools
359
+ try:
360
+ components = molecule.split()
361
+ divided_molecules.extend(components)
362
+ except Exception as e:
363
+ logging.warning("Failed to split molecule %s: %s", molecule, e)
364
+ divided_molecules.append(molecule)
365
+
366
+ total_charge = 0
367
+ ions_present = False
368
+ for molecule in divided_molecules:
369
+ try:
370
+ mol_charge = self._calc_charge(molecule)
371
+ total_charge += mol_charge
372
+ if mol_charge != 0:
373
+ ions_present = True
374
+ except Exception as e:
375
+ logging.warning(
376
+ "Failed to calculate charge for molecule %s: %s", molecule, e
377
+ )
378
+ continue
379
+
380
+ if ions_present and total_charge:
381
+ return_codes.append(2)
382
+ elif ions_present:
383
+ return_codes.append(1)
384
+ else:
385
+ return_codes.append(0)
386
+
387
+ reaction_parts.append(tuple(divided_molecules))
388
+
389
+ return (
390
+ ReactionContainer(
391
+ reactants=reaction_parts[0],
392
+ reagents=reaction_parts[1],
393
+ products=reaction_parts[2],
394
+ meta=meta,
395
+ ),
396
+ max(return_codes),
397
+ )
398
+
399
+
400
+ @dataclass
401
+ class AromaticFormConfig:
402
+ pass
403
+
404
+
405
+ class AromaticFormStandardizer(BaseStandardizer):
406
+ """Aromatize molecules in reaction."""
407
+
408
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
409
+ """Aromatize molecules in the reaction.
410
+
411
+ Args:
412
+ rxn: Input reaction
413
+
414
+ Returns:
415
+ The reaction with aromatized molecules
416
+
417
+ Raises:
418
+ StandardizationError: If aromatization fails
419
+ """
420
+ rxn.thiele()
421
+ return rxn
422
+
423
+
424
+ @dataclass
425
+ class MappingFixConfig:
426
+ pass
427
+
428
+
429
+ class MappingFixStandardizer(BaseStandardizer):
430
+ """Fix atom-to-atom mapping in reaction."""
431
+
432
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
433
+ """Fix atom-to-atom mapping in the reaction.
434
+
435
+ Args:
436
+ rxn: Input reaction
437
+
438
+ Returns:
439
+ The reaction with fixed atom-to-atom mapping
440
+
441
+ Raises:
442
+ StandardizationError: If mapping fix fails
443
+ """
444
+ rxn.fix_mapping()
445
+ return rxn
446
+
447
+
448
+ @dataclass
449
+ class UnchangedPartsConfig:
450
+ pass
451
+
452
+
453
+ class UnchangedPartsStandardizer(BaseStandardizer):
454
+ """Ungroup molecules, remove unchanged parts from reactants and products."""
455
+
456
+ def __init__(
457
+ self,
458
+ add_reagents_to_reactants: bool = False,
459
+ keep_reagents: bool = False,
460
+ ):
461
+ self.add_reagents_to_reactants = add_reagents_to_reactants
462
+ self.keep_reagents = keep_reagents
463
+
464
+ @classmethod
465
+ def from_config(cls, config: UnchangedPartsConfig) -> "UnchangedPartsStandardizer":
466
+ return cls()
467
+
468
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
469
+ """Remove unchanged parts from the reaction.
470
+
471
+ Args:
472
+ rxn: Input reaction
473
+
474
+ Returns:
475
+ The reaction with unchanged parts removed
476
+
477
+ Raises:
478
+ StandardizationError: If unchanged parts removal fails
479
+ """
480
+ meta = rxn.meta
481
+ new_reactants = list(rxn.reactants)
482
+ new_reagents = list(rxn.reagents)
483
+ if self.add_reagents_to_reactants:
484
+ new_reactants.extend(new_reagents)
485
+ new_reagents = []
486
+ reactants = new_reactants.copy()
487
+ new_products = list(rxn.products)
488
+
489
+ for reactant in reactants:
490
+ if reactant in new_products:
491
+ new_reagents.append(reactant)
492
+ new_reactants.remove(reactant)
493
+ new_products.remove(reactant)
494
+ if not self.keep_reagents:
495
+ new_reagents = []
496
+
497
+ if not new_reactants and new_products:
498
+ raise StandardizationError(
499
+ "UnchangedParts", str(rxn), ValueError("No reactants left")
500
+ )
501
+ if not new_products and new_reactants:
502
+ raise StandardizationError(
503
+ "UnchangedParts", str(rxn), ValueError("No products left")
504
+ )
505
+ if not new_reactants and not new_products:
506
+ raise StandardizationError(
507
+ "UnchangedParts", str(rxn), ValueError("No molecules left")
508
+ )
509
+
510
+ new_reaction = ReactionContainer(
511
+ reactants=tuple(new_reactants),
512
+ reagents=tuple(new_reagents),
513
+ products=tuple(new_products),
514
+ meta=meta,
515
+ )
516
+ new_reaction.name = rxn.name
517
+ return new_reaction
518
+
519
+
520
+ @dataclass
521
+ class SmallMoleculesConfig:
522
+ mol_max_size: int = 6
523
+
524
+ @staticmethod
525
+ def from_dict(config_dict: Dict[str, Any]) -> "SmallMoleculesConfig":
526
+ """Create an instance of SmallMoleculesConfig from a dictionary."""
527
+ return SmallMoleculesConfig(**config_dict)
528
+
529
+ @staticmethod
530
+ def from_yaml(file_path: str) -> "SmallMoleculesConfig":
531
+ """Deserialize a YAML file into a SmallMoleculesConfig object."""
532
+ with open(file_path, "r", encoding="utf-8") as file:
533
+ config_dict = yaml.safe_load(file)
534
+ return SmallMoleculesConfig.from_dict(config_dict)
535
+
536
+ def _validate_params(self, params: Dict[str, Any]) -> None:
537
+ """Validate configuration parameters."""
538
+ mol_max_size = params.get("mol_max_size", self.mol_max_size)
539
+ if not isinstance(mol_max_size, int) or not (0 < mol_max_size):
540
+ raise ValueError("Invalid 'mol_max_size'; expected an integer more than 1")
541
+
542
+
543
+ class SmallMoleculesStandardizer(BaseStandardizer):
544
+ """Remove small molecule from reaction."""
545
+
546
+ def __init__(self, mol_max_size: int = 6):
547
+ self.mol_max_size = mol_max_size
548
+
549
+ @classmethod
550
+ def from_config(cls, config: SmallMoleculesConfig) -> "SmallMoleculesStandardizer":
551
+ return cls(config.mol_max_size)
552
+
553
+ def _split_molecules(
554
+ self, molecules: Iterable, number_of_atoms: int
555
+ ) -> Tuple[List[MoleculeContainer], List[MoleculeContainer]]:
556
+ """Split molecules according to the number of heavy atoms.
557
+
558
+ Args:
559
+ molecules: Iterable of molecules
560
+ number_of_atoms: Threshold for splitting molecules
561
+
562
+ Returns:
563
+ Tuple of lists containing "big" molecules and "small" molecules
564
+ """
565
+ big_molecules, small_molecules = [], []
566
+ for molecule in molecules:
567
+ if len(molecule) > number_of_atoms:
568
+ big_molecules.append(molecule)
569
+ else:
570
+ small_molecules.append(molecule)
571
+ return big_molecules, small_molecules
572
+
573
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
574
+ """Remove small molecules from the reaction.
575
+
576
+ Args:
577
+ rxn: Input reaction
578
+
579
+ Returns:
580
+ The reaction without small molecules
581
+
582
+ Raises:
583
+ StandardizationError: If small molecule removal fails
584
+ """
585
+ new_reactants, small_reactants = self._split_molecules(
586
+ rxn.reactants, self.mol_max_size
587
+ )
588
+ new_products, small_products = self._split_molecules(
589
+ rxn.products, self.mol_max_size
590
+ )
591
+
592
+ if not new_reactants or not new_products:
593
+ raise StandardizationError(
594
+ "SmallMolecules",
595
+ str(rxn),
596
+ ValueError("No molecules left after removing small ones"),
597
+ )
598
+
599
+ new_reaction = ReactionContainer(
600
+ new_reactants, new_products, rxn.reagents, rxn.meta
601
+ )
602
+ new_reaction.name = rxn.name
603
+
604
+ # Save small molecules to meta
605
+ united_small_reactants = unite_molecules(small_reactants)
606
+ new_reaction.meta["small_reactants"] = str(united_small_reactants)
607
+ united_small_products = unite_molecules(small_products)
608
+ new_reaction.meta["small_products"] = str(united_small_products)
609
+
610
+ return new_reaction
611
+
612
+
613
+ @dataclass
614
+ class RemoveReagentsConfig:
615
+ reagent_max_size: int = 7
616
+
617
+ @staticmethod
618
+ def from_dict(config_dict: Dict[str, Any]) -> "RemoveReagentsConfig":
619
+ """Create an instance of RemoveReagentsConfig from a dictionary."""
620
+ return RemoveReagentsConfig(**config_dict)
621
+
622
+ @staticmethod
623
+ def from_yaml(file_path: str) -> "RemoveReagentsConfig":
624
+ """Deserialize a YAML file into a RemoveReagentsConfig object."""
625
+ with open(file_path, "r", encoding="utf-8") as file:
626
+ config_dict = yaml.safe_load(file)
627
+ return RemoveReagentsConfig.from_dict(config_dict)
628
+
629
+ def _validate_params(self, params: Dict[str, Any]) -> None:
630
+ """Validate configuration parameters."""
631
+ reagent_max_size = params.get("reagent_max_size", self.reagent_max_size)
632
+ if not isinstance(reagent_max_size, int) or not (0 < reagent_max_size):
633
+ raise ValueError(
634
+ "Invalid 'reagent_max_size'; expected an integer more than 1"
635
+ )
636
+
637
+
638
+ class RemoveReagentsStandardizer(BaseStandardizer):
639
+ """Remove reagents from reaction."""
640
+
641
+ def __init__(self, reagent_max_size: int = 7):
642
+ self.reagent_max_size = reagent_max_size
643
+
644
+ @classmethod
645
+ def from_config(cls, config: RemoveReagentsConfig) -> "RemoveReagentsStandardizer":
646
+ return cls(config.reagent_max_size)
647
+
648
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
649
+ """Remove reagents from the reaction.
650
+
651
+ Args:
652
+ rxn: Input reaction
653
+
654
+ Returns:
655
+ The reaction without reagents
656
+
657
+ Raises:
658
+ StandardizationError: If reagent removal fails
659
+ """
660
+ not_changed_molecules = set(rxn.reactants).intersection(rxn.products)
661
+ cgr = ~rxn
662
+ center_atoms = set(cgr.center_atoms)
663
+
664
+ new_reactants = []
665
+ new_products = []
666
+ new_reagents = []
667
+
668
+ for molecule in rxn.reactants:
669
+ if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
670
+ new_reagents.append(molecule)
671
+ else:
672
+ new_reactants.append(molecule)
673
+
674
+ for molecule in rxn.products:
675
+ if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
676
+ new_reagents.append(molecule)
677
+ else:
678
+ new_products.append(molecule)
679
+
680
+ if not new_reactants or not new_products:
681
+ raise StandardizationError(
682
+ "RemoveReagents",
683
+ str(rxn),
684
+ ValueError("No molecules left after removing reagents"),
685
+ )
686
+
687
+ # Filter reagents by size
688
+ new_reagents = {
689
+ molecule
690
+ for molecule in new_reagents
691
+ if len(molecule) <= self.reagent_max_size
692
+ }
693
+
694
+ new_reaction = ReactionContainer(
695
+ new_reactants, new_products, new_reagents, rxn.meta
696
+ )
697
+ new_reaction.name = rxn.name
698
+
699
+ return new_reaction
700
+
701
+
702
+ @dataclass
703
+ class RebalanceReactionConfig:
704
+ pass
705
+
706
+
707
+ class RebalanceReactionStandardizer(BaseStandardizer):
708
+ """Rebalance reaction."""
709
+
710
+ @classmethod
711
+ def from_config(
712
+ cls, config: RebalanceReactionConfig
713
+ ) -> "RebalanceReactionStandardizer":
714
+ return cls()
715
+
716
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
717
+ """Rebalances the reaction by assembling CGR and then decomposing it. Works for
718
+ all reactions for which the correct CGR can be assembled.
719
+
720
+ Args:
721
+ rxn: Input reaction
722
+
723
+ Returns:
724
+ The rebalanced reaction
725
+
726
+ Raises:
727
+ StandardizationError: If rebalancing fails
728
+ """
729
+ try:
730
+ tmp_rxn = ReactionContainer(rxn.reactants, rxn.products)
731
+ cgr = ~tmp_rxn
732
+ reactants, products = ~cgr
733
+ new_rxn = ReactionContainer(
734
+ reactants.split(), products.split(), rxn.reagents, rxn.meta
735
+ )
736
+ new_rxn.name = rxn.name
737
+ return new_rxn
738
+ except Exception as e:
739
+ logging.debug(f"Rebalancing attempt failed: {e}")
740
+ raise StandardizationError(
741
+ "RebalanceReaction",
742
+ str(rxn),
743
+ ValueError("Failed to rebalance reaction"),
744
+ )
745
+
746
+
747
+ @dataclass
748
+ class DuplicateReactionConfig:
749
+ pass
750
+
751
+
752
+ class DuplicateReactionStandardizer(BaseStandardizer):
753
+ """Cluster‑wide duplicate removal via a Ray actor."""
754
+
755
+ def __init__(self, dedup_actor: "ray.actor.ActorHandle"):
756
+ self._actor = dedup_actor # global singleton handle
757
+ # local fast‑path cache to avoid actor call on obvious repeats *in
758
+ # the same worker*; purely an optimisation, not required.
759
+ self._local_seen: set[int] = set()
760
+
761
+ @classmethod
762
+ def from_config(cls, config: DuplicateReactionConfig):
763
+ # fallback for single‑process mode: create a dummy in‑proc actor
764
+ if ray.is_initialized():
765
+ dedup_actor = ray.get_actor("duplicate_rxn_actor")
766
+ else:
767
+ dedup_actor = None
768
+ return cls(dedup_actor)
769
+
770
+ # ------------------------------------------------------------------
771
+ def safe_reaction_smiles(self, reaction: ReactionContainer) -> str:
772
+ reactants_smi = ".".join(str(i) for i in reaction.reactants)
773
+ products_smi = ".".join(str(i) for i in reaction.products)
774
+ return f"{reactants_smi}>>{products_smi}"
775
+
776
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
777
+ h = hash(self.safe_reaction_smiles(rxn))
778
+
779
+ # local cache fast‑path (helps in large batches processed by same
780
+ # worker; no correctness impact).
781
+ if h in self._local_seen:
782
+ raise StandardizationError(
783
+ "DuplicateReaction", str(rxn), ValueError("Duplicate reaction found")
784
+ )
785
+
786
+ # ------------------- cluster‑wide check ------------------------
787
+ if self._actor is None: # single‑CPU fall‑back
788
+ is_new = h not in self._local_seen
789
+ else:
790
+ # synchronous, returns True/False
791
+ is_new = ray.get(self._actor.check_and_add.remote(h))
792
+
793
+ if is_new:
794
+ self._local_seen.add(h)
795
+ return rxn
796
+
797
+ raise StandardizationError(
798
+ "DuplicateReaction", str(rxn), ValueError("Duplicate reaction found")
799
+ )
800
+
801
+
802
+ @ray.remote
803
+ class DedupActor:
804
+ """Cluster‑wide set of reaction hashes."""
805
+
806
+ def __init__(self):
807
+ self._seen: set[int] = set()
808
+
809
+ def check_and_add(self, h: int) -> bool:
810
+ """
811
+ Returns True **iff** the hash was not present yet and is now stored.
812
+ Cluster‑wide uniqueness is guaranteed because this method executes
813
+ serially inside the actor process.
814
+ """
815
+ if h in self._seen:
816
+ return False
817
+ self._seen.add(h)
818
+ return True
819
+
820
+
821
+ # Registry mapping config field names to standardizer classes
822
+ STANDARDIZER_REGISTRY = {
823
+ "reaction_mapping_config": ReactionMappingStandardizer,
824
+ "functional_groups_config": FunctionalGroupsStandardizer,
825
+ "kekule_form_config": KekuleFormStandardizer,
826
+ "check_valence_config": CheckValenceStandardizer,
827
+ "implicify_hydrogens_config": ImplicifyHydrogensStandardizer,
828
+ "check_isotopes_config": CheckIsotopesStandardizer,
829
+ "split_ions_config": SplitIonsStandardizer,
830
+ "aromatic_form_config": AromaticFormStandardizer,
831
+ "mapping_fix_config": MappingFixStandardizer,
832
+ "unchanged_parts_config": UnchangedPartsStandardizer,
833
+ "small_molecules_config": SmallMoleculesStandardizer,
834
+ "remove_reagents_config": RemoveReagentsStandardizer,
835
+ "rebalance_reaction_config": RebalanceReactionStandardizer,
836
+ "duplicate_reaction_config": DuplicateReactionStandardizer,
837
+ }
838
+
839
+
840
+ @dataclass
841
+ class ReactionStandardizationConfig(ConfigABC):
842
+ """Configuration class for reaction filtering. This class manages configuration
843
+ settings for various reaction filters, including paths, file formats, and filter-
844
+ specific parameters.
845
+
846
+ :param reaction_mapping_config: Configuration for reaction mapping.
847
+ :param functional_groups_config: Configuration for functional groups
848
+ standardization.
849
+ :param kekule_form_config: Configuration for reactants/reagents/products
850
+ kekulization.
851
+ :param check_valence_config: Configuration for atom valence checking.
852
+ :param implicify_hydrogens_config: Configuration for hydrogens removal.
853
+ :param check_isotopes_config: Configuration for isotopes checking and cleaning.
854
+ :param split_ions_config: Configuration for computing charge of molecule.
855
+ :param aromatic_form_config: Configuration for molecules aromatization.
856
+ :param unchanged_parts_config: Configuration for removal of unchanged parts in
857
+ reaction.
858
+ :param small_molecules_config: Configuration for removal of small molecule from
859
+ reaction.
860
+ :param remove_reagents_config: Configuration for removal of reagents from reaction.
861
+ :param rebalance_reaction_config: Configuration for reaction rebalancing.
862
+ :param duplicate_reaction_config: Configuration for removal of duplicate reactions.
863
+ """
864
+
865
+ # configuration for reaction standardizers
866
+ reaction_mapping_config: Optional[ReactionMappingConfig] = None
867
+ functional_groups_config: Optional[FunctionalGroupsConfig] = None
868
+ kekule_form_config: Optional[KekuleFormConfig] = None
869
+ check_valence_config: Optional[CheckValenceConfig] = None
870
+ implicify_hydrogens_config: Optional[ImplicifyHydrogensConfig] = None
871
+ check_isotopes_config: Optional[CheckIsotopesConfig] = None
872
+ split_ions_config: Optional[SplitIonsConfig] = None
873
+ aromatic_form_config: Optional[AromaticFormConfig] = None
874
+ mapping_fix_config: Optional[MappingFixConfig] = None
875
+ unchanged_parts_config: Optional[UnchangedPartsConfig] = None
876
+ small_molecules_config: Optional[SmallMoleculesConfig] = None
877
+ remove_reagents_config: Optional[RemoveReagentsConfig] = None
878
+ rebalance_reaction_config: Optional[RebalanceReactionConfig] = None
879
+ duplicate_reaction_config: Optional[DuplicateReactionConfig] = None
880
+
881
+ def _validate_params(self, params: Dict[str, Any]) -> None:
882
+ """Validate configuration parameters."""
883
+ for field_name, config in self.__dict__.items():
884
+ if config is not None and hasattr(config, "_validate_params"):
885
+ config._validate_params(params.get(field_name, {}))
886
+
887
+ def to_dict(self):
888
+ """Converts the configuration into a dictionary."""
889
+ config_dict = {}
890
+ for field_name in STANDARDIZER_REGISTRY:
891
+ config = getattr(self, field_name)
892
+ if config is not None:
893
+ config_dict[field_name] = {}
894
+ return config_dict
895
+
896
+ @staticmethod
897
+ def from_dict(config_dict: Dict[str, Any]) -> "ReactionStandardizationConfig":
898
+ """Create an instance of ReactionCheckConfig from a dictionary."""
899
+ config_kwargs = {}
900
+ for field_name, std_cls in STANDARDIZER_REGISTRY.items():
901
+ if field_name in config_dict:
902
+ config_kwargs[field_name] = std_cls.__name__.replace(
903
+ "Standardizer", "Config"
904
+ )()
905
+ return ReactionStandardizationConfig(**config_kwargs)
906
+
907
+ @staticmethod
908
+ def from_yaml(file_path: str) -> "ReactionStandardizationConfig":
909
+ """Deserializes a YAML file into a ReactionCheckConfig object."""
910
+ with open(file_path, "r", encoding="utf-8") as file:
911
+ config_dict = yaml.safe_load(file)
912
+ return ReactionStandardizationConfig.from_dict(config_dict)
913
+
914
+ def create_standardizers(self):
915
+ """Create standardizer instances based on configuration."""
916
+ standardizers = []
917
+ for field_name, std_cls in STANDARDIZER_REGISTRY.items():
918
+ config = getattr(self, field_name)
919
+ if config is not None:
920
+ standardizers.append(std_cls.from_config(config))
921
+ return standardizers
922
+
923
+
924
+ def standardize_reaction(
925
+ reaction: ReactionContainer,
926
+ standardizers: Sequence,
927
+ ) -> ReactionContainer | None:
928
+ """
929
+ Apply each standardizer in order.
930
+
931
+ Returns
932
+ -------
933
+ ReactionContainer | None
934
+ - the fully‑standardised reaction, or
935
+ - None if *any* standardizer decides to filter it out.
936
+
937
+ Raises
938
+ ------
939
+ StandardizationError
940
+ Propagated untouched so the caller can decide what to do.
941
+ """
942
+ std_rxn = reaction
943
+ for std in standardizers:
944
+ logger.debug(" › %s(%s)", std.__class__.__name__, std_rxn)
945
+ try:
946
+ std_rxn = std(std_rxn) # may return None
947
+ if std_rxn is None: # soft filter
948
+ logger.info("%s filtered out reaction", std.__class__.__name__)
949
+ return None
950
+ except StandardizationError as exc:
951
+ # Log *once*, then re‑raise with full traceback intact
952
+ logger.warning(
953
+ "%s failed on reaction %s : %s",
954
+ std.__class__.__name__,
955
+ std_rxn,
956
+ exc,
957
+ )
958
+ raise # re‑raise same object
959
+ return std_rxn
960
+
961
+
962
+ def safe_standardize(
963
+ item: str | ReactionContainer,
964
+ standardizers: Sequence,
965
+ ) -> Tuple[ReactionContainer, bool]:
966
+ """
967
+ Always returns a ReactionContainer. The boolean flags real success.
968
+ """
969
+ try:
970
+ # Parse only if needed
971
+ reaction = (
972
+ item if isinstance(item, ReactionContainer) else smiles_cgrtools(item)
973
+ )
974
+ std = standardize_reaction(reaction, standardizers)
975
+ if std is None:
976
+ return reaction, False # filtered → keep original
977
+ return std, True
978
+ except Exception as exc: # noqa: BLE001
979
+ # keep the original container (parse if it was a string)
980
+ if isinstance(item, ReactionContainer):
981
+ return item, False
982
+ return smiles_cgrtools(item), False
983
+
984
+
985
+ def _process_batch(
986
+ batch: Sequence[str | ReactionContainer],
987
+ standardizers: Sequence,
988
+ ) -> Tuple[List[ReactionContainer], int]:
989
+ results: List[ReactionContainer] = []
990
+ n_std = 0
991
+ for item in batch:
992
+ rxn, ok = safe_standardize(item, standardizers)
993
+ results.append(rxn)
994
+ n_std += ok
995
+ return results, n_std
996
+
997
+
998
+ @ray.remote
999
+ def process_batch_remote(
1000
+ batch: Sequence[str | ReactionContainer],
1001
+ std_param: ray.ObjectRef, # <-- receives a ref
1002
+ log_file_path: str | Path | None = None,
1003
+ ) -> Tuple[List[ReactionContainer], int]:
1004
+ # Ray keeps a local cache of fetched objects, so the list is
1005
+ # deserialised only once per worker process, not once per task.
1006
+ if isinstance(std_param, ray.ObjectRef): # handle? get it
1007
+ standardizers = ray.get(std_param) # • O(once)
1008
+ else: # plain list? use as is
1009
+ standardizers = std_param
1010
+
1011
+ # --- Worker-specific logging setup ---
1012
+ worker_logger = logging.getLogger("synplan.chem.data.standardizing")
1013
+ if log_file_path:
1014
+ log_file_path = Path(log_file_path) # Ensure it's a Path object
1015
+ # Check if a handler for this file already exists for this logger
1016
+ handler_exists = any(
1017
+ isinstance(h, logging.FileHandler) and Path(h.baseFilename) == log_file_path
1018
+ for h in worker_logger.handlers
1019
+ )
1020
+ if not handler_exists:
1021
+ try:
1022
+ fh = logging.FileHandler(log_file_path, encoding="utf-8")
1023
+ # Use a simple format for worker logs, or match driver's format
1024
+ formatter = logging.Formatter(
1025
+ "%(asctime)s | %(name)s (worker) | %(levelname)-8s | %(message)s",
1026
+ datefmt="%Y-%m-%d %H:%M:%S",
1027
+ )
1028
+ fh.setFormatter(formatter)
1029
+ fh.setLevel(logging.INFO) # Or DEBUG, or use worker_log_level if passed
1030
+ worker_logger.addHandler(fh)
1031
+ worker_logger.setLevel(
1032
+ logging.INFO
1033
+ ) # Ensure logger passes messages to handler
1034
+ worker_logger.propagate = (
1035
+ False # Avoid double logging if driver also logs
1036
+ )
1037
+ # Optional: Log that the handler was added
1038
+ # worker_logger.info(f"Worker process attached file handler: {log_file_path}")
1039
+ except Exception as e:
1040
+ # Log error if handler creation fails (e.g., permissions)
1041
+ logging.error(
1042
+ f"Worker failed to create file handler {log_file_path}: {e}"
1043
+ )
1044
+
1045
+ return _process_batch(batch, standardizers)
1046
+
1047
+
1048
+ def chunked(iterable: Iterable, size: int):
1049
+ chunk = []
1050
+ for it in iterable:
1051
+ chunk.append(it)
1052
+ if len(chunk) == size:
1053
+ yield chunk
1054
+ chunk = []
1055
+ if chunk:
1056
+ yield chunk
1057
+
1058
+
1059
+ def standardize_reactions_from_file(
1060
+ config: "ReactionStandardizationConfig",
1061
+ input_reaction_data_path: str | Path,
1062
+ standardized_reaction_data_path: str | Path = "reaction_data_standardized.smi",
1063
+ *,
1064
+ num_cpus: int = 1,
1065
+ batch_size: int = 1_000, # larger batches amortise overhead
1066
+ silent: bool = True,
1067
+ max_pending_factor: int = 4, # tasks in flight = factor × CPUs
1068
+ worker_log_level: int | str = logging.WARNING,
1069
+ log_file_path: str | Path | None = None,
1070
+ ) -> None:
1071
+ """
1072
+ Reads reactions, standardises them in parallel with Ray, writes results.
1073
+
1074
+ The function keeps at most `max_pending_factor * num_cpus` Ray tasks in
1075
+ flight to avoid flooding the scheduler and blowing up the object store.
1076
+ Standardisers are broadcast once with `ray.put`, removing per‑task
1077
+ pickling cost. All other logic is unchanged.
1078
+
1079
+ Args:
1080
+ config: Configuration object for standardizers.
1081
+ input_reaction_data_path: Path to the input reaction data file.
1082
+ standardized_reaction_data_path: Path to save the standardized reactions.
1083
+ num_cpus: Number of CPU cores to use for parallel processing.
1084
+ batch_size: Number of reactions to process in each batch.
1085
+ silent: If True, suppress the progress bar.
1086
+ max_pending_factor: Controls the number of pending Ray tasks.
1087
+ worker_log_level: Logging level for Ray workers (e.g., logging.INFO, logging.WARNING).
1088
+ log_file_path: Path to the log file for workers to write to.
1089
+ """
1090
+ output_path = Path(standardized_reaction_data_path)
1091
+ standardizers = config.create_standardizers()
1092
+
1093
+ logger.info(
1094
+ "Standardizers: %s",
1095
+ ", ".join(s.__class__.__name__ for s in standardizers),
1096
+ )
1097
+
1098
+ # ----------------------- Ray initialisation -----------------------
1099
+ if num_cpus > 1:
1100
+ if not ray.is_initialized():
1101
+ ray.init(
1102
+ num_cpus=num_cpus,
1103
+ ignore_reinit_error=True,
1104
+ logging_level=worker_log_level,
1105
+ log_to_driver=False,
1106
+ )
1107
+
1108
+ DEDUP_NAME = "duplicate_rxn_actor"
1109
+
1110
+ try:
1111
+ dedup_actor = ray.get_actor(DEDUP_NAME) # already running?
1112
+ except ValueError:
1113
+ dedup_actor = DedupActor.options(
1114
+ name=DEDUP_NAME, lifetime="detached" # survives driver exit
1115
+ ).remote()
1116
+
1117
+ std_ref: ray.ObjectRef | None = None
1118
+ if num_cpus > 1 and std_ref is None: # broadcast once
1119
+ std_ref = ray.put(standardizers)
1120
+
1121
+ max_pending = max_pending_factor * num_cpus
1122
+ pending: Dict[ray.ObjectRef, None] = {}
1123
+
1124
+ n_processed = n_std = 0
1125
+ bar = tqdm(
1126
+ total=0,
1127
+ unit="rxn",
1128
+ desc="Standardising",
1129
+ disable=silent,
1130
+ dynamic_ncols=True,
1131
+ )
1132
+
1133
+ # ------------------------ Helper function ------------------------
1134
+ def _flush(ref: ray.ObjectRef, write_fn) -> None:
1135
+ """Fetch finished task, write its results, update counters & bar."""
1136
+ nonlocal n_processed, n_std
1137
+ res, ok = ray.get(ref)
1138
+ write_fn(res)
1139
+ bar.update(len(res))
1140
+ n_processed += len(res)
1141
+ n_std += ok
1142
+
1143
+ # ----------------------------- I/O -------------------------------
1144
+ with ReactionReader(input_reaction_data_path) as reader, ReactionWriter(
1145
+ output_path
1146
+ ) as writer:
1147
+
1148
+ write_fn = lambda reactions: [writer.write(r) for r in reactions]
1149
+
1150
+ # --------------------- Main read/compute loop -----------------
1151
+ for chunk in chunked(reader, batch_size):
1152
+ bar.total += len(chunk)
1153
+ bar.refresh()
1154
+
1155
+ if num_cpus > 1:
1156
+ # ---------- back‑pressure: keep ≤ max_pending ----------
1157
+ while len(pending) >= max_pending:
1158
+ done, _ = ray.wait(list(pending), num_returns=1)
1159
+ _flush(done[0], write_fn)
1160
+ pending.pop(done[0], None)
1161
+
1162
+ # ----------- schedule new task -------------------------
1163
+ ref = process_batch_remote.remote(chunk, std_ref, log_file_path)
1164
+ pending[ref] = None
1165
+ else:
1166
+ # --------------- serial fall‑back ----------------------
1167
+ res, ok = _process_batch(chunk, standardizers)
1168
+ write_fn(res)
1169
+ bar.update(len(res))
1170
+ n_processed += len(res)
1171
+ n_std += ok
1172
+
1173
+ # ------------------ Drain remaining Ray tasks -----------------
1174
+ while pending:
1175
+ done, _ = ray.wait(list(pending), num_returns=1)
1176
+ _flush(done[0], write_fn)
1177
+ pending.pop(done[0], None)
1178
+
1179
+ bar.close()
1180
+ ray.shutdown()
1181
+
1182
+ logger.info(
1183
+ "Finished: processed %d, standardised %d, filtered %d",
1184
+ n_processed,
1185
+ n_std,
1186
+ n_processed - n_std,
1187
+ )
synplan/chem/precursor.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class Precursor that represents a precursor (extend molecule object) in
2
+ the search tree."""
3
+
4
+ from typing import Set
5
+
6
+ from CGRtools.containers import MoleculeContainer
7
+
8
+ from synplan.chem.utils import safe_canonicalization
9
+
10
+
11
+ class Precursor:
12
+ """Precursor class is used to extend the molecule behavior needed for interaction with
13
+ a tree in MCTS."""
14
+
15
+ def __init__(self, molecule: MoleculeContainer, canonicalize: bool = True):
16
+ """It initializes a Precursor object with a molecule container as a parameter.
17
+
18
+ :param molecule: A molecule.
19
+ """
20
+ self.molecule = safe_canonicalization(molecule) if canonicalize else molecule
21
+ self.prev_precursors = []
22
+
23
+ def __len__(self) -> int:
24
+ """Return the number of atoms in Precursor."""
25
+ return len(self.molecule)
26
+
27
+ def __hash__(self) -> hash:
28
+ """Returns the hash value of Precursor."""
29
+ return hash(self.molecule)
30
+
31
+ def __str__(self) -> str:
32
+ """Returns a SMILES of the Precursor."""
33
+ return str(self.molecule)
34
+
35
+ def __eq__(self, other: "Precursor") -> bool:
36
+ """Checks if the current Precursor is equal to another Precursor."""
37
+ return self.molecule == other.molecule
38
+
39
+ def __repr__(self) -> str:
40
+ """Returns a SMILES of the Precursor."""
41
+ return str(self.molecule)
42
+
43
+ def is_building_block(self, bb_stock: Set[str], min_mol_size: int = 6) -> bool:
44
+ """Checks if a Precursor is a building block.
45
+
46
+ :param bb_stock: The list of building blocks. Each building block is represented
47
+ by a canonical SMILES.
48
+ :param min_mol_size: If the size of the Precursor is equal or smaller than
49
+ min_mol_size it is automatically classified as building block.
50
+ :return: True is Precursor is a building block.
51
+ """
52
+ if len(self.molecule) <= min_mol_size:
53
+ return True
54
+
55
+ return str(self.molecule) in bb_stock
56
+
57
+
58
+ def compose_precursors(
59
+ precursors: list = None, exclude_small: bool = True, min_mol_size: int = 6
60
+ ) -> MoleculeContainer:
61
+ """
62
+ Takes a list of precursors, excludes small precursors if specified, and composes them
63
+ into a single molecule. The composed molecule then is used for the prediction of
64
+ synthesisability of the characterizing the possible success of the route including
65
+ the nodes with the given precursor.
66
+
67
+ :param precursors: The list of precursor to be composed.
68
+ :param exclude_small: The parameter that determines whether small precursor should be excluded from the composition
69
+ process. If `exclude_small` is set to `True`,
70
+ only precursor with a length greater than min_mol_size will be composed.
71
+ :param min_mol_size: The parameter used with exclude_small.
72
+
73
+ :return: A composed precursor as a MoleculeContainer object.
74
+
75
+ """
76
+
77
+ if len(precursors) == 1:
78
+ return precursors[0].molecule
79
+ if len(precursors) > 1:
80
+ if exclude_small:
81
+ big_precursor = [
82
+ precursor
83
+ for precursor in precursors
84
+ if len(precursor.molecule) > min_mol_size
85
+ ]
86
+ if big_precursor:
87
+ precursors = big_precursor
88
+ tmp_mol = precursors[0].molecule.copy()
89
+ transition_mapping = {}
90
+ for mol in precursors[1:]:
91
+ for n, atom in mol.molecule.atoms():
92
+ new_number = tmp_mol.add_atom(atom.atomic_symbol)
93
+ transition_mapping[n] = new_number
94
+ for atom, neighbor, bond in mol.molecule.bonds():
95
+ tmp_mol.add_bond(
96
+ transition_mapping[atom], transition_mapping[neighbor], bond
97
+ )
98
+ transition_mapping = {}
99
+
100
+ return tmp_mol
synplan/chem/reaction.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes and functions for manipulating reactions and reaction
2
+ rules."""
3
+
4
+ from typing import Any, Iterator, List, Optional
5
+
6
+ from CGRtools.containers import MoleculeContainer, ReactionContainer
7
+ from CGRtools.exceptions import InvalidAromaticRing
8
+ from CGRtools.reactor import Reactor
9
+
10
+
11
+ class Reaction(ReactionContainer):
12
+ """Reaction class used for a general representation of reaction."""
13
+
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+
17
+
18
+ def add_small_mols(
19
+ big_mol: MoleculeContainer, small_molecules: Optional[Any] = None
20
+ ) -> List[MoleculeContainer]:
21
+ """Takes a molecule and returns a list of modified molecules where each small
22
+ molecule has been added to the big molecule.
23
+
24
+ :param big_mol: A molecule.
25
+ :param small_molecules: A list of small molecules that need to be added to the
26
+ molecule.
27
+ :return: Returns a list of molecules.
28
+ """
29
+ if small_molecules:
30
+ tmp_mol = big_mol.copy()
31
+ transition_mapping = {}
32
+ for small_mol in small_molecules:
33
+
34
+ for n, atom in small_mol.atoms():
35
+ new_number = tmp_mol.add_atom(atom.atomic_symbol)
36
+ transition_mapping[n] = new_number
37
+
38
+ for atom, neighbor, bond in small_mol.bonds():
39
+ tmp_mol.add_bond(
40
+ transition_mapping[atom], transition_mapping[neighbor], bond
41
+ )
42
+
43
+ transition_mapping = {}
44
+ return tmp_mol.split()
45
+
46
+ return [big_mol]
47
+
48
+
49
+ def apply_reaction_rule(
50
+ molecule: MoleculeContainer,
51
+ reaction_rule: Reactor,
52
+ sort_reactions: bool = False,
53
+ top_reactions_num: int = 3,
54
+ validate_products: bool = True,
55
+ rebuild_with_cgr: bool = False,
56
+ ) -> Iterator[List[MoleculeContainer,]]:
57
+ """Applies a reaction rule to a given molecule.
58
+
59
+ :param molecule: A molecule to which reaction rule will be applied.
60
+ :param reaction_rule: A reaction rule to be applied.
61
+ :param sort_reactions:
62
+ :param top_reactions_num: The maximum amount of reactions after the application of
63
+ reaction rule.
64
+ :param validate_products: If True, validates the final products.
65
+ :param rebuild_with_cgr: If True, the products are extracted from CGR decomposition.
66
+ :return: An iterator yielding the products of reaction rule application.
67
+ """
68
+
69
+ reactants = add_small_mols(molecule, small_molecules=False)
70
+
71
+ try:
72
+ if sort_reactions:
73
+ unsorted_reactions = list(reaction_rule(reactants))
74
+ sorted_reactions = sorted(
75
+ unsorted_reactions,
76
+ key=lambda react: len(
77
+ list(filter(lambda mol: len(mol) > 6, react.products))
78
+ ),
79
+ reverse=True,
80
+ )
81
+
82
+ # take top-N reactions from reactor
83
+ reactions = sorted_reactions[:top_reactions_num]
84
+ else:
85
+ reactions = []
86
+ for reaction in reaction_rule(reactants):
87
+ reactions.append(reaction)
88
+ if len(reactions) == top_reactions_num:
89
+ break
90
+ except IndexError:
91
+ reactions = []
92
+
93
+ for reaction in reactions:
94
+
95
+ # temporary solution - incorrect leaving groups
96
+ reactant_atom_nums = []
97
+ for i in reaction.reactants:
98
+ reactant_atom_nums.extend(i.atoms_numbers)
99
+ product_atom_nums = []
100
+ for i in reaction.products:
101
+ product_atom_nums.extend(i.atoms_numbers)
102
+ leaving_atom_nums = set(reactant_atom_nums) - set(product_atom_nums)
103
+ if len(leaving_atom_nums) > len(product_atom_nums):
104
+ continue
105
+
106
+ # check reaction
107
+ if rebuild_with_cgr:
108
+ cgr = reaction.compose()
109
+ reactants = cgr.decompose()[1].split()
110
+ else:
111
+ reactants = reaction.products # reactants are products in retro reaction
112
+ reactants = [mol for mol in reactants if len(mol) > 0]
113
+
114
+ # validate products
115
+ if validate_products:
116
+ for mol in reactants:
117
+ try:
118
+ mol.kekule()
119
+ if mol.check_valence():
120
+ yield None
121
+ mol.thiele()
122
+ except InvalidAromaticRing:
123
+ yield None
124
+
125
+ yield reactants
synplan/chem/reaction_routes/__init__.py ADDED
File without changes
synplan/chem/reaction_routes/clustering.py ADDED
@@ -0,0 +1,859 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ from pathlib import Path
4
+ import pickle
5
+ import re
6
+
7
+ from CGRtools.containers import ReactionContainer, CGRContainer
8
+ from CGRtools.containers.bonds import DynamicBond
9
+
10
+ from synplan.chem.reaction_routes.leaving_groups import *
11
+ from synplan.chem.reaction_routes.visualisation import *
12
+ from synplan.chem.reaction_routes.route_cgr import *
13
+ from synplan.chem.reaction_routes.io import (
14
+ read_routes_csv,
15
+ read_routes_json,
16
+ make_dict,
17
+ make_json,
18
+ )
19
+ from synplan.utils.visualisation import (
20
+ routes_clustering_report,
21
+ routes_subclustering_report,
22
+ )
23
+
24
+
25
+ def run_cluster_cli(
26
+ routes_file: str,
27
+ cluster_results_dir: str,
28
+ perform_subcluster: bool = False,
29
+ subcluster_results_dir: Path = None,
30
+ ):
31
+ """
32
+ Read routes from a CSV or JSON file, perform clustering, and optionally subclustering.
33
+
34
+ Args:
35
+ routes_file: Path to the input routes file (.csv or .json).
36
+ cluster_results_dir: Directory where clustering results are stored.
37
+ perform_subcluster: Whether to run subclustering on each cluster.
38
+ subcluster_results_dir: Subdirectory for subclustering results (if enabled).
39
+ """
40
+ import click
41
+
42
+ routes_file = Path(routes_file)
43
+ match = re.search(r"_(\d+)\.", routes_file.name)
44
+ if not match:
45
+ raise ValueError(f"Could not extract index from filename: {routes_file.name}")
46
+ file_index = int(match.group(1))
47
+ ext = routes_file.suffix.lower()
48
+ if ext == ".csv":
49
+ routes_dict = read_routes_csv(str(routes_file))
50
+ routes_json = make_json(routes_dict)
51
+ elif ext == ".json":
52
+ routes_json = read_routes_json(str(routes_file))
53
+ routes_dict = make_dict(routes_json)
54
+ else:
55
+ raise ValueError(f"Unsupported file type: {ext}")
56
+
57
+ # Compose condensed graph representations
58
+ route_cgrs = compose_all_route_cgrs(routes_dict)
59
+ click.echo(f"Generating RouteCGR")
60
+ reduced_cgrs = compose_all_sb_cgrs(route_cgrs)
61
+ click.echo(f"Generating ReducedRouteCGR")
62
+
63
+ # Perform clustering
64
+ click.echo(f"\nClustering")
65
+ clusters = cluster_routes(reduced_cgrs, use_strat=False)
66
+
67
+ click.echo(f"Total number of routes: {len(routes_dict)}")
68
+ click.echo(f"Found number of clusters: {len(clusters)} ({list(clusters.keys())})")
69
+
70
+ # Ensure output directory exists
71
+ cluster_results_dir = Path(cluster_results_dir)
72
+ cluster_results_dir.mkdir(parents=True, exist_ok=True)
73
+
74
+ # Save clusters to pickle
75
+ with open(cluster_results_dir / f"clusters_{file_index}.pickle", "wb") as f:
76
+ pickle.dump(clusters, f)
77
+
78
+ # Generate HTML reports for each cluster
79
+ for idx in clusters:
80
+ report_path = cluster_results_dir / f"{file_index}_cluster_{idx}.html"
81
+ routes_clustering_report(
82
+ routes_json, clusters, idx, reduced_cgrs, html_path=str(report_path)
83
+ )
84
+
85
+ # Optional subclustering (Under development)
86
+ if perform_subcluster and subcluster_results_dir:
87
+ click.echo("\nSubClustering")
88
+ sub_dir = cluster_results_dir / subcluster_results_dir
89
+ sub_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ subclusters = subcluster_all_clusters(clusters, reduced_cgrs, route_cgrs)
92
+ for cluster_idx, sub in subclusters.items():
93
+ click.echo(f"Cluster {cluster_idx} has {len(sub)} subclusters")
94
+ for sub_idx, subcluster in sub.items():
95
+ subreport_path = (
96
+ sub_dir / f"{file_index}_subcluster_{cluster_idx}.{sub_idx}.html"
97
+ )
98
+ routes_subclustering_report(
99
+ routes_json,
100
+ subcluster,
101
+ cluster_idx,
102
+ sub_idx,
103
+ reduced_cgrs,
104
+ aam=False,
105
+ html_path=str(subreport_path),
106
+ )
107
+
108
+
109
+ def cluster_route_from_csv(routes_file: str):
110
+ """
111
+ Reads retrosynthetic routes from a CSV file, processes them, and performs clustering.
112
+
113
+ This function orchestrates the process of loading retrosynthetic route data
114
+ from a specified CSV file, converting the routes into Condensed Graph of
115
+ Reactions (CGRs), reducing these CGRs to a simplified form (ReducedRouteCGRs),
116
+ and finally clustering the routes based on these reduced representations.
117
+ It uses strategic bonds for clustering by default (as indicated by `use_strat=False`
118
+ in `cluster_routes`, which implies clustering based on the graph structure
119
+ derived from the reduced CGRs, which often highlight strategic bonds).
120
+
121
+ Args:
122
+ routes_file (str): The path to the CSV file containing the retrosynthetic
123
+ route data.
124
+
125
+ Returns:
126
+ object: The result of the clustering process, typically a data structure
127
+ representing the identified clusters. The exact type depends on
128
+ the implementation of the `cluster_routes` function.
129
+ """
130
+ routes_dict = read_routes_csv(routes_file)
131
+ route_cgrs_dict = compose_all_route_cgrs(routes_dict)
132
+ reduced_route_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
133
+ clusters = cluster_routes(reduced_route_cgrs_dict, use_strat=False)
134
+ return clusters
135
+
136
+
137
+ def cluster_route_from_json(routes_file: str):
138
+ """
139
+ Reads retrosynthetic routes from a JSON file, processes them, and performs clustering.
140
+
141
+ This function is similar to `cluster_route_from_csv` but loads the
142
+ retrosynthetic route data from a specified JSON file. It reads the JSON,
143
+ converts it into a suitable dictionary format, composes and reduces the
144
+ Condensed Graph of Reactions (CGRs) for each route, and then clusters
145
+ the routes based on these reduced representations, typically using
146
+ strategic bonds as the basis for clustering.
147
+
148
+ Args:
149
+ routes_file (str): The path to the JSON file containing the retrosynthetic
150
+ route data.
151
+
152
+ Returns:
153
+ object: The result of the clustering process, typically a data structure
154
+ representing the identified clusters. The exact type depends on
155
+ the implementation of the `cluster_routes` function.
156
+ """
157
+ routes_json = read_routes_json(routes_file)
158
+ routes_dict = make_dict(routes_json)
159
+ route_cgrs_dict = compose_all_route_cgrs(routes_dict)
160
+ reduced_route_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
161
+ clusters = cluster_routes(reduced_route_cgrs_dict, use_strat=False)
162
+ return clusters
163
+
164
+
165
+ def extract_strat_bonds(target_cgr: CGRContainer):
166
+ """
167
+ Extracts strategic bonds from a CGRContainer object.
168
+
169
+ Strategic bonds are identified as bonds where the original bond order
170
+ (`bond.order`) is None (indicating a bond that was not present in the
171
+ reactants) but the primary bond order (`bond.p_order`) is not None
172
+ (indicating a bond that was formed in the product). This function iterates
173
+ through all bonds in the input CGR, identifies those matching the criteria
174
+ for strategic bonds, and returns a sorted list of unique strategic bonds
175
+ represented as tuples of sorted atom indices.
176
+
177
+ Args:
178
+ target_cgr (CGRContainer): The CGRContainer object from which to extract
179
+ strategic bonds.
180
+
181
+ Returns:
182
+ list: A sorted list of tuples, where each tuple represents a strategic
183
+ bond by the sorted integer indices of the two atoms involved in the bond.
184
+ """
185
+ result = []
186
+ seen = set()
187
+ for atom1, bond_set in target_cgr._bonds.items():
188
+ for atom2, bond in bond_set.items():
189
+ if atom1 >= atom2:
190
+ continue
191
+ if bond.order is None and bond.p_order is not None:
192
+ bond_key = tuple(sorted((atom1, atom2)))
193
+ if bond_key not in seen:
194
+ seen.add(bond_key)
195
+ result.append(bond_key)
196
+ return sorted(result)
197
+
198
+
199
+ def cluster_routes(sb_cgrs: dict, use_strat=False):
200
+ """
201
+ Cluster routes objects based on their strategic bonds
202
+ or CGRContainer object signature (not avoid mapping)
203
+
204
+ Args:
205
+ sb_cgrs: Dictionary mapping node_id to sb_cgr objects.
206
+
207
+ Returns:
208
+ Dictionary with groups keyed by '{length}.{index}' containing
209
+ 'sb_cgr', 'node_ids', and 'strat_bonds'.
210
+ """
211
+ temp_groups = defaultdict(
212
+ lambda: {"node_ids": [], "sb_cgr": None, "strat_bonds": None}
213
+ )
214
+
215
+ # 1. Initial grouping based on the content of strategic bonds
216
+ for node_id, sb_cgr in sb_cgrs.items():
217
+ strat_bonds_list = extract_strat_bonds(sb_cgr)
218
+ if use_strat == True:
219
+ group_key = tuple(strat_bonds_list)
220
+ else:
221
+ group_key = str(sb_cgr)
222
+
223
+ if not temp_groups[group_key]["node_ids"]: # First time seeing this group
224
+ temp_groups[group_key][
225
+ "sb_cgr"
226
+ ] = sb_cgr # Store the first CGR as representative
227
+ temp_groups[group_key][
228
+ "strat_bonds"
229
+ ] = strat_bonds_list # Store the actual list
230
+
231
+ temp_groups[group_key]["node_ids"].append(node_id)
232
+ temp_groups[group_key][
233
+ "node_ids"
234
+ ].sort() # Keep node_ids sorted for consistency
235
+
236
+ for group_key in temp_groups.keys():
237
+ temp_groups[group_key]["group_size"] = len(temp_groups[group_key]["node_ids"])
238
+
239
+ # 2. Format the output dictionary with desired keys '{length}.{index}'
240
+ final_grouped_results = {}
241
+ group_indices = defaultdict(int) # To track index for each length
242
+
243
+ # Sort items by length of bonds first, then potentially by bonds themselves for consistent indexing
244
+ # Sorting by the group_key (tuple of tuples) provides a deterministic order
245
+ sorted_groups = sorted(
246
+ temp_groups.items(), key=lambda item: (len(item[0]), item[0])
247
+ )
248
+
249
+ for group_key, group_data in sorted_groups:
250
+ num_bonds = len(group_data["strat_bonds"])
251
+ group_indices[num_bonds] += 1 # Increment index for this length (1-based)
252
+ final_key = f"{num_bonds}.{group_indices[num_bonds]}"
253
+ final_grouped_results[final_key] = group_data
254
+
255
+ return final_grouped_results
256
+
257
+
258
+ def lg_process_reset(lg_cgr: CGRContainer, atom_num: int):
259
+ """
260
+ Normalize bonds in an extracted leaving group (X) fragment and flag the attachment atom as a radical.
261
+
262
+ Scans all bonds in `lg_cgr`, converting any bond with undefined `p_order`
263
+ but defined `order` into a `DynamicBond` of matching integer order. Then sets
264
+ the atom at `atom_num` to a radical.
265
+
266
+ Parameters
267
+ ----------
268
+ target_cgr : CGRContainer
269
+ The CGR representing the isolated leaving-group fragment.
270
+ atom_num : int
271
+ Index of the attachment atom to mark as a radical.
272
+
273
+ Returns
274
+ -------
275
+ CGRContainer
276
+ The modified `lg_cgr` with normalized bonds and the specified atom
277
+ flagged as a radical.
278
+ """
279
+ bond_items = list(lg_cgr._bonds.items())
280
+ for atom1, bond_set in bond_items:
281
+ bond_set_items = list(bond_set.items())
282
+ for atom2, bond in bond_set_items:
283
+ if bond.p_order is None and bond.order is not None:
284
+ order = int(bond.order)
285
+ lg_cgr.delete_bond(atom1, atom2)
286
+ lg_cgr.add_bond(atom1, atom2, DynamicBond(order, order))
287
+ lg_cgr._atoms[atom_num].is_radical = True
288
+ return lg_cgr
289
+
290
+
291
+ def lg_replacer(route_cgr: CGRContainer):
292
+ """
293
+ Extract dynamic leaving-groups from a CGR and mark attachment points.
294
+
295
+ Scans the input CGRContainer for bonds lacking explicit p_order (i.e., leaving-group attachments),
296
+ severs those bonds, captures each leaving-group as its own CGRContainer, and inserts DynamicX
297
+ markers at the attachment sites. Finally, reindexes the markers to ensure unique labels.
298
+
299
+ Parameters
300
+ ----------
301
+ route_cgr : CGRContainer
302
+ A CGR representing the full synthethic route.
303
+
304
+ Returns
305
+ -------
306
+ synthon_cgr : CGRContainer
307
+ The core synthon CGR with DynamicX atoms marking each former leaving-group site.
308
+ lg_groups : dict[int, tuple[CGRContainer, int]]
309
+ Mapping from each marker label to a tuple of:
310
+ - the extracted leaving-group CGRContainer
311
+ - the atom index where it was attached.
312
+ """
313
+ lg_groups = {}
314
+
315
+ cgr_prods = [route_cgr.substructure(c) for c in route_cgr.connected_components]
316
+ target_cgr = cgr_prods[0]
317
+
318
+ bond_items = list(target_cgr._bonds.items())
319
+ reaction = ReactionContainer.from_cgr(target_cgr)
320
+ target_mol = reaction.products[0]
321
+ max_in_target_mol = max(target_mol._atoms)
322
+
323
+ k = 1
324
+ atom_nums = []
325
+ checked_atoms = set()
326
+
327
+ for atom1, bond_set in bond_items:
328
+ bond_set_items = list(bond_set.items())
329
+ for atom2, bond in bond_set_items:
330
+ if bond.p_order is None and bond.order is not None and tuple(sorted([atom1, atom2])) not in checked_atoms:
331
+ if atom1 <= max_in_target_mol:
332
+ lg = DynamicX()
333
+ lg.mark = k
334
+ lg.isotope = k
335
+ order = bond.order
336
+ p_order = bond.p_order
337
+ target_cgr.delete_bond(atom1, atom2)
338
+ lg_cgrs = [
339
+ target_cgr.substructure(c)
340
+ for c in target_cgr.connected_components
341
+ ]
342
+ checked_atoms.add(tuple(sorted([atom1, atom2])))
343
+ if len(lg_cgrs) == 2:
344
+ lg_cgr = lg_cgrs[1]
345
+ lg_cgr = lg_process_reset(lg_cgr, atom2)
346
+ lg_cgr.clean2d()
347
+ else:
348
+ continue
349
+ lg_groups[k] = (lg_cgr, atom2)
350
+ target_cgr = [
351
+ target_cgr.substructure(c)
352
+ for c in target_cgr.connected_components
353
+ ][0]
354
+ target_cgr.add_atom(lg, atom2)
355
+ if order == 4 and p_order == None:
356
+ order = 1
357
+ target_cgr.add_bond(atom1, atom2, DynamicBond(order, p_order))
358
+ target_cgr = [
359
+ target_cgr.substructure(c)
360
+ for c in target_cgr.connected_components
361
+ ][0]
362
+ k += 1
363
+ atom_nums.append(atom2)
364
+
365
+ synthon_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][
366
+ 0
367
+ ]
368
+ reaction = ReactionContainer.from_cgr(synthon_cgr)
369
+ reactants = reaction.reactants
370
+
371
+ atom_mark_map = {} # To map atom numbers to their new marks
372
+ g = 1
373
+ for n, r in enumerate(reactants):
374
+ for atom_num in atom_nums:
375
+ if atom_num in r._atoms:
376
+ synthon_cgr._atoms[atom_num].mark = g
377
+ atom_mark_map[atom_num] = g
378
+ g += 1
379
+
380
+ new_lg_groups = {}
381
+ for original_mark in lg_groups:
382
+ cgr_obj, a_num = lg_groups[original_mark]
383
+ new_mark = atom_mark_map.get(a_num)
384
+ if new_mark is not None:
385
+ new_lg_groups[new_mark] = (cgr_obj, a_num)
386
+ lg_groups = new_lg_groups
387
+
388
+ return synthon_cgr, lg_groups
389
+
390
+
391
+ def lg_reaction_replacer(
392
+ synthon_reaction: ReactionContainer, lg_groups: dict, max_in_target_mol: int
393
+ ):
394
+ """
395
+ Replace marked leaving-groups (X) into synthon reactants.
396
+
397
+ For each reactant in `synthon_reaction`, finds placeholder atoms
398
+ (indices > `max_in_target_mol`) that match entries in `lg_groups`,
399
+ replaces them with `MarkedAt` atoms labeled by their leaving-group key (X),
400
+ and preserves original bond connectivity.
401
+
402
+ Parameters
403
+ ----------
404
+ synthon_reaction : ReactionContainer
405
+ Reaction containing reactants with X placeholders.
406
+ lg_groups : dict[int, tuple[CGRContainer, int]]
407
+ Mapping from X label to (X CGR, attachment atom index).
408
+ max_in_target_mol : int
409
+ Highest atom index of the core product; any atom_num above this is a placeholder.
410
+
411
+ Returns
412
+ -------
413
+ List[Molecule]
414
+ Reactant molecules with `MarkedAt` atoms reinserted at X attachment sites.
415
+ """
416
+ new_reactants = []
417
+ for reactant in synthon_reaction.reactants:
418
+ atom_keys = list(reactant._atoms.keys())
419
+ for atom_num in atom_keys:
420
+ if atom_num > max_in_target_mol:
421
+ for k, val in lg_groups.items():
422
+ lg = MarkedAt()
423
+ if atom_num == val[1]:
424
+ lg.mark = k
425
+ lg.isotope = k
426
+ atom1 = list(reactant._bonds[atom_num].keys())[0]
427
+ bond = reactant._bonds[atom_num][atom1]
428
+ reactant.delete_bond(atom1, atom_num)
429
+ reactant.delete_atom(atom_num)
430
+ reactant.add_atom(lg, atom_num)
431
+ reactant.add_bond(atom1, atom_num, bond)
432
+ new_reactants.append(reactant)
433
+ return new_reactants
434
+
435
+
436
+ class SubclusterError(Exception):
437
+ """Raised when subcluster_one_cluster cannot complete successfully."""
438
+
439
+
440
+ def subcluster_one_cluster(group, sb_cgrs_dict, route_cgrs_dict):
441
+ """
442
+ Generate synthon data for each route in a single cluster.
443
+
444
+ For each route (node ID) in `group['node_ids']`, replaces RouteCGRs with
445
+ SynthonCGR, builds ReactionContainers before and after X replacement,
446
+ and collects relevant data.
447
+
448
+ Parameters
449
+ ----------
450
+ group : dict
451
+ Must include `'node_ids'`, a list of node identifiers.
452
+ sb_cgrs_dict : dict
453
+ Maps node IDs to their ReducedRouteCGR.
454
+ route_cgrs_dict : dict
455
+ Maps node IDs to their RouteCGR.
456
+
457
+ Returns
458
+ -------
459
+ dict or None
460
+ If successful, returns a dict mapping each `node_id` to a tuple:
461
+ `(sb_cgr, original_reaction, synthon_cgr, new_reaction, lg_groups)`.
462
+ Or raises SubclusterError on any failure: if any step (X replacement or reaction
463
+ parsing) fails for a node.
464
+
465
+ """
466
+
467
+ node_ids = group.get("node_ids")
468
+ if not isinstance(node_ids, (list, tuple)):
469
+ raise SubclusterError(
470
+ f"'node_ids' must be a list or tuple, got {type(node_ids).__name__}"
471
+ )
472
+
473
+ result = {}
474
+ for node_id in node_ids:
475
+ sb_cgr = sb_cgrs_dict[node_id]
476
+ route_cgr = route_cgrs_dict[node_id]
477
+
478
+ # 1) Replace leaving groups in RouteCGR
479
+ try:
480
+ synthon_cgr, lg_groups = lg_replacer(route_cgr)
481
+ except (KeyError, ValueError) as e:
482
+ raise SubclusterError(f"LG replacement failed for node {node_id}") from e
483
+
484
+ # 2) Build ReactionContainer for Abstracted RouteCGR
485
+ try:
486
+ synthon_rxn = ReactionContainer.from_cgr(synthon_cgr)
487
+ except: # replace with the actual exception class
488
+ raise SubclusterError(
489
+ f"Failed to parse synthon CGR for node {node_id}"
490
+ ) from e
491
+
492
+ # 3) Prepare for X-based reaction replacement
493
+ try:
494
+ old_reactants = synthon_rxn.reactants
495
+ target_mol = synthon_rxn.products[0]
496
+ max_atom_idx = max(target_mol._atoms)
497
+ new_reactants = lg_reaction_replacer(synthon_rxn, lg_groups, max_atom_idx)
498
+ new_rxn = ReactionContainer(reactants=new_reactants, products=[target_mol])
499
+ except (IndexError, TypeError) as e:
500
+ raise SubclusterError(
501
+ f"Leaving group (X) reaction replacement failed for node {node_id}"
502
+ ) from e
503
+
504
+ result[node_id] = (
505
+ sb_cgr,
506
+ ReactionContainer(reactants=old_reactants, products=[target_mol]),
507
+ synthon_cgr,
508
+ new_rxn,
509
+ lg_groups,
510
+ )
511
+
512
+ return result
513
+
514
+
515
+ def group_nodes_by_synthon_detail(data_dict: dict):
516
+ """
517
+ Groups nodes based on synthon CGR (result[0]) and reaction (result[1]).
518
+ The output includes a dictionary mapping node IDs to their result[2] value.
519
+
520
+ Args:
521
+ data_dict: Dictionary {node_id: [synthon_cgr, synthon_reaction, node_data, ...]}.
522
+
523
+ Returns:
524
+ Dictionary {group_index: {'sb_cgr': ... ,'synthon_cgr': ..., 'synthon_reaction': ...,
525
+ 'nodes_data': {node_id1: node_data1, ...}}}.
526
+ """
527
+ temp_groups = defaultdict(list)
528
+
529
+ for node_id, result_list in data_dict.items():
530
+ if len(result_list) < 4:
531
+ group_key = (result_list[0], None) # Handle missing reaction
532
+ else:
533
+ try:
534
+ group_key = (
535
+ result_list[0],
536
+ result_list[1],
537
+ result_list[2],
538
+ result_list[3],
539
+ )
540
+ except TypeError:
541
+ print(
542
+ f"Warning: Skipping node {node_id} because reaction data is not hashable: {type(result_list[1])}"
543
+ )
544
+ continue
545
+
546
+ temp_groups[group_key].append(node_id)
547
+
548
+ # 2. Format the output dictionary with sequential integer keys
549
+ # and include the node-specific data (result[2]) in a sub-dictionary.
550
+ final_grouped_results = {}
551
+ group_index = 1
552
+
553
+ sorted_temp_groups = sorted(temp_groups.items(), key=lambda item: item[1])
554
+ for group_key, node_ids in sorted_temp_groups:
555
+
556
+ sb_cgr, unlabeled_reaction, synthon_cgr, synthon_reaction = group_key
557
+ nodes_data_dict = {}
558
+
559
+ # Iterate through the node IDs belonging to this group
560
+ for node_id in sorted(node_ids): # Sort node IDs for consistent dict order
561
+ original_result = data_dict.get(
562
+ node_id, []
563
+ ) # Get original list for this node
564
+ node_specific_data = None # Default value if index 2 is missing
565
+ if len(original_result) > 4:
566
+ node_specific_data = original_result[4] # Get the third element
567
+
568
+ nodes_data_dict[node_id] = node_specific_data # Add to the sub-dictionary
569
+
570
+ final_grouped_results[group_index] = {
571
+ "sb_cgr": sb_cgr,
572
+ "unlabeled_reaction": unlabeled_reaction,
573
+ "synthon_cgr": synthon_cgr,
574
+ "synthon_reaction": synthon_reaction,
575
+ "nodes_data": nodes_data_dict,
576
+ "post_processed": False,
577
+ }
578
+ group_index += 1
579
+
580
+ return final_grouped_results
581
+
582
+
583
+ def subcluster_all_clusters(groups, sb_cgrs_dict, route_cgrs_dict):
584
+ """
585
+ Subdivide each reaction cluster into detailed synthon-based subgroups.
586
+
587
+ Iterates over all clusters in `groups`, applies `subcluster_one_cluster`
588
+ to generate per-cluster synthons, then organizes nodes by synthon detail.
589
+
590
+ Parameters
591
+ ----------
592
+ groups : dict
593
+ Mapping of cluster indices to cluster data.
594
+ sb_cgrs_dict : dict
595
+ Dictionary of ReducedRoteCGRs
596
+ route_cgrs_dict : dict
597
+ Dictionary of RoteCGRs
598
+
599
+ Returns
600
+ -------
601
+ dict or None
602
+ A dict mapping each cluster index to its subgroups dict,
603
+ or None if any cluster fails to subcluster.
604
+ """
605
+ all_subgroups = {}
606
+ for group_index, group in groups.items():
607
+ group_synthons = subcluster_one_cluster(
608
+ group, sb_cgrs_dict, route_cgrs_dict
609
+ )
610
+ if group_synthons is None:
611
+ return None
612
+ all_subgroups[group_index] = group_nodes_by_synthon_detail(group_synthons)
613
+ return all_subgroups
614
+
615
+
616
+ def all_lg_collect(subgroup):
617
+ """
618
+ Gather all leaving-group CGRContainers by node index.
619
+
620
+ Scans `subgroup['nodes_data']`, collects every CGRContainer per index,
621
+ and returns a mapping from each index to the list of distinct containers.
622
+
623
+ Parameters
624
+ ----------
625
+ subgroup : dict
626
+ Must contain 'nodes_data', a dict mapping pathway keys to
627
+ dicts of {node_index: (CGRContainer, …)}.
628
+
629
+ Returns
630
+ -------
631
+ dict[int, list[CGRContainer]]
632
+ For each node index, a list of unique CGRContainer objects
633
+ (duplicates by string are filtered out).
634
+ """
635
+ all_indices = set()
636
+ for sub_dict in subgroup["nodes_data"].values():
637
+ all_indices.update(sub_dict.keys())
638
+
639
+ # Dynamically initialize result and seen dictionaries
640
+ result = {idx: [] for idx in all_indices}
641
+ seen = {idx: set() for idx in all_indices}
642
+
643
+ # Populate the result with unique CGRContainer objects
644
+ for sub_dict in subgroup["nodes_data"].values():
645
+ for idx in sub_dict:
646
+ cgr_container = sub_dict[idx][0]
647
+ cgr_str = str(cgr_container)
648
+ if cgr_str not in seen[idx]:
649
+ seen[idx].add(cgr_str)
650
+ result[idx].append(cgr_container)
651
+ return result
652
+
653
+
654
+ def replace_leaving_groups_in_synthon(subgroup, to_remove): # Under development
655
+ """
656
+ Replace specified leaving groups (LG) in a synthon CGR with new fragments and return the updated CGR
657
+ along with a mapping from adjusted LG marks to their atom indices.
658
+
659
+ Parameters:
660
+ subgroup (dict): Must contain:
661
+ - 'synthon_cgr': the CGR object representing the synthon graph
662
+ - 'nodes_data': mapping of node indices to LG replacement data
663
+ to_remove (List[int]): List of LG marks to remove and replace.
664
+
665
+ Returns:
666
+ Tuple[CGR, Dict[int, int]]:
667
+ - The updated CGR with replacements
668
+ - A dict mapping new LG marks to their atom indices in the updated CGR
669
+ """
670
+ # Extract the original CGR and leaving group replacement table
671
+ original_cgr = subgroup["synthon_cgr"]
672
+ lg_table = next(iter(subgroup["nodes_data"].values()))
673
+
674
+ updated_cgr = original_cgr
675
+
676
+ removed_count = 0
677
+ new_lgs = {}
678
+
679
+ # Iterate through all atoms (index, atom_obj) in the CGR
680
+ for atom_idx, atom_obj in list(updated_cgr.atoms()):
681
+ # Skip non-X atoms
682
+ if atom_obj.__class__.__name__ != "DynamicX":
683
+ continue
684
+
685
+ current_mark = atom_obj.mark
686
+ if current_mark in to_remove:
687
+ # Remove old LG (X): delete bond and atom
688
+ neighbors = list(updated_cgr._bonds[atom_idx].keys())
689
+ if neighbors:
690
+ neighbor_idx = neighbors[0]
691
+ bond = updated_cgr._bonds[atom_idx][neighbor_idx]
692
+ updated_cgr.delete_bond(atom_idx, neighbor_idx)
693
+ updated_cgr.delete_atom(atom_idx)
694
+
695
+ # Attach new LG(X) fragment from the table
696
+ lg_fragment = lg_table[current_mark][0]
697
+ updated_cgr = updated_cgr.union(lg_fragment)
698
+ # Reset radical flag on the new atom and restore the bond
699
+ updated_cgr._atoms[atom_idx].is_radical = False
700
+ updated_cgr.add_bond(atom_idx, neighbor_idx, bond)
701
+
702
+ removed_count += 1
703
+ else:
704
+ # Adjust the marks of remaining LGs to account for removed ones
705
+ atom_obj.mark -= removed_count
706
+ new_lgs[atom_obj.mark] = atom_idx
707
+
708
+ # Reorder atoms dict and update 2D coordinates for depiction
709
+ updated_cgr._atoms = dict(sorted(updated_cgr._atoms.items()))
710
+
711
+ return updated_cgr, new_lgs
712
+
713
+
714
+ def new_lg_reaction_replacer(synthon_reaction, new_lgs, max_in_target_mol):
715
+ """
716
+ Replace placeholder atom indices with marked leaving-group atoms in reactants.
717
+
718
+ Iterates through each reactant in a `ReactionContainer`, finds atom indices
719
+ corresponding to newly detached leaving-groups (those greater than the
720
+ core’s maximum index), and replaces them with `MarkedAt` atoms bearing
721
+ the correct X labels and isotopes. Bonds to the original attachment points
722
+ are preserved.
723
+
724
+ Parameters
725
+ ----------
726
+ synthon_reaction : ReactionContainer
727
+ A reaction container whose `reactants` list contains molecules with
728
+ dummy atoms (by index) marking where leaving-groups were removed.
729
+ new_lgs : dict[int, int]
730
+ Mapping from leaving-group label (int) to the atom index (int) in each
731
+ reactant that should be replaced.
732
+ max_in_target_mol : int
733
+ The highest atom index used by the core product. Any atom index in a
734
+ reactant greater than this is treated as a leaving-group placeholder.
735
+
736
+ Returns
737
+ -------
738
+ List[Molecule]
739
+ A list of reactant molecules where each placeholder atom has been
740
+ replaced by a `MarkedAt` atom with its `.mark` and `.isotope` set
741
+ to the leaving-group label, and original bonds reattached.
742
+ """
743
+ new_reactants = []
744
+ for reactant in synthon_reaction.reactants:
745
+ atom_keys = list(reactant._atoms.keys())
746
+ for atom_num in atom_keys:
747
+ if atom_num > max_in_target_mol:
748
+ for k, val in new_lgs.items():
749
+ lg = MarkedAt()
750
+ if atom_num == val:
751
+ lg.mark = k
752
+ lg.isotope = k
753
+ atom1 = list(reactant._bonds[atom_num].keys())[0]
754
+ bond = reactant._bonds[atom_num][atom1]
755
+ reactant.delete_bond(atom1, atom_num)
756
+ reactant.delete_atom(atom_num)
757
+ reactant.add_atom(lg, atom_num)
758
+ reactant.add_bond(atom1, atom_num, bond)
759
+ new_reactants.append(reactant)
760
+
761
+ return new_reactants
762
+
763
+
764
+ def post_process_subgroup(
765
+ subgroup,
766
+ ): # Under development: Error in replace_leaving_groups_in_synthon , 'cuz synthon_reaction.clean2d crashes
767
+ """
768
+ Drop leaving-groups common to all pathways and rebuild a minimal synthon.
769
+
770
+ Scans the subgroup for leaving-groups present in every route, removes those
771
+ from the CGR, re-assembles a clean ReactionContainer with the original core,
772
+ updates `nodes_data`, and flags the dict as processed.
773
+
774
+ Parameters
775
+ ----------
776
+ subgroup : dict
777
+ Must include keys for `nodes_data` and the helpers
778
+ (`all_lg_collect`, `find_const_lg`, etc.). If already
779
+ post_processed, returns immediately.
780
+
781
+ Returns
782
+ -------
783
+ dict
784
+ The same dict, now with:
785
+ - `'synthon_reaction'`: cleaned ReactionContainer
786
+ - `'nodes_data'`: filtered node table
787
+ - `'post_processed'`: True
788
+ """
789
+ if "post_processed" in subgroup.keys() and subgroup["post_processed"] == True:
790
+ return subgroup
791
+ result = all_lg_collect(subgroup)
792
+ # to find constant lg that need to be removed
793
+ to_remove = [ind for ind, cgr_set in result.items() if len(cgr_set) == 1]
794
+ new_synthon_cgr, new_lgs = replace_leaving_groups_in_synthon(subgroup, to_remove)
795
+ synthon_reaction = ReactionContainer.from_cgr(new_synthon_cgr)
796
+ synthon_reaction.clean2d()
797
+ old_reactants = ReactionContainer.from_cgr(new_synthon_cgr).reactants
798
+ target_mol = synthon_reaction.products[0] # TO DO: target_mol might be non 0
799
+ max_in_target_mol = max(target_mol._atoms)
800
+ new_reactants = new_lg_reaction_replacer(
801
+ synthon_reaction, new_lgs, max_in_target_mol
802
+ )
803
+ new_synthon_reaction = ReactionContainer(
804
+ reactants=new_reactants, products=[target_mol]
805
+ )
806
+ new_synthon_reaction.clean2d()
807
+ subgroup["synthon_reaction"] = new_synthon_reaction
808
+ subgroup["nodes_data"] = remove_and_shift(subgroup["nodes_data"], to_remove)
809
+ subgroup["post_processed"] = True
810
+ subgroup["group_lgs"] = group_by_identical_values(subgroup["nodes_data"])
811
+ return subgroup
812
+
813
+
814
+ def group_by_identical_values(nodes_data): # Under development
815
+ """
816
+ Groups entries in a nested dictionary based on identical sets of core values.
817
+
818
+ Identifies route IDs whose inner dictionaries contain the
819
+ same sequence of leaving groups, when ordered by subkey. These are collapsed into a single entry.
820
+
821
+ Args:
822
+ nodes_data (dict): A dictionary mapping outer keys to inner dictionaries.
823
+ Each inner dictionary maps subkeys to a tuple `(value_obj, other_info)`.
824
+ `value_obj` is used for grouping, `other_info` is ignored.
825
+ Example: {'route_1': {'pos_a': (1, 'infoA'), 'pos_b': (2, 'infoB')}, ...}
826
+
827
+ Returns:
828
+ dict: A dictionary where:
829
+ - Keys are tuples of the original outer keys that were grouped.
830
+ - Values are dictionaries mapping the original subkeys to the
831
+ `value_obj` from the first outer key in the group's tuple.
832
+ The dictionary is sorted descending by the number of grouped outer keys.
833
+ Example: {('route_1', 'route_2'): {'pos_a': 1, 'pos_b': 2}, ...}
834
+ """
835
+ # Step 1: Build a signature for each outer key: the tuple of all first-elements in its inner dict
836
+ signature_map = defaultdict(list)
837
+ for outer_key, inner_dict in nodes_data.items():
838
+ # Sort inner_dict items by subkey to ensure consistent ordering
839
+ sorted_items = sorted(inner_dict.items(), key=lambda kv: kv[0])
840
+ # Extract only the first element of each (value_obj, other_info) tuple
841
+ signature = tuple(val_tuple[0] for _, val_tuple in sorted_items)
842
+ signature_map[signature].append(outer_key)
843
+
844
+ # Step 2: Build the grouped result
845
+ grouped = {}
846
+ for signature, outer_keys in signature_map.items():
847
+ # Use the representative inner dict from the first outer key in this group
848
+ rep_inner = nodes_data[outer_keys[0]]
849
+ # Build mapping subkey -> value_obj
850
+ rep_values = {subkey: val_tuple[0] for subkey, val_tuple in rep_inner.items()}
851
+ # Store under tuple of grouped outer keys
852
+ grouped_key = tuple(outer_keys)
853
+ grouped[grouped_key] = rep_values
854
+
855
+ sorted_grouped = dict(
856
+ sorted(grouped.items(), key=lambda item: len(item[0]), reverse=True)
857
+ )
858
+
859
+ return sorted_grouped
synplan/chem/reaction_routes/io.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import pickle
4
+ import os
5
+
6
+ from CGRtools import smiles as read_smiles
7
+ from synplan.mcts.tree import Tree
8
+
9
+
10
+ def make_dict(routes_json):
11
+ """
12
+ routes_json : list of tree-dicts as produced by make_json()
13
+
14
+ Returns a dict mapping each route index (0, 1, …) to a sub-dict
15
+ of {new_step_id: ReactionContainer}, where the step IDs run
16
+ from the earliest reaction (0) up to the final (max).
17
+ """
18
+ routes_dict = {}
19
+ if isinstance(routes_json, dict):
20
+ for route_idx, tree in routes_json.items():
21
+ rxn_list = []
22
+
23
+ def _postorder(node):
24
+ # first dive into any children, then record this reaction
25
+ for child in node.get("children", []):
26
+ _postorder(child)
27
+ if node["type"] == "reaction":
28
+ rxn_list.append(read_smiles(node["smiles"]))
29
+ # mol-nodes simply recurse (no record)
30
+
31
+ # collect all reactions in leaf→root order
32
+ _postorder(tree)
33
+
34
+ # now assign 0,1,2,… in that order
35
+ reactions = {i: rxn for i, rxn in enumerate(rxn_list)}
36
+ routes_dict[int(route_idx)] = reactions
37
+
38
+ return routes_dict
39
+ else:
40
+ for route_idx, tree in enumerate(routes_json):
41
+ rxn_list = []
42
+
43
+ def _postorder(node):
44
+ # first dive into any children, then record this reaction
45
+ for child in node.get("children", []):
46
+ _postorder(child)
47
+ if node["type"] == "reaction":
48
+ rxn_list.append(read_smiles(node["smiles"]))
49
+ # mol-nodes simply recurse (no record)
50
+
51
+ # collect all reactions in leaf→root order
52
+ _postorder(tree)
53
+
54
+ # now assign 0,1,2,… in that order
55
+ reactions = {i: rxn for i, rxn in enumerate(rxn_list)}
56
+ routes_dict[int(route_idx)] = reactions
57
+
58
+ return routes_dict
59
+
60
+
61
+ def read_routes_json(file_path="routes.csv", to_dict=False):
62
+ with open(file_path, "r") as file:
63
+ routes_json = json.load(file)
64
+ if to_dict:
65
+ return make_dict(routes_json)
66
+ return routes_json
67
+
68
+
69
+ def read_routes_csv(file_path="routes.csv"):
70
+ """
71
+ Read a CSV with columns: route_id, step_id, smiles, meta
72
+ and return a nested dict mapping
73
+ route_id (int) -> step_id (int) -> ReactionContainer
74
+ (ignoring meta for now, but you could extract it if needed).
75
+ """
76
+ routes_dict = {}
77
+ with open(file_path, newline="") as csvfile:
78
+ reader = csv.DictReader(csvfile)
79
+ for row in reader:
80
+ route_id = int(row["route_id"])
81
+ step_id = int(row["step_id"])
82
+ smiles = row["smiles"]
83
+ # adjust this constructor to your actual API
84
+ reaction = read_smiles(smiles)
85
+ routes_dict.setdefault(route_id, {})[step_id] = reaction
86
+ return routes_dict
87
+
88
+
89
+ def make_json(routes_dict, keep_ids=True):
90
+ """
91
+ Convert routes into a nested JSON tree of reaction and molecule nodes.
92
+
93
+ Args:
94
+ routes_dict (dict[int, dict[int, Reaction]]): Mapping route IDs to steps (step_id -> Reaction).
95
+ keep_ids (bool): If True, returns a list of route trees; otherwise returns a dict mapping route IDs to trees.
96
+
97
+ Returns:
98
+ list or dict: JSON-like tree(s) of routes.
99
+ """
100
+ # Prepare output
101
+ all_routes = {} if keep_ids else []
102
+
103
+ for route_id, steps in routes_dict.items():
104
+ if not steps:
105
+ continue
106
+
107
+ # Determine target molecule atoms from the final step of this route
108
+ final_step = max(steps)
109
+ target = steps[final_step].products[0]
110
+ atom_nums = set(target._atoms.keys())
111
+
112
+ # Precompute canonical SMILES and producer mapping for all products
113
+ prod_map = {} # smiles -> list of step_ids
114
+ for sid, rxn in steps.items():
115
+ for prod in rxn.products:
116
+ prod.kekule()
117
+ prod.implicify_hydrogens()
118
+ prod.thiele()
119
+ s = str(prod)
120
+ prod_map.setdefault(s, []).append(sid)
121
+
122
+ def transform(mol):
123
+ mol.kekule()
124
+ mol.implicify_hydrogens()
125
+ mol.thiele()
126
+ return str(mol)
127
+
128
+ def build_mol_node(sid):
129
+ """Find the product with any overlap to target atoms and recurse into its reaction."""
130
+ rxn = steps[sid]
131
+ for p in rxn.products:
132
+ if atom_nums & set(p._atoms.keys()):
133
+ smiles = str(p)
134
+ return {
135
+ "type": "mol",
136
+ "smiles": smiles,
137
+ "children": [build_reaction_node(sid)],
138
+ "in_stock": False,
139
+ }
140
+ # Shouldn't reach here if tree is consistent
141
+ return None
142
+
143
+ def build_reaction_node(sid):
144
+ """Build reaction node and recurse into reactant molecule nodes."""
145
+ rxn = steps[sid]
146
+ node = {"type": "reaction", "smiles": format(rxn, "m"), "children": []}
147
+
148
+ for react in rxn.reactants:
149
+ r_smi = transform(react)
150
+ # Look up any prior step producing this reactant
151
+ prior = [ps for ps in prod_map.get(r_smi, []) if ps < sid]
152
+ if prior:
153
+ node["children"].append(build_mol_node(max(prior)))
154
+ else:
155
+ node["children"].append(
156
+ {"type": "mol", "smiles": r_smi, "in_stock": True}
157
+ )
158
+
159
+ return node
160
+
161
+ # Build route tree and store
162
+ tree = build_mol_node(final_step)
163
+ if keep_ids:
164
+ all_routes[int(route_id)] = tree
165
+ else:
166
+ all_routes.append(tree)
167
+
168
+ return all_routes
169
+
170
+
171
+ def write_routes_json(routes_dict, file_path):
172
+ """Serialize reaction routes to a JSON file."""
173
+ routes_json = make_json(routes_dict)
174
+ with open(file_path, "w") as f:
175
+ json.dump(routes_json, f, indent=2)
176
+
177
+
178
+ def write_routes_csv(routes_dict, file_path="routes.csv"):
179
+ """
180
+ Write out a nested routes_dict of the form
181
+ { route_id: { step_id: reaction_obj, ... }, ... }
182
+ to a CSV with columns: route_id, step_id, smiles, meta
183
+ where smiles is format(reaction, 'm') and meta is left blank.
184
+ """
185
+ with open(file_path, "w", newline="") as csvfile:
186
+ writer = csv.writer(csvfile)
187
+ # header row
188
+ writer.writerow(["route_id", "step_id", "smiles", "meta"])
189
+ # sort routes and steps for deterministic output
190
+ for route_id in sorted(routes_dict):
191
+ steps = routes_dict[route_id]
192
+ for step_id in sorted(steps):
193
+ reaction = steps[step_id]
194
+ smiles = format(reaction, "m")
195
+ meta = "" # or reaction.meta if you add that later
196
+ writer.writerow([route_id, step_id, smiles, meta])
197
+
198
+
199
+ class TreeWrapper:
200
+
201
+ def __init__(self, tree, mol_id=1, config=1, path="planning_results/forest"):
202
+ """Initializes the TreeWrapper."""
203
+ self.tree = tree
204
+ self.mol_id = mol_id
205
+ self.config = config
206
+ self.path = path
207
+ # Ensure the directory exists before creating the filename
208
+ os.makedirs(self.path, exist_ok=True)
209
+ self.filename = os.path.join(self.path, f"tree_{mol_id}_{config}.pkl")
210
+
211
+ def __getstate__(self):
212
+ state = self.__dict__.copy()
213
+ tree_state = self.tree.__dict__.copy()
214
+ # Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
215
+ if "_tqdm" in tree_state:
216
+ tree_state["_tqdm"] = True # Reset to a simple flag
217
+ for attr in ["policy_network", "value_network"]:
218
+ if attr in tree_state:
219
+ tree_state[attr] = None
220
+ state["tree_state"] = tree_state
221
+ del state["tree"]
222
+ return state
223
+
224
+ def __setstate__(self, state):
225
+ tree_state = state.pop("tree_state")
226
+ self.__dict__.update(state)
227
+ new_tree = Tree.__new__(Tree)
228
+ new_tree.__dict__.update(tree_state)
229
+ self.tree = new_tree
230
+
231
+ def save_tree(self):
232
+ """Saves the TreeWrapper instance (including the tree state) to a file."""
233
+ try:
234
+ with open(self.filename, "wb") as f:
235
+ pickle.dump(self, f)
236
+ print(
237
+ f"Tree wrapper for mol_id '{self.mol_id}', config '{self.config}' saved to '{self.filename}'."
238
+ )
239
+ except Exception as e:
240
+ print(f"Error saving tree to {self.filename}: {e}")
241
+
242
+ @classmethod
243
+ def load_tree_from_id(cls, mol_id, config=1, path="planning_results/forest"):
244
+ """
245
+ Loads a Tree object from a saved file using mol_id and config.
246
+
247
+ Args:
248
+ mol_id: The molecule ID used for saving.
249
+ config: The configuration used for saving.
250
+ path: The directory where the file is located
251
+
252
+ Returns:
253
+ The loaded Tree object, or None if loading fails.
254
+ """
255
+ filename = os.path.join(path, f"tree_{mol_id}_{config}.pkl")
256
+ print(f"Attempting to load tree from: {filename}")
257
+ try:
258
+ # Ensure the 'Tree' class is defined in the current scope
259
+ if "Tree" not in globals() and "Tree" not in locals():
260
+ raise NameError(
261
+ "The 'Tree' class definition is required to load the object."
262
+ )
263
+
264
+ with open(filename, "rb") as f:
265
+ loaded_wrapper = pickle.load(f) # This implicitly calls __setstate__
266
+
267
+ print(
268
+ f"Tree object for mol_id '{mol_id}', config '{config}' successfully loaded from '{filename}'."
269
+ )
270
+ # The __setstate__ method already reconstructed the tree inside the wrapper
271
+ return loaded_wrapper.tree
272
+
273
+ except FileNotFoundError:
274
+ print(f"Error: File not found at {filename}")
275
+ return None
276
+ except (pickle.UnpicklingError, EOFError) as e:
277
+ print(
278
+ f"Error: Could not unpickle file {filename}. It might be corrupted or empty. Details: {e}"
279
+ )
280
+ return None
281
+ except NameError as e:
282
+ print(f"Error during loading: {e}. Ensure 'Tree' class is defined.")
283
+ return None
284
+ except Exception as e:
285
+ print(f"An unexpected error occurred loading tree from {filename}: {e}")
286
+ return None
synplan/chem/reaction_routes/leaving_groups.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.periodictable import Core, At, DynamicElement
2
+ from typing import Optional
3
+
4
+
5
+ class Marked(Core):
6
+ __slots__ = "__mark", "_isotope"
7
+
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.__mark = None
11
+ self._isotope = 0 # Make sure this exists
12
+
13
+ @property
14
+ def mark(self):
15
+ return self.__mark
16
+
17
+ @mark.setter
18
+ def mark(self, mark):
19
+ self.__mark = mark
20
+
21
+ @property
22
+ def isotope(self):
23
+ return getattr(self, "_isotope", 0) # Always returns int
24
+
25
+ @isotope.setter
26
+ def isotope(self, value):
27
+ self._isotope = int(value)
28
+
29
+ def __repr__(self):
30
+ return f"{self.symbol}({self.isotope})"
31
+
32
+ @property
33
+ def atomic_symbol(self) -> str:
34
+ return self.__class__.__name__[6:]
35
+
36
+ @property
37
+ def symbol(self) -> str:
38
+ return "X" # For human-readable representation
39
+
40
+ def __len__(self):
41
+ return super().__len__()
42
+
43
+
44
+ class MarkedAt(Marked, At):
45
+ atomic_number = At.atomic_number
46
+
47
+ @property
48
+ def atomic_symbol(self):
49
+ return "At"
50
+
51
+ @property
52
+ def symbol(self):
53
+ return "X"
54
+
55
+ def __repr__(self):
56
+ return f"X({self.isotope})"
57
+
58
+ def __str__(self):
59
+ return f"X({self.isotope})"
60
+
61
+ def __hash__(self):
62
+ return hash(
63
+ (
64
+ self.isotope,
65
+ getattr(self, "atomic_number", 0),
66
+ getattr(self, "charge", 0),
67
+ getattr(self, "is_radical", False),
68
+ )
69
+ )
70
+
71
+
72
+ class DynamicX(DynamicElement):
73
+ __slots__ = ("_mark", "_isotope")
74
+
75
+ atomic_number = 85
76
+ mass = 0.0
77
+ group = 0
78
+ period = 0
79
+ isotopes_distribution = list(range(20))
80
+ atomic_radius = 0.5
81
+ isotopes_masses = 0
82
+
83
+ def __init__(self, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+ self._isotope = None
86
+ self._mark = None
87
+
88
+ @property
89
+ def mark(self):
90
+ return getattr(self, "_mark", None)
91
+
92
+ @mark.setter
93
+ def mark(self, value):
94
+ self._mark = value
95
+
96
+ @property
97
+ def isotope(self):
98
+ return getattr(self, "_isotope", None)
99
+
100
+ @isotope.setter
101
+ def isotope(self, value):
102
+ self._isotope = value
103
+
104
+ @property
105
+ def symbol(self) -> str:
106
+ return "X"
107
+
108
+ def valence_rules(
109
+ self, charge: int = 0, is_radical: bool = False, valence: int = 0
110
+ ) -> tuple:
111
+ if charge == 0 and not is_radical and (valence == 1):
112
+ return tuple()
113
+ elif charge == 0 and not is_radical and valence == 0:
114
+ return tuple()
115
+ else:
116
+ return tuple()
117
+
118
+ def __repr__(self):
119
+ return f"Dynamic{self.symbol}()"
120
+
121
+ @property
122
+ def p_charge(self) -> int:
123
+ return self.charge
124
+
125
+ @property
126
+ def p_is_radical(self) -> bool:
127
+ return self.is_radical
128
+
129
+ @property
130
+ def p_hybridization(self) -> Optional[int]:
131
+ return self.hybridization
synplan/chem/reaction_routes/route_cgr.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.containers.bonds import DynamicBond
2
+ from CGRtools.containers import ReactionContainer, CGRContainer, MoleculeContainer
3
+ from synplan.mcts.tree import Tree
4
+
5
+
6
+ def find_next_atom_num(reactions: list):
7
+ """
8
+ Find the next available atom number across a list of reactions.
9
+
10
+ This function iterates through a list of reaction containers, composes
11
+ each reaction to get its Condensed Graph of Reaction (CGR), and finds
12
+ the maximum atom index used within each CGR. It then returns the maximum
13
+ atom index found across all reactions plus one, providing a unique
14
+ next available atom number.
15
+
16
+ Args:
17
+ reactions (list): A list of ReactionContainer objects.
18
+
19
+ Returns:
20
+ int: The next available integer atom number, which is one greater
21
+ than the maximum atom index found in any of the provided reaction CGRs.
22
+ """
23
+ max_num = 0
24
+ for reaction in reactions:
25
+ cgr = reaction.compose()
26
+ max_num = max(max_num, max(cgr._atoms.keys()))
27
+ return max_num + 1
28
+
29
+
30
+ def get_clean_mapping(
31
+ curr_prod: MoleculeContainer, prod: MoleculeContainer, reverse: bool = False
32
+ ):
33
+ """
34
+ Get a 'clean' atom mapping between two molecules, avoiding conflicts.
35
+
36
+ This function attempts to establish a mapping between the atoms of two
37
+ MoleculeContainer objects (`curr_prod` and `prod`). It uses an internal
38
+ mapping mechanism and then filters the result to create a "clean" mapping.
39
+ The cleaning process specifically avoids adding entries to the mapping
40
+ where the source and target indices are the same, or where the target
41
+ index already exists as a source in the mapping with a different target.
42
+ It also checks for potential conflicts based on the atom keys present
43
+ in the original molecules.
44
+
45
+ Args:
46
+ curr_prod (MoleculeContainer): The first MoleculeContainer object.
47
+ prod (MoleculeContainer): The second MoleculeContainer object.
48
+ reverse (bool, optional): If True, the mapping is generated in the
49
+ reverse direction (from `prod` to `curr_prod`).
50
+ Defaults to False (mapping from `curr_prod` to `prod`).
51
+
52
+ Returns:
53
+ dict: A dictionary representing the clean atom mapping. Keys are atom
54
+ indices from the source molecule, and values are the corresponding
55
+ atom indices in the target molecule. Returns an empty dictionary
56
+ if no mapping is found or if the initial mapping is empty.
57
+ """
58
+ dict_map = {}
59
+ m = list(curr_prod.get_mapping(prod))
60
+
61
+ if len(m) == 0:
62
+ return dict_map
63
+
64
+ curr_atoms = set(curr_prod._atoms.keys())
65
+ prod_atoms = set(prod._atoms.keys())
66
+
67
+ rr = m[0]
68
+
69
+ # Build mapping while checking for conflicts
70
+ for key, value in rr.items():
71
+ if key != value:
72
+ if value in rr.keys() and rr[value] != key:
73
+ continue
74
+
75
+ source = value if reverse else key
76
+ target = key if reverse else value
77
+
78
+ if reverse and target in curr_atoms:
79
+ continue
80
+ if not reverse and target in prod_atoms:
81
+ continue
82
+
83
+ dict_map[source] = target
84
+
85
+ return dict_map
86
+
87
+
88
+ def validate_molecule_components(curr_mol: MoleculeContainer, node_id: int):
89
+ """
90
+ Validate that a molecule has only one connected component.
91
+
92
+ This function checks if a given MoleculeContainer object represents a
93
+ single connected molecule or multiple disconnected fragments. It extracts
94
+ the connected components and prints an error message if more than one
95
+ component is found, indicating a potential issue with the molecule
96
+ representation within a specific tree node.
97
+
98
+ Args:
99
+ curr_mol (MoleculeContainer): The MoleculeContainer object to validate.
100
+ node_id (int): The ID of the tree node associated with this molecule,
101
+ used for reporting purposes in the error message.
102
+ """
103
+ new_rmol = [curr_mol.substructure(c) for c in curr_mol.connected_components]
104
+ if len(new_rmol) > 1:
105
+ print(f"Error tree {node_id}: We have more than one molecule in one node")
106
+
107
+
108
+ def get_leaving_groups(products: list):
109
+ """
110
+ Extract leaving group atom numbers from a list of reaction products.
111
+
112
+ This function takes a list of product MoleculeContainer objects resulting
113
+ from a reaction. It assumes the first molecule in the list is the main
114
+ product and the subsequent molecules are leaving groups. It collects
115
+ the atom indices (keys from the `_atoms` dictionary) for all molecules
116
+ except the first one, considering these indices as belonging to leaving
117
+ group atoms.
118
+
119
+ Args:
120
+ products (list): A list of MoleculeContainer objects representing the
121
+ products of a reaction. The first element is assumed
122
+ to be the main product.
123
+
124
+ Returns:
125
+ list: A list of integer atom indices corresponding to the atoms
126
+ in the leaving group molecules.
127
+ """
128
+ lg_atom_nums = []
129
+ for i, prod in enumerate(products):
130
+ if i != 0: # Skip first product (main product)
131
+ lg_atom_nums.extend(prod._atoms.keys())
132
+ return lg_atom_nums
133
+
134
+
135
+ def process_first_reaction(first_react: ReactionContainer, tree: Tree, node_id: int):
136
+ """
137
+ Process the first reaction in a retrosynthetic route and initialize the building block set.
138
+
139
+ This function takes the first reaction in a route, iterates through its
140
+ reactants, validates that each reactant is a single connected component,
141
+ and identifies potential building blocks. A reactant is considered a
142
+ potential building block if its size is less than or equal to the
143
+ minimum molecule size defined in the tree's configuration or if its
144
+ SMILES string is present in the tree's building blocks set. The atom
145
+ indices of such building blocks are collected into a set.
146
+
147
+ Args:
148
+ first_react (ReactionContainer): The first ReactionContainer object in the route.
149
+ tree (Tree): The Tree object containing the retrosynthetic search tree
150
+ and configuration (including `min_mol_size` and `building_blocks`).
151
+ node_id (int): The ID of the tree node associated with this reaction,
152
+ used for validation reporting.
153
+
154
+ Returns:
155
+ set: A set of integer atom indices corresponding to the atoms
156
+ identified as part of building blocks in the first reaction's reactants.
157
+ """
158
+ bb_set = set()
159
+
160
+ for curr_mol in first_react.reactants:
161
+ react_key = tuple(curr_mol._atoms)
162
+ react_key_set = set(react_key)
163
+
164
+ if (
165
+ len(curr_mol) <= tree.config.min_mol_size
166
+ or str(curr_mol) in tree.building_blocks
167
+ ):
168
+ bb_set = react_key_set
169
+
170
+ validate_molecule_components(curr_mol, node_id)
171
+
172
+ return bb_set
173
+
174
+
175
+ def update_reaction_dict(
176
+ reaction: ReactionContainer,
177
+ node_id: int,
178
+ mapping: dict,
179
+ react_dict: dict,
180
+ tree: Tree,
181
+ bb_set: set,
182
+ prev_remap: dict = None,
183
+ ):
184
+ """
185
+ Update a reaction dictionary with atom mappings and identify building blocks.
186
+
187
+ This function processes the reactants of a given reaction, validates their
188
+ structure (single connected component), updates a dictionary (`react_dict`)
189
+ with atom mappings for each reactant, and expands a set of building block
190
+ atom indices (`bb_set`). The mapping is filtered based on the atoms present
191
+ in the current reactant, and can optionally include a previous remapping.
192
+ Reactants are identified as building blocks based on size or presence in
193
+ the tree's building blocks set.
194
+
195
+ Args:
196
+ reaction (ReactionContainer): The ReactionContainer object representing the reaction.
197
+ node_id (int): The ID of the tree node associated with this synthethic route,
198
+ used for validation reporting.
199
+ mapping (dict): The primary atom mapping dictionary to filter and apply.
200
+ react_dict (dict): The dictionary to update with filtered mappings for each reactant.
201
+ Keys are tuples of atom indices for each reactant molecule.
202
+ tree (Tree): The Tree object containing the retrosynthetic search tree
203
+ and configuration (including `min_mol_size` and `building_blocks`).
204
+ bb_set (set): The set of building block atom indices to update.
205
+ prev_remap (dict, optional): An optional dictionary representing a previous
206
+ remapping to include in the filtered mapping.
207
+ Defaults to None.
208
+
209
+ Returns:
210
+ tuple: A tuple containing:
211
+ - dict: The updated `react_dict` with filtered mappings for each reactant.
212
+ - set: The updated `bb_set` including atom indices from newly identified
213
+ building blocks.
214
+ """
215
+ for curr_mol in reaction.reactants:
216
+ react_key = tuple(curr_mol._atoms)
217
+ react_key_set = set(react_key)
218
+
219
+ validate_molecule_components(curr_mol, node_id)
220
+
221
+ if (
222
+ len(curr_mol) <= tree.config.min_mol_size
223
+ or str(curr_mol) in tree.building_blocks
224
+ ):
225
+ bb_set = bb_set.union(react_key_set)
226
+
227
+ # Filter the mapping to include only keys present in the current react_key
228
+ filtered_mapping = {k: v for k, v in mapping.items() if k in react_key_set}
229
+ if prev_remap:
230
+ prev_remappping = {
231
+ k: v for k, v in prev_remap.items() if k in react_key_set
232
+ }
233
+ filtered_mapping.update(prev_remappping)
234
+ react_dict[react_key] = filtered_mapping
235
+
236
+ return react_dict, bb_set
237
+
238
+
239
+ def process_target_blocks(
240
+ curr_products: list,
241
+ curr_prod: MoleculeContainer,
242
+ lg_atom_nums: list,
243
+ curr_lg_atom_nums: list,
244
+ bb_set: set,
245
+ ):
246
+ """
247
+ Identifies and collects atom indices for target blocks based on leaving groups and building blocks.
248
+
249
+ This function iterates through a list of current product molecules, compares their atoms
250
+ to a reference molecule (`curr_prod`), and collects the indices of atoms that correspond
251
+ to atoms in the provided leaving group lists (`lg_atom_nums`, `curr_lg_atom_nums`) or
252
+ the building block set (`bb_set`). This is typically used to identify parts of molecules
253
+ that should be treated as 'target blocks' during a remapping or analysis process.
254
+
255
+ Args:
256
+ curr_products (list): A list of MoleculeContainer objects representing the current products.
257
+ curr_prod (MoleculeContainer): A reference MoleculeContainer object, likely the main product,
258
+ used for mapping atom indices.
259
+ lg_atom_nums (list): A list of integer atom indices identified as leaving group atoms
260
+ in a relevant context.
261
+ curr_lg_atom_nums (list): Another list of integer atom indices identified as leaving
262
+ group atoms, potentially from a different context than `lg_atom_nums`.
263
+ bb_set (set): A set of integer atom indices identified as building block atoms.
264
+
265
+ Returns:
266
+ list: A list of integer atom indices that are identified as 'target blocks' based on
267
+ their presence in the leaving group lists or building block set after mapping
268
+ to the reference molecule.
269
+ """
270
+ target_block = []
271
+ if len(curr_products) > 1:
272
+ for prod in curr_products:
273
+ dict_map = get_clean_mapping(curr_prod, prod)
274
+ if prod._atoms.keys() != curr_prod._atoms.keys():
275
+ for key in list(prod._atoms.keys()):
276
+ if key in lg_atom_nums or key in curr_lg_atom_nums:
277
+ target_block.append(key)
278
+ if key in bb_set:
279
+ target_block.append(key)
280
+ return target_block
281
+
282
+
283
+ def compose_route_cgr(tree_or_routes, node_id):
284
+ """
285
+ Process a single synthesis route maintaining consistent state.
286
+
287
+ Parameters
288
+ ----------
289
+ tree_or_routes : synplan.mcts.tree.Tree
290
+ or dict mapping route_id -> {step_id: ReactionContainer}
291
+ node_id : int
292
+ the route index (in the Tree’s winning_nodes, or the dict’s keys)
293
+
294
+ Returns
295
+ -------
296
+ dict or None
297
+ - if successful: { 'cgr': <composed CGR>, 'reactions_dict': {step: ReactionContainer,…} }
298
+ - on error: None
299
+ """
300
+ # ----------- dict-based branch ------------
301
+ if isinstance(tree_or_routes, dict):
302
+ routes_dict = tree_or_routes
303
+ if node_id not in routes_dict:
304
+ raise KeyError(f"Route {node_id} not in provided dict.")
305
+ # grab and sort the ReactionContainers in chronological order
306
+ step_map = routes_dict[node_id]
307
+ sorted_ids = sorted(step_map)
308
+ reactions = [step_map[i] for i in sorted_ids]
309
+
310
+ # start from the last (final) reaction
311
+ accum_cgr = reactions[-1].compose()
312
+ reactions_dict = {len(reactions) - 1: reactions[-1]}
313
+ # now fold backwards through the earlier steps
314
+ for idx in range(len(reactions) - 2, -1, -1):
315
+ rxn = reactions[idx]
316
+ curr_cgr = rxn.compose()
317
+ accum_cgr = curr_cgr.compose(accum_cgr)
318
+ reactions_dict[idx] = rxn
319
+
320
+ return {"cgr": accum_cgr, "reactions_dict": reactions_dict}
321
+
322
+ # ----------- tree-based branch ------------
323
+ tree = tree_or_routes
324
+ try:
325
+ # original tree-based logic:
326
+ reactions = tree.synthesis_route(node_id)
327
+
328
+ first_react = reactions[-1]
329
+ reactions_dict = {len(reactions) - 1: first_react}
330
+
331
+ accum_cgr = first_react.compose()
332
+ bb_set = process_first_reaction(first_react, tree, node_id)
333
+ react_dict = {}
334
+ max_num = find_next_atom_num(reactions)
335
+
336
+ for step in range(len(reactions) - 2, -1, -1):
337
+ reaction = reactions[step]
338
+ curr_cgr = reaction.compose()
339
+ curr_prod = reaction.products[0]
340
+
341
+ accum_products = accum_cgr.decompose()[1].split()
342
+ lg_atom_nums = get_leaving_groups(accum_products)
343
+ curr_products = curr_cgr.decompose()[1].split()
344
+
345
+ tuple_atoms = tuple(curr_prod._atoms)
346
+ prev_remap = react_dict.get(tuple_atoms, {})
347
+
348
+ if prev_remap:
349
+ curr_cgr = curr_cgr.remap(prev_remap, copy=True)
350
+
351
+ # identify new atom‐numbers for any overlap
352
+ target_block = process_target_blocks(
353
+ curr_products,
354
+ curr_prod,
355
+ lg_atom_nums,
356
+ [list(p._atoms.keys()) for p in curr_products[1:]],
357
+ bb_set,
358
+ )
359
+ mapping = {}
360
+ for atom_num in sorted(target_block):
361
+ if atom_num in accum_cgr._atoms and atom_num not in mapping:
362
+ mapping[atom_num] = max_num
363
+ max_num += 1
364
+
365
+ # carry forward any clean remap on the product itself
366
+ dict_map = {}
367
+ for ap in accum_products:
368
+ clean_map = get_clean_mapping(curr_prod, ap, reverse=True)
369
+ if clean_map:
370
+ dict_map = clean_map
371
+ break
372
+ if dict_map:
373
+ curr_cgr = curr_cgr.remap(dict_map, copy=False)
374
+
375
+ # update our react_dict & bb_set
376
+ react_dict, bb_set = update_reaction_dict(
377
+ reaction, node_id, mapping, react_dict, tree, bb_set, prev_remap
378
+ )
379
+
380
+ # apply the new overlap‐mapping
381
+ if mapping:
382
+ curr_cgr = curr_cgr.remap(mapping, copy=False)
383
+
384
+ reactions_dict[step] = ReactionContainer.from_cgr(curr_cgr)
385
+ accum_cgr = curr_cgr.compose(accum_cgr)
386
+
387
+ return {"cgr": accum_cgr, "reactions_dict": reactions_dict}
388
+
389
+ except Exception as e:
390
+ print(f"Error processing node {node_id}: {e}")
391
+ return None
392
+
393
+
394
+ def compose_all_route_cgrs(tree_or_routes, node_id=None):
395
+ """
396
+ Process routes (reassign atom mappings) to compose RouteCGR.
397
+
398
+ Parameters
399
+ ----------
400
+ tree_or_routes : synplan.mcts.tree.Tree
401
+ or dict mapping route_id -> {step_id: ReactionContainer}
402
+ node_id : int or None
403
+ if None, do *all* winning routes (or all keys of the dict);
404
+ otherwise only that specific route.
405
+
406
+ Returns
407
+ -------
408
+ dict or None
409
+ - if node_id is None: {route_id: CGR, …}
410
+ - if node_id is given: {node_id: CGR}
411
+ - returns None on error
412
+ """
413
+ # dict-based branch
414
+ if isinstance(tree_or_routes, dict):
415
+ routes_dict = tree_or_routes
416
+
417
+ def _single(rid):
418
+ res = compose_route_cgr(routes_dict, rid)
419
+ return res["cgr"] if res else None
420
+
421
+ if node_id is not None:
422
+ if node_id not in routes_dict:
423
+ raise KeyError(f"Route {node_id} not in provided dict.")
424
+ return {node_id: _single(node_id)}
425
+
426
+ # all routes
427
+ result = {rid: _single(rid) for rid in sorted(routes_dict)}
428
+ return result
429
+
430
+ # tree-based branch
431
+ tree = tree_or_routes
432
+ route_cgrs = {}
433
+
434
+ if node_id is not None:
435
+ res = compose_route_cgr(tree, node_id)
436
+ if res:
437
+ route_cgrs[node_id] = res["cgr"]
438
+ else:
439
+ return None
440
+ return route_cgrs
441
+
442
+ for rid in sorted(set(tree.winning_nodes)):
443
+ res = compose_route_cgr(tree, rid)
444
+ if res:
445
+ route_cgrs[rid] = res["cgr"]
446
+
447
+ return route_cgrs
448
+
449
+
450
+ def extract_reactions(tree: Tree, node_id=None):
451
+ """
452
+ Collect mapped reaction sequences from a synthesis tree.
453
+
454
+ Traverses either a single branch (if `node_id` is given) or all winning routes,
455
+ composing CGR-based reactions for each, and returns a dict of reaction mappings.
456
+ Ensures that in every extracted reaction, atom indices are uniquely mapped (no overlaps)
457
+
458
+ Parameters
459
+ ----------
460
+ tree : ReactionTree
461
+ A retrosynthetic tree object with a `.winning_nodes` attribute and
462
+ supporting `compose_route_cgr(...)`.
463
+ node_id : hashable, optional
464
+ If provided, only extract reactions for this specific node/route.
465
+
466
+ Returns
467
+ -------
468
+ dict[node_id, dict]
469
+ Maps each route terminal node ID to its `reactions_dict` (as returned
470
+ by `compose_route_cgr`). Returns `None` if the specified `node_id` fails
471
+ to produce valid reactions.
472
+ """
473
+ react_dict = {}
474
+ if node_id is not None:
475
+ result = compose_route_cgr(tree, node_id)
476
+ if result:
477
+ react_dict[node_id] = result["reactions_dict"]
478
+ else:
479
+ return None
480
+ return react_dict
481
+
482
+ for node_id in set(tree.winning_nodes):
483
+ result = compose_route_cgr(tree, node_id)
484
+ if result:
485
+ react_dict[node_id] = result["reactions_dict"]
486
+
487
+ return dict(sorted(react_dict.items()))
488
+
489
+
490
+ def compose_sb_cgr(route_cgr: CGRContainer):
491
+ """
492
+ Reduces a Routes Condensed Graph of reaction (RouteCGR) by performing the following steps:
493
+
494
+ 1. Extracts substructures corresponding to connected components from the input RouteCGR.
495
+ 2. Selects the first substructure as the target to work on.
496
+ 3. Iterates over all bonds in the target RouteCGR:
497
+ - If a bond is identified as a "leaving group" (its primary order is None while its original order is defined),
498
+ the bond is removed.
499
+ - If a bond has a modified order (both primary and original orders are integers) and the primary order is less than the original,
500
+ the bond is deleted and then re-added with a new dynamic bond using the primary order (this updates the bond to the reduced form).
501
+ 4. After bond modifications, re-extracts the substructure from the target RouteCGR (now called the reduced RouteCGR or ReducedRouteCGR).
502
+ 5. If the charge distributions (_p_charges vs. _charges) differ, neutralizes the charges by setting them to zero.
503
+
504
+ Args:
505
+ route_cgr: The input RouteCGR object to be reduced.
506
+
507
+ Returns:
508
+ The reduced RouteCGR object.
509
+ """
510
+ # Get all connected components of the RouteCGR as separate substructures.
511
+ cgr_prods = [route_cgr.substructure(c) for c in route_cgr.connected_components]
512
+ target_cgr = cgr_prods[
513
+ 0
514
+ ] # Choose the first substructure (main product) for further reduction.
515
+
516
+ # Iterate over each bond in the target RouteCGR.
517
+ bond_items = list(target_cgr._bonds.items())
518
+ for atom1, bond_set in bond_items:
519
+ bond_set_items = list(bond_set.items())
520
+ for atom2, bond in bond_set_items:
521
+
522
+ # Removing bonds corresponding to leaving groups:
523
+ # If product bond order is None (indicating a leaving group) but an original bond order exists,
524
+ # delete the bond.
525
+ if bond.p_order is None and bond.order is not None:
526
+ target_cgr.delete_bond(atom1, atom2)
527
+
528
+ # For bonds that have been modified (not leaving groups) where the new (primary) order is less than the original:
529
+ # Remove the bond and re-add it using the DynamicBond with the primary order for both bond orders.
530
+ elif (
531
+ type(bond.p_order) is int
532
+ and type(bond.order) is int
533
+ and bond.p_order != bond.order
534
+ ):
535
+ p_order = int(bond.p_order)
536
+ target_cgr.delete_bond(atom1, atom2)
537
+ target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
538
+
539
+ # After modifying bonds, extract the reduced RouteCGR from the target's connected components.
540
+ reduced_route_cgr = [
541
+ target_cgr.substructure(c) for c in target_cgr.connected_components
542
+ ][0]
543
+
544
+ # Neutralize charges if the primary charges and current charges differ.
545
+ if reduced_route_cgr._p_charges != reduced_route_cgr._charges:
546
+ for num, charge in reduced_route_cgr._charges.items():
547
+ if charge != 0:
548
+ reduced_route_cgr._atoms[num].charge = 0
549
+
550
+ return reduced_route_cgr
551
+
552
+
553
+ def compose_all_sb_cgrs(route_cgrs_dict: dict):
554
+ """
555
+ Processes a collection (dictionary) of RouteCGRs to generate their reduced forms (ReducedRouteCGRs).
556
+
557
+ Iterates over each RouteCGR in the provided dictionary and applies the compose_reduced_route_cgr function.
558
+
559
+ Args:
560
+ route_cgrs_dict (dict): A dictionary where keys are identifiers (e.g., route numbers)
561
+ and values are RouteCGR objects.
562
+
563
+ Returns:
564
+ dict: A dictionary where each key corresponds to the original identifier from
565
+ `route_cgrs_dict` and the value is the corresponding ReducedRouteCGR object.
566
+ """
567
+ all_reduced_route_cgrs = dict()
568
+ for num, cgr in route_cgrs_dict.items():
569
+ all_reduced_route_cgrs[num] = compose_sb_cgr(cgr)
570
+ return all_reduced_route_cgrs
synplan/chem/reaction_routes/visualisation.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.algorithms.depict import (
2
+ Depict,
3
+ DepictMolecule,
4
+ DepictCGR,
5
+ rotate_vector,
6
+ _render_charge,
7
+ )
8
+ from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer
9
+
10
+ from collections import defaultdict
11
+ from uuid import uuid4
12
+ from math import hypot
13
+ from functools import partial
14
+
15
+
16
+ class WideBondDepictCGR(DepictCGR):
17
+ """
18
+ Like DepictCGR, but all DynamicBonds
19
+ are drawn 2.5× wider than the standard bond width.
20
+ """
21
+
22
+ __slots__ = ()
23
+
24
+ def _render_bonds(self):
25
+ """
26
+ Renders the bonds of the CGR as SVG lines, with DynamicBonds drawn wider.
27
+
28
+ This method overrides the base `_render_bonds` to apply a wider stroke
29
+ to DynamicBonds, highlighting changes in bond order during a reaction.
30
+ It iterates through all bonds, calculates their positions based on
31
+ 2D coordinates, and generates SVG `<line>` elements with appropriate
32
+ styles (color, width, dash array) based on the bond's original (`order`)
33
+ and primary (`p_order`) states. Aromatic bonds are handled separately
34
+ using a helper method.
35
+
36
+ Returns:
37
+ list: A list of strings, where each string is an SVG element
38
+ representing a bond.
39
+ """
40
+ plane = self._plane
41
+ config = self._render_config
42
+
43
+ # get the normal width (default 1.0) and compute a 4× wide stroke
44
+ normal_width = config.get("bond_width", 0.02)
45
+ wide_width = normal_width * 2.5
46
+
47
+ broken = config["broken_color"]
48
+ formed = config["formed_color"]
49
+ dash1, dash2 = config["dashes"]
50
+ double_space = config["double_space"]
51
+ triple_space = config["triple_space"]
52
+
53
+ svg = []
54
+ ar_bond_colors = defaultdict(dict)
55
+
56
+ for n, m, bond in self.bonds():
57
+ order, p_order = bond.order, bond.p_order
58
+ nx, ny = plane[n]
59
+ mx, my = plane[m]
60
+ # invert Y for SVG
61
+ ny, my = -ny, -my
62
+ rv = partial(rotate_vector, 0, x2=mx - nx, y2=ny - my)
63
+ if order == 1:
64
+ if p_order == 1:
65
+ svg.append(
66
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
67
+ )
68
+ elif p_order == 4:
69
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
70
+ svg.append(
71
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
72
+ )
73
+ elif p_order == 2:
74
+ dx, dy = rv(double_space)
75
+ svg.append(
76
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
77
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
78
+ )
79
+ svg.append(
80
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
81
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
82
+ )
83
+ elif p_order == 3:
84
+ dx, dy = rv(triple_space)
85
+ svg.append(
86
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
87
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
88
+ )
89
+ svg.append(
90
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
91
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke-width="{wide_width:.2f}"/>'
92
+ )
93
+ svg.append(
94
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
95
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
96
+ )
97
+ elif p_order is None:
98
+ svg.append(
99
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
100
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
101
+ )
102
+ else:
103
+ dx, dy = rv(double_space)
104
+ svg.append(
105
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
106
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
107
+ )
108
+ svg.append(
109
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
110
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
111
+ )
112
+ elif order == 4:
113
+ if p_order == 4:
114
+ svg.append(
115
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
116
+ )
117
+ elif p_order == 1:
118
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
119
+ svg.append(
120
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
121
+ )
122
+ elif p_order == 2:
123
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
124
+ dx, dy = rv(double_space)
125
+ svg.append(
126
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
127
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
128
+ )
129
+ svg.append(
130
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
131
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
132
+ )
133
+ elif p_order == 3:
134
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
135
+ dx, dy = rv(triple_space)
136
+ svg.append(
137
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
138
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
139
+ )
140
+ svg.append(
141
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
142
+ )
143
+ svg.append(
144
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
145
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
146
+ )
147
+ elif p_order is None:
148
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
149
+ svg.append(
150
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
151
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
152
+ )
153
+ else:
154
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = None
155
+ svg.append(
156
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
157
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
158
+ )
159
+ elif order == 2:
160
+ if p_order == 2:
161
+ dx, dy = rv(double_space)
162
+ svg.append(
163
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
164
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
165
+ )
166
+ svg.append(
167
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
168
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}"/>'
169
+ )
170
+ elif p_order == 1:
171
+ dx, dy = rv(double_space)
172
+ svg.append(
173
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
174
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
175
+ )
176
+ svg.append(
177
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
178
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
179
+ )
180
+ elif p_order == 4:
181
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
182
+ dx, dy = rv(double_space)
183
+ svg.append(
184
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
185
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
186
+ )
187
+ svg.append(
188
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
189
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
190
+ )
191
+ elif p_order == 3:
192
+ dx, dy = rv(triple_space)
193
+ svg.append(
194
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
195
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
196
+ )
197
+ svg.append(
198
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
199
+ )
200
+ svg.append(
201
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
202
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed} stroke-width="{wide_width:.2f}""/>'
203
+ )
204
+ elif p_order is None:
205
+ dx, dy = rv(double_space)
206
+ svg.append(
207
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
208
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
209
+ )
210
+ svg.append(
211
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
212
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
213
+ )
214
+ else:
215
+ dx, dy = rv(triple_space)
216
+ svg.append(
217
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
218
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
219
+ )
220
+ svg.append(
221
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
222
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
223
+ )
224
+ svg.append(
225
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
226
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
227
+ )
228
+ elif order == 3:
229
+ if p_order == 3:
230
+ dx, dy = rv(triple_space)
231
+ svg.append(
232
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
233
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
234
+ )
235
+ svg.append(
236
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
237
+ )
238
+ svg.append(
239
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
240
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}"/>'
241
+ )
242
+ elif p_order == 1:
243
+ dx, dy = rv(triple_space)
244
+ svg.append(
245
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
246
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
247
+ )
248
+ svg.append(
249
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
250
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
251
+ )
252
+ svg.append(
253
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
254
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" '
255
+ f'stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
256
+ )
257
+ elif p_order == 4:
258
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
259
+ dx, dy = rv(triple_space)
260
+ svg.append(
261
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}" '
262
+ f'y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
263
+ )
264
+ svg.append(
265
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
266
+ )
267
+ svg.append(
268
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" x2="{mx - dx:.2f}" '
269
+ f'y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
270
+ )
271
+ elif p_order == 2:
272
+ dx, dy = rv(triple_space)
273
+ svg.append(
274
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
275
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
276
+ )
277
+ svg.append(
278
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
279
+ )
280
+ svg.append(
281
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
282
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
283
+ )
284
+ elif p_order is None:
285
+ dx, dy = rv(triple_space)
286
+ svg.append(
287
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
288
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
289
+ )
290
+ svg.append(
291
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" '
292
+ f'x2="{mx:.2f}" y2="{my:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
293
+ )
294
+ svg.append(
295
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
296
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
297
+ )
298
+ else:
299
+ dx, dy = rv(double_space)
300
+ dx3 = 3 * dx
301
+ dy3 = 3 * dy
302
+ svg.append(
303
+ f' <line x1="{nx + dx3:.2f}" y1="{ny - dy3:.2f}" x2="{mx + dx3:.2f}" '
304
+ f'y2="{my - dy3:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
305
+ )
306
+ svg.append(
307
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
308
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
309
+ )
310
+ svg.append(
311
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
312
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
313
+ )
314
+ svg.append(
315
+ f' <line x1="{nx - dx3:.2f}" y1="{ny + dy3:.2f}" x2="{mx - dx3:.2f}" '
316
+ f'y2="{my + dy3:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
317
+ )
318
+ elif order is None:
319
+ if p_order == 1:
320
+ svg.append(
321
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
322
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
323
+ )
324
+ elif p_order == 4:
325
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
326
+ svg.append(
327
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
328
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
329
+ )
330
+ elif p_order == 2:
331
+ dx, dy = rv(double_space)
332
+ # dx = dx // 1.4
333
+ # dy = dy // 1.4
334
+ svg.append(
335
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}" '
336
+ f'y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
337
+ )
338
+ svg.append(
339
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" x2="{mx - dx:.2f}" '
340
+ f'y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
341
+ )
342
+ elif p_order == 3:
343
+ dx, dy = rv(triple_space)
344
+ svg.append(
345
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
346
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
347
+ )
348
+ svg.append(
349
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
350
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
351
+ )
352
+ svg.append(
353
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
354
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
355
+ )
356
+ else:
357
+ svg.append(
358
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
359
+ f'stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
360
+ )
361
+ else:
362
+ if p_order == 8:
363
+ svg.append(
364
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
365
+ f'stroke-dasharray="{dash1:.2f} {dash2:.2f}"/>'
366
+ )
367
+ elif p_order == 1:
368
+ dx, dy = rv(double_space)
369
+ svg.append(
370
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
371
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
372
+ )
373
+ svg.append(
374
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
375
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
376
+ )
377
+ elif p_order == 4:
378
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = None
379
+ svg.append(
380
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
381
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
382
+ )
383
+ elif p_order == 2:
384
+ dx, dy = rv(triple_space)
385
+ svg.append(
386
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
387
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
388
+ )
389
+ svg.append(
390
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
391
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
392
+ )
393
+ svg.append(
394
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
395
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
396
+ )
397
+ elif p_order == 3:
398
+ dx, dy = rv(double_space)
399
+ dx3 = 3 * dx
400
+ dy3 = 3 * dy
401
+ svg.append(
402
+ f' <line x1="{nx + dx3:.2f}" y1="{ny - dy3:.2f}" x2="{mx + dx3:.2f}" '
403
+ f'y2="{my - dy3:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
404
+ )
405
+ svg.append(
406
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
407
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
408
+ )
409
+ svg.append(
410
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
411
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
412
+ )
413
+ svg.append(
414
+ f' <line x1="{nx - dx3:.2f}" y1="{ny + dy3:.2f}" '
415
+ f'x2="{mx - dx3:.2f}" y2="{my + dy3:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
416
+ )
417
+ else:
418
+ svg.append(
419
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
420
+ f'stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
421
+ )
422
+
423
+ # aromatic rings - unchanged
424
+ for ring in self.aromatic_rings:
425
+ cx = sum(plane[x][0] for x in ring) / len(ring)
426
+ cy = sum(plane[x][1] for x in ring) / len(ring)
427
+
428
+ for n, m in zip(ring, ring[1:]):
429
+ nx, ny = plane[n]
430
+ mx, my = plane[m]
431
+ aromatic = self.__render_aromatic_bond(
432
+ nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m)
433
+ )
434
+ if aromatic:
435
+ svg.append(aromatic)
436
+
437
+ n, m = ring[-1], ring[0]
438
+ nx, ny = plane[n]
439
+ mx, my = plane[m]
440
+ aromatic = self.__render_aromatic_bond(
441
+ nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m)
442
+ )
443
+ if aromatic:
444
+ svg.append(aromatic)
445
+ return svg
446
+
447
+ def __render_aromatic_bond(self, n_x, n_y, m_x, m_y, c_x, c_y, color):
448
+ config = self._render_config
449
+
450
+ dash1, dash2 = config["dashes"]
451
+ dash3, dash4 = config["aromatic_dashes"]
452
+ aromatic_space = config["cgr_aromatic_space"]
453
+
454
+ normal_width = config.get("bond_width", 0.02)
455
+ wide_width = normal_width * 2
456
+
457
+ # n aligned xy
458
+ mn_x, mn_y, cn_x, cn_y = m_x - n_x, m_y - n_y, c_x - n_x, c_y - n_y
459
+
460
+ # nm reoriented xy
461
+ mr_x, mr_y = hypot(mn_x, mn_y), 0
462
+ cr_x, cr_y = rotate_vector(cn_x, cn_y, mn_x, -mn_y)
463
+
464
+ if cr_y and aromatic_space / cr_y < 0.65:
465
+ if cr_y > 0:
466
+ r_y = aromatic_space
467
+ else:
468
+ r_y = -aromatic_space
469
+ cr_y = -cr_y
470
+
471
+ ar_x = aromatic_space * cr_x / cr_y
472
+ br_x = mr_x - aromatic_space * (mr_x - cr_x) / cr_y
473
+
474
+ # backward reorienting
475
+ an_x, an_y = rotate_vector(ar_x, r_y, mn_x, mn_y)
476
+ bn_x, bn_y = rotate_vector(br_x, r_y, mn_x, mn_y)
477
+
478
+ if color:
479
+ # print('color')
480
+ return (
481
+ f' <line x1="{an_x + n_x:.2f}" y1="{-an_y - n_y:.2f}" x2="{bn_x + n_x:.2f}" '
482
+ f'y2="{-bn_y - n_y:.2f}" stroke-dasharray="{dash3:.2f} {dash4:.2f}" stroke="{color}" stroke-width="{wide_width:.2f}"/>'
483
+ )
484
+ elif color is None:
485
+ dash3, dash4 = dash1, dash2
486
+ return (
487
+ f' <line x1="{an_x + n_x:.2f}" y1="{-an_y - n_y:.2f}"'
488
+ f' x2="{bn_x + n_x:.2f}" y2="{-bn_y - n_y:.2f}" stroke-dasharray="{dash3:.2f} {dash4:.2f}"/>'
489
+ )
490
+
491
+
492
+ def cgr_display(cgr: CGRContainer) -> str:
493
+ """
494
+ Generates an SVG string for displaying a CGR with wider DynamicBonds.
495
+
496
+ This function temporarily modifies the rendering methods of the
497
+ `CGRContainer` class to use the bond rendering logic from
498
+ `WideBondDepictCGR`, which draws DynamicBonds with a wider stroke.
499
+ It cleans the 2D coordinates of the input CGR and then calls its
500
+ `depict()` method to generate the SVG string using the modified
501
+ rendering behavior.
502
+
503
+ Args:
504
+ cgr (CGRContainer): The CGRContainer object to be depicted.
505
+
506
+ Returns:
507
+ str: An SVG string representing the depiction of the CGR
508
+ with wider DynamicBonds.
509
+ """
510
+ CGRContainer._CGRContainer__render_aromatic_bond = (
511
+ WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond
512
+ )
513
+ CGRContainer._render_bonds = WideBondDepictCGR._render_bonds
514
+ CGRContainer._WideBondDepictCGR__render_aromatic_bond = (
515
+ WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond
516
+ )
517
+ cgr.clean2d()
518
+ return cgr.depict()
519
+
520
+
521
+ class CustomDepictMolecule(DepictMolecule):
522
+ """
523
+ Custom molecule depiction class that uses atom.symbol for rendering.
524
+ """
525
+
526
+ def _render_atoms(self):
527
+ bonds = self._bonds
528
+ plane = self._plane
529
+ charges = self._charges
530
+ radicals = self._radicals
531
+ hydrogens = self._hydrogens
532
+ config = self._render_config
533
+
534
+ carbon = config["carbon"]
535
+ mapping = config["mapping"]
536
+ span_size = config["span_size"]
537
+ font_size = config["font_size"]
538
+ monochrome = config["monochrome"]
539
+ other_size = config["other_size"]
540
+ atoms_colors = config["atoms_colors"]
541
+ mapping_font = config["mapping_size"]
542
+ dx_m, dy_m = config["dx_m"], config["dy_m"]
543
+ dx_ci, dy_ci = config["dx_ci"], config["dy_ci"]
544
+ symbols_font_style = config["symbols_font_style"]
545
+
546
+ # for cumulenes
547
+ try:
548
+ # Check if _cumulenes method exists and handle potential errors
549
+ cumulenes = {
550
+ y
551
+ for x in self._cumulenes(heteroatoms=True)
552
+ if len(x) > 2
553
+ for y in x[1:-1]
554
+ }
555
+ except AttributeError:
556
+ cumulenes = set() # Fallback if _cumulenes is not available or fails
557
+
558
+ if monochrome:
559
+ map_fill = other_fill = "black"
560
+ else:
561
+ map_fill = config["mapping_color"]
562
+ other_fill = config["other_color"]
563
+
564
+ svg = []
565
+ maps = []
566
+ others = []
567
+ font2 = 0.2 * font_size
568
+ font3 = 0.3 * font_size
569
+ font4 = 0.4 * font_size
570
+ font5 = 0.5 * font_size
571
+ font6 = 0.6 * font_size
572
+ font7 = 0.7 * font_size
573
+ font15 = 0.15 * font_size
574
+ font25 = 0.25 * font_size
575
+ mask = defaultdict(list)
576
+ for n, atom in self._atoms.items():
577
+ x, y = plane[n]
578
+ y = -y
579
+
580
+ # --- KEY CHANGE HERE ---
581
+ # Use atom.symbol if it exists, otherwise fallback to atomic_symbol
582
+ try:
583
+ symbol = atom.symbol
584
+ except AttributeError:
585
+ symbol = atom.atomic_symbol # Fallback if .symbol doesn't exist
586
+ # --- END KEY CHANGE ---
587
+
588
+ if (
589
+ not bonds.get(n)
590
+ or symbol != "C"
591
+ or carbon
592
+ or atom.charge
593
+ or atom.is_radical
594
+ or atom.isotope
595
+ or n in cumulenes
596
+ ): # Added bonds.get(n) check for single atoms
597
+ # Calculate hydrogens if the attribute exists, otherwise default to 0
598
+ try:
599
+ h = hydrogens[n]
600
+ except (KeyError, AttributeError):
601
+ h = 0 # Default if _hydrogens is missing or key n is not present
602
+
603
+ if h == 1:
604
+ h_str = "H"
605
+ span = ""
606
+ elif h and h > 1: # Check if h is not None and greater than 1
607
+ span = f'<tspan dy="{config["span_dy"]:.2f}" font-size="{span_size:.2f}">{h}</tspan>'
608
+ h_str = "H"
609
+ else:
610
+ h_str = ""
611
+ span = ""
612
+
613
+ # Handle charges and radicals safely
614
+ charge_val = charges.get(n, 0)
615
+ is_radical = radicals.get(n, False)
616
+
617
+ if charge_val:
618
+ t = f'{_render_charge.get(charge_val, "")}{"↑" if is_radical else ""}' # Use .get for safety
619
+ if t: # Only add if charge text is generated
620
+ others.append(
621
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">'
622
+ f"{t}</text>"
623
+ )
624
+ mask["other"].append(
625
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">'
626
+ f"{t}</text>"
627
+ )
628
+ elif is_radical:
629
+ others.append(
630
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">↑</text>'
631
+ )
632
+ mask["other"].append(
633
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}"'
634
+ f' dy="-{dy_ci:.2f}">↑</text>'
635
+ )
636
+
637
+ # Handle isotope safely
638
+ try:
639
+ iso = atom.isotope
640
+ if iso:
641
+ t = iso
642
+ others.append(
643
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_ci:.2f}" dy="-{dy_ci:.2f}" '
644
+ f'text-anchor="end">{t}</text>'
645
+ )
646
+ mask["other"].append(
647
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_ci:.2f}"'
648
+ f' dy="-{dy_ci:.2f}" text-anchor="end">{t}</text>'
649
+ )
650
+ except AttributeError:
651
+ pass # Atom might not have isotope attribute
652
+
653
+ # Determine atom color based on atomic_number, default to black if monochrome or not found
654
+ atom_color = "black"
655
+ if not monochrome:
656
+ try:
657
+ an = atom.atomic_number
658
+ if 0 < an <= len(atoms_colors):
659
+ atom_color = atoms_colors[an - 1]
660
+ else:
661
+ atom_color = atoms_colors[
662
+ 5
663
+ ] # Default to Carbon color if out of range
664
+ except AttributeError:
665
+ atom_color = atoms_colors[
666
+ 5
667
+ ] # Default to Carbon color if no atomic_number
668
+
669
+ svg.append(
670
+ f' <g fill="{atom_color}" '
671
+ f'font-family="{symbols_font_style }">'
672
+ )
673
+
674
+ # Adjust dx based on symbol length for better centering
675
+ if len(symbol) > 1:
676
+ dx = font7
677
+ dx_mm = dx_m + font5
678
+ if symbol[-1].lower() in (
679
+ "l",
680
+ "i",
681
+ "r",
682
+ "t",
683
+ ): # Heuristic for narrow last letters
684
+ rx = font6
685
+ ax = font25
686
+ else:
687
+ rx = font7
688
+ ax = font15
689
+ mask["center"].append(
690
+ f' <ellipse cx="{x - ax:.2f}" cy="{y:.2f}" rx="{rx}" ry="{font4}"/>'
691
+ )
692
+ else:
693
+ if symbol == "I": # Special case for 'I'
694
+ dx = font15
695
+ dx_mm = dx_m
696
+ else: # Single character
697
+ dx = font4
698
+ dx_mm = dx_m + font2
699
+ mask["center"].append(
700
+ f' <circle cx="{x:.2f}" cy="{y:.2f}" r="{font4:.2f}"/>'
701
+ )
702
+
703
+ svg.append(
704
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" dy="{font4:.2f}" '
705
+ f'font-size="{font_size:.2f}">{symbol}{h_str}{span}</text>'
706
+ )
707
+ mask["symbols"].append(
708
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" '
709
+ f'dy="{font4:.2f}">{symbol}{h_str}</text>'
710
+ )
711
+ if span:
712
+ mask["span"].append(
713
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" dy="{font4:.2f}">'
714
+ f"{symbol}{h_str}{span}</text>"
715
+ )
716
+ svg.append(" </g>")
717
+
718
+ if mapping:
719
+ maps.append(
720
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m + font3:.2f}" '
721
+ f'text-anchor="end">{n}</text>'
722
+ )
723
+ mask["aam"].append(
724
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" '
725
+ f'dy="{dy_m + font3:.2f}" text-anchor="end">{n}</text>'
726
+ )
727
+
728
+ elif mapping:
729
+ # Determine dx_mm for mapping based on symbol length even if atom itself isn't drawn
730
+ if len(symbol) > 1:
731
+ dx_mm = dx_m + font5
732
+ else:
733
+ dx_mm = dx_m + font2 if symbol != "I" else dx_m
734
+
735
+ maps.append(
736
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m:.2f}" '
737
+ f'text-anchor="end">{n}</text>'
738
+ )
739
+ mask["aam"].append(
740
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m:.2f}" '
741
+ f'text-anchor="end">{n}</text>'
742
+ )
743
+ if others:
744
+ svg.append(
745
+ f' <g font-family="{config["other_font_style"]}" fill="{other_fill}" '
746
+ f'font-size="{other_size:.2f}">'
747
+ )
748
+ svg.extend(others)
749
+ svg.append(" </g>")
750
+ if mapping:
751
+ svg.append(f' <g fill="{map_fill}" font-size="{mapping_font:.2f}">')
752
+ svg.extend(maps)
753
+ svg.append(" </g>")
754
+ return svg, mask
755
+
756
+
757
+ def depict_custom_reaction(reaction: ReactionContainer):
758
+ """
759
+ Depicts a ReactionContainer using custom atom rendering logic (replace At to X).
760
+
761
+ This function generates an SVG string representing a reaction. It
762
+ temporarily modifies the classes of the molecules within the reaction
763
+ to use a custom depiction logic (`CustomDepictMolecule`) that alters
764
+ how atoms are rendered (specifically, it seems to use `atom.symbol`
765
+ instead of `atom.atomic_symbol`, potentially for replacing 'At' with 'X'
766
+ as mentioned in the original comment). After depicting each molecule
767
+ with the temporary class, it restores the original classes. The function
768
+ then combines the individual molecule depictions, reaction arrow, and
769
+ reaction signs into a single SVG.
770
+
771
+ Args:
772
+ reaction (ReactionContainer): The ReactionContainer object to be depicted.
773
+
774
+ Returns:
775
+ str: An SVG string representing the depiction of the reaction
776
+ with custom atom rendering.
777
+ """
778
+ if not reaction._arrow:
779
+ reaction.fix_positions() # Ensure positions are calculated
780
+
781
+ r_atoms = []
782
+ r_bonds = []
783
+ r_masks = []
784
+ r_max_x = r_max_y = r_min_y = 0
785
+ original_classes = {} # Store original classes to restore later
786
+
787
+ try:
788
+ # Temporarily change the class of molecules to use the custom depiction
789
+ for mol in reaction.molecules():
790
+ if isinstance(mol, (MoleculeContainer, CGRContainer)):
791
+ original_classes[mol] = mol.__class__
792
+ custom_class_name = (
793
+ f"TempCustom_{mol.__class__.__name__}_{uuid4().hex}" # Unique name
794
+ )
795
+ # Combine custom depiction with original class methods
796
+ # Ensure the custom _render_atoms takes precedence
797
+ new_bases = (CustomDepictMolecule,) + original_classes[mol].__bases__
798
+ # Filter out DepictMolecule if it's already a base to avoid MRO issues
799
+ new_bases = tuple(b for b in new_bases if b is not DepictMolecule)
800
+ # If DepictMolecule wasn't a direct base, ensure its methods are accessible
801
+ if CustomDepictMolecule not in original_classes[mol].__mro__:
802
+ # Prioritize CustomDepictMolecule's methods
803
+ new_bases = (CustomDepictMolecule, original_classes[mol])
804
+ else:
805
+ # If DepictMolecule was a base, CustomDepictMolecule is already first
806
+ new_bases = (CustomDepictMolecule,) + tuple(
807
+ b
808
+ for b in original_classes[mol].__bases__
809
+ if b is not DepictMolecule
810
+ )
811
+
812
+ # Create the temporary class
813
+ mol.__class__ = type(custom_class_name, new_bases, {})
814
+
815
+ # Depict using the (potentially) modified class
816
+ atoms, bonds, masks, min_x, min_y, max_x, max_y = mol.depict(embedding=True)
817
+ r_atoms.append(atoms)
818
+ r_bonds.append(bonds)
819
+ r_masks.append(masks)
820
+ if max_x > r_max_x:
821
+ r_max_x = max_x
822
+ if max_y > r_max_y:
823
+ r_max_y = max_y
824
+ if min_y < r_min_y:
825
+ r_min_y = min_y
826
+
827
+ finally:
828
+ # Restore original classes
829
+ for mol, original_class in original_classes.items():
830
+ mol.__class__ = original_class
831
+
832
+ config = DepictMolecule._render_config # Access via the imported class
833
+
834
+ font_size = config["font_size"]
835
+ font125 = 1.25 * font_size
836
+ width = r_max_x + 3.0 * font_size
837
+ height = r_max_y - r_min_y + 2.5 * font_size
838
+ viewbox_x = -font125
839
+ viewbox_y = -r_max_y - font125
840
+
841
+ svg = [
842
+ f'<svg width="{width:.2f}cm" height="{height:.2f}cm" '
843
+ f'viewBox="{viewbox_x:.2f} {viewbox_y:.2f} {width:.2f} '
844
+ f'{height:.2f}" xmlns="http://www.w3.org/2000/svg" version="1.1">\n'
845
+ ' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" '
846
+ 'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3"/>\n </marker>\n </defs>\n'
847
+ f' <line x1="{reaction._arrow[0]:.2f}" y1="0" x2="{reaction._arrow[1]:.2f}" y2="0" '
848
+ 'fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>'
849
+ ]
850
+
851
+ sings_plus = reaction._signs
852
+ if sings_plus:
853
+ svg.append(f' <g fill="none" stroke="black" stroke-width=".04">')
854
+ for x in sings_plus:
855
+ svg.append(
856
+ f' <line x1="{x + .35:.2f}" y1="0" x2="{x + .65:.2f}" y2="0"/>'
857
+ )
858
+ svg.append(
859
+ f' <line x1="{x + .5:.2f}" y1="0.15" x2="{x + .5:.2f}" y2="-0.15"/>'
860
+ )
861
+ svg.append(" </g>")
862
+
863
+ for atoms, bonds, masks in zip(r_atoms, r_bonds, r_masks):
864
+ # Use the static method from Depict directly
865
+ svg.extend(
866
+ Depict._graph_svg(atoms, bonds, masks, viewbox_x, viewbox_y, width, height)
867
+ )
868
+ svg.append("</svg>")
869
+ return "\n".join(svg)
870
+
871
+
872
+ def remove_and_shift(nested_dict, to_remove): # Under development
873
+ """
874
+ Removes specified inner keys from a nested dictionary and renumbers the remaining keys.
875
+
876
+ Given a dictionary where values are themselves dictionaries, this function
877
+ iterates through each inner dictionary. For each inner dictionary, it
878
+ creates a new dictionary containing only the key-value pairs where the
879
+ inner key is NOT present in the `to_remove` list. The keys of the remaining
880
+ elements in the new inner dictionary are then renumbered sequentially
881
+ starting from 0, effectively removing gaps left by the removed keys.
882
+
883
+ Args:
884
+ nested_dict (dict): The input nested dictionary (dict of dicts).
885
+ to_remove (list): A list of keys to remove from the inner dictionaries.
886
+
887
+ Returns:
888
+ dict: A new nested dictionary with the specified keys removed from
889
+ inner dictionaries and the remaining inner keys renumbered.
890
+ """
891
+ rem_set = set(to_remove)
892
+
893
+ result = {}
894
+ for outer_k, inner in nested_dict.items():
895
+ new_inner = {}
896
+ for old_k, v in inner.items():
897
+ if old_k in rem_set:
898
+ continue
899
+ shift = sum(1 for r in rem_set if r < old_k)
900
+ new_k = old_k - shift
901
+ new_inner[new_k] = v
902
+ result[outer_k] = new_inner
903
+ return result
synplan/chem/reaction_rules/__init__.py ADDED
File without changes
synplan/chem/reaction_rules/extraction.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for protocol of reaction rules extraction."""
2
+
3
+ import logging
4
+ import pickle
5
+ from collections import defaultdict
6
+ from itertools import islice
7
+ from os.path import splitext
8
+ from typing import Dict, List, Set, Tuple
9
+
10
+ import ray
11
+ from chython import smarts
12
+ from chython import QueryContainer as QueryContainerChython
13
+ from CGRtools.containers.cgr import CGRContainer
14
+ from CGRtools.containers.molecule import MoleculeContainer
15
+ from CGRtools.containers.query import QueryContainer
16
+ from CGRtools.containers.reaction import ReactionContainer
17
+ from CGRtools.exceptions import InvalidAromaticRing
18
+ from CGRtools.reactor import Reactor
19
+ from tqdm import tqdm
20
+
21
+ from synplan.chem.data.standardizing import RemoveReagentsStandardizer
22
+ from synplan.chem.utils import (
23
+ reverse_reaction,
24
+ cgrtools_to_chython_molecule,
25
+ chython_query_to_cgrtools,
26
+ )
27
+ from synplan.utils.config import RuleExtractionConfig
28
+ from synplan.utils.files import ReactionReader
29
+
30
+
31
+ def add_environment_atoms(
32
+ cgr: CGRContainer, center_atoms: Set[int], environment_atom_count: int
33
+ ) -> Set[int]:
34
+ """
35
+ Adds environment atoms to the set of center atoms based on the specified depth.
36
+
37
+ :param cgr: A complete graph representation of a reaction (ReactionContainer
38
+ object).
39
+ :param center_atoms: A set of atom id corresponding to the center atoms of the
40
+ reaction.
41
+ :param environment_atom_count: An integer specifying the depth of the environment
42
+ around the reaction center to be included. If it's 0, only the reaction center
43
+ is included. If it's 1, the first layer of surrounding atoms is included, and so
44
+ on.
45
+
46
+ :return: A set of atom id including the center atoms and their environment atoms up
47
+ to the specified depth. If environment_atom_count is 0, the original set of
48
+ center atoms is returned unchanged.
49
+
50
+ """
51
+ if environment_atom_count:
52
+ env_cgr = cgr.augmented_substructure(center_atoms, deep=environment_atom_count)
53
+ # combine the original center atoms with the new environment atoms
54
+ return center_atoms | set(env_cgr)
55
+
56
+ # if no environment is to be included, return the original center atoms
57
+ return center_atoms
58
+
59
+
60
+ def add_functional_groups(
61
+ reaction: ReactionContainer,
62
+ center_atoms: Set[int],
63
+ func_groups_list: List[QueryContainerChython],
64
+ ) -> Set[int]:
65
+ """
66
+ Augments the set of reaction rule atoms with functional groups if specified.
67
+
68
+ :param reaction: The reaction object (ReactionContainer) from which molecules are
69
+ extracted.
70
+ :param center_atoms: A set of atom id corresponding to the center atoms of the
71
+ reaction.
72
+ :param func_groups_list: A list of functional group objects (MoleculeContainer or
73
+ QueryContainer) to be considered when including functional groups. These objects
74
+ define the structure of the functional groups to be included.
75
+
76
+ :return: A set of atom id corresponding to the rule atoms, including atoms from the
77
+ specified functional groups if include_func_groups is True. If
78
+ include_func_groups is False, the original set of center atoms is returned.
79
+
80
+ """
81
+
82
+ rule_atoms = center_atoms.copy()
83
+ # iterate over each molecule in the reaction
84
+ for molecule in reaction.molecules():
85
+ molecule_chython = cgrtools_to_chython_molecule(molecule)
86
+ # for each functional group specified in the list
87
+ for func_group in func_groups_list:
88
+ # find mappings of the functional group in the molecule
89
+ for mapping in func_group.get_mapping(molecule_chython):
90
+ # remap the functional group based on the found mapping
91
+ func_group.remap(mapping)
92
+ # if the functional group intersects with center atoms, include it
93
+ if set(func_group.atoms_numbers) & center_atoms:
94
+ rule_atoms |= set(func_group.atoms_numbers)
95
+ # reset the mapping to its original state for the next iteration
96
+ func_group.remap({v: k for k, v in mapping.items()})
97
+ return rule_atoms
98
+
99
+
100
+ def add_ring_structures(cgr: CGRContainer, rule_atoms: Set[int]) -> Set[int]:
101
+ """
102
+ Adds ring structures to the set of rule atoms if they intersect with the reaction
103
+ center atoms.
104
+
105
+ :param cgr: A condensed graph representation of a reaction (CGRContainer object).
106
+ :param rule_atoms: A set of atom id corresponding to the center atoms of the
107
+ reaction.
108
+
109
+ :return: A set of atom id corresponding to the original rule atoms and the included
110
+ ring structures.
111
+
112
+ """
113
+ for ring in cgr.sssr:
114
+ # check if the current ring intersects with the set of rule atoms
115
+ if set(ring) & rule_atoms:
116
+ # if the intersection exists, include all atoms in the ring to the rule atoms
117
+ rule_atoms |= set(ring)
118
+ return rule_atoms
119
+
120
+
121
+ def add_leaving_incoming_groups(
122
+ reaction: ReactionContainer,
123
+ rule_atoms: Set[int],
124
+ keep_leaving_groups: bool,
125
+ keep_incoming_groups: bool,
126
+ ) -> Tuple[Set[int], Dict[str, Set]]:
127
+ """
128
+ Identifies and includes leaving and incoming groups to the rule atoms based on
129
+ specified flags.
130
+
131
+ :param reaction: The reaction object (ReactionContainer) from which leaving and
132
+ incoming groups are extracted.
133
+ :param rule_atoms: A set of atom id corresponding to the center atoms of the
134
+ reaction.
135
+ :param keep_leaving_groups: A boolean flag indicating whether to include leaving
136
+ groups in the rule.
137
+ :param keep_incoming_groups: A boolean flag indicating whether to include incoming
138
+ groups in the rule.
139
+
140
+ :return: Updated set of rule atoms including leaving and incoming groups if
141
+ specified, and metadata about added groups.
142
+
143
+ """
144
+
145
+ meta_debug = {"leaving": set(), "incoming": set()}
146
+
147
+ # extract atoms from reactants and products
148
+ reactant_atoms = {atom for reactant in reaction.reactants for atom in reactant}
149
+ product_atoms = {atom for product in reaction.products for atom in product}
150
+
151
+ # identify leaving groups (reactant atoms not in products)
152
+ if keep_leaving_groups:
153
+ leaving_atoms = reactant_atoms - product_atoms
154
+ new_leaving_atoms = leaving_atoms - rule_atoms
155
+ # include leaving atoms in the rule atoms
156
+ rule_atoms |= leaving_atoms
157
+ # add leaving atoms to metadata
158
+ meta_debug["leaving"] |= new_leaving_atoms
159
+
160
+ # identify incoming groups (product atoms not in reactants)
161
+ if keep_incoming_groups:
162
+ incoming_atoms = product_atoms - reactant_atoms
163
+ new_incoming_atoms = incoming_atoms - rule_atoms
164
+ # Include incoming atoms in the rule atoms
165
+ rule_atoms |= incoming_atoms
166
+ # Add incoming atoms to metadata
167
+ meta_debug["incoming"] |= new_incoming_atoms
168
+
169
+ return rule_atoms, meta_debug
170
+
171
+
172
+ def clean_molecules(
173
+ rule_molecules: List[MoleculeContainer],
174
+ reaction_molecules: Tuple[MoleculeContainer],
175
+ reaction_center_atoms: Set[int],
176
+ atom_retention_details: Dict[str, Dict[str, bool]],
177
+ ) -> List[QueryContainer]:
178
+ """
179
+ Cleans rule molecules by removing specified information about atoms based on
180
+ retention details provided.
181
+
182
+ :param rule_molecules: A list of query container objects representing the rule molecules.
183
+ :param reaction_molecules: A list of molecule container objects involved in the reaction.
184
+ :param reaction_center_atoms: A set of id corresponding to the atom numbers in the reaction center.
185
+ :param atom_retention_details: A dictionary specifying what atom information to retain or remove.
186
+ This dictionary should have two keys: "reaction_center" and "environment",
187
+ each mapping to another dictionary. The nested dictionaries should have
188
+ keys representing atom attributes (like "neighbors", "hybridization",
189
+ "implicit_hydrogens", "ring_sizes") and boolean values.
190
+ A value of True indicates that the corresponding attribute
191
+ should be retained, while False indicates it should be removed from the atom.
192
+
193
+ :return: A list of QueryContainer objects representing the cleaned rule molecules.
194
+
195
+ """
196
+ cleaned_rule_molecules = []
197
+
198
+ for rule_molecule in rule_molecules:
199
+ for reaction_molecule in reaction_molecules:
200
+ if set(rule_molecule.atoms_numbers) <= set(reaction_molecule.atoms_numbers):
201
+ query_reaction_molecule = reaction_molecule.substructure(
202
+ reaction_molecule, as_query=True
203
+ )
204
+ query_rule_molecule = query_reaction_molecule.substructure(
205
+ rule_molecule
206
+ )
207
+
208
+ # clean reaction center atoms
209
+ if not all(
210
+ atom_retention_details["reaction_center"].values()
211
+ ): # if everything True, we keep all marks
212
+ local_reaction_center_atoms = (
213
+ set(rule_molecule.atoms_numbers) & reaction_center_atoms
214
+ )
215
+ for atom_number in local_reaction_center_atoms:
216
+ query_rule_molecule = clean_atom(
217
+ query_rule_molecule,
218
+ atom_retention_details["reaction_center"],
219
+ atom_number,
220
+ )
221
+
222
+ # clean environment atoms
223
+ if not all(
224
+ atom_retention_details["environment"].values()
225
+ ): # if everything True, we keep all marks
226
+ local_environment_atoms = (
227
+ set(rule_molecule.atoms_numbers) - reaction_center_atoms
228
+ )
229
+ for atom_number in local_environment_atoms:
230
+ query_rule_molecule = clean_atom(
231
+ query_rule_molecule,
232
+ atom_retention_details["environment"],
233
+ atom_number,
234
+ )
235
+
236
+ cleaned_rule_molecules.append(query_rule_molecule)
237
+ break
238
+
239
+ return cleaned_rule_molecules
240
+
241
+
242
+ def clean_atom(
243
+ query_molecule: QueryContainer,
244
+ attributes_to_keep: Dict[str, bool],
245
+ atom_number: int,
246
+ ) -> QueryContainer:
247
+ """
248
+ Removes specified information from a given atom in a query molecule.
249
+
250
+ :param query_molecule: The QueryContainer of molecule.
251
+ :param attributes_to_keep: Dictionary indicating which attributes to keep in the atom. The keys should be strings
252
+ representing the attribute names, and the values should be booleans indicating whether
253
+ to retain (True) or remove(False) that attribute. Expected keys are:
254
+ - "neighbors": Indicates if neighbors of the atom should be removed.
255
+ - "hybridization": Indicates if hybridization information of the atom should be removed.
256
+ - "implicit_hydrogens": Indicates if implicit hydrogen information of the atom should be removed.
257
+ - "ring_sizes": Indicates if ring size information of the atom should be removed.
258
+
259
+ :param atom_number: The number of the atom to be modified in the query molecule.
260
+
261
+ """
262
+
263
+ target_atom = query_molecule.atom(atom_number)
264
+
265
+ if not attributes_to_keep["neighbors"]:
266
+ target_atom.neighbors = None
267
+ if not attributes_to_keep["hybridization"]:
268
+ target_atom.hybridization = None
269
+ if not attributes_to_keep["implicit_hydrogens"]:
270
+ target_atom.implicit_hydrogens = None
271
+ if not attributes_to_keep["ring_sizes"]:
272
+ target_atom.ring_sizes = None
273
+
274
+ return query_molecule
275
+
276
+
277
+ def create_substructures_and_reagents(
278
+ reaction: ReactionContainer,
279
+ rule_atoms: Set[int],
280
+ as_query_container: bool,
281
+ keep_reagents: bool,
282
+ ) -> Tuple[List[MoleculeContainer], List[MoleculeContainer], List]:
283
+ """
284
+ Creates substructures for reactants and products, and optionally includes
285
+ reagents, based on specified parameters. The function processes the reaction to
286
+ create substructures for reactants and products based on the rule atoms. It also
287
+ handles the inclusion of reagents based on the keep_reagents flag and converts these
288
+ structures to query containers if required.
289
+
290
+ :param reaction: The reaction object (ReactionContainer) from which to extract substructures.
291
+ This object represents a chemical reaction with specified reactants, products, and possibly reagents.
292
+ :param rule_atoms: A set of atom id corresponding to the rule atoms. These are used to identify relevant
293
+ substructures in reactants and products.
294
+ :param as_query_container: A boolean flag indicating whether the substructures should be converted to query containers.
295
+ Query containers are used for pattern matching in chemical structures.
296
+ :param keep_reagents: A boolean flag indicating whether reagents should be included in the resulting structures.
297
+ Reagents are additional substances that are present in the reaction but are not reactants or products.
298
+
299
+ :return: A tuple containing three elements:
300
+ - A list of reactant substructures, each corresponding to a part of the reactants that matches the rule atoms.
301
+ - A list of product substructures, each corresponding to a part of the products that matches the rule atoms.
302
+ - A list of reagents, included as is or as substructures, depending on the as_query_container flag.
303
+
304
+ """
305
+ reactant_substructures = [
306
+ reactant.substructure(rule_atoms.intersection(reactant.atoms_numbers))
307
+ for reactant in reaction.reactants
308
+ if rule_atoms.intersection(reactant.atoms_numbers)
309
+ ]
310
+
311
+ product_substructures = [
312
+ product.substructure(rule_atoms.intersection(product.atoms_numbers))
313
+ for product in reaction.products
314
+ if rule_atoms.intersection(product.atoms_numbers)
315
+ ]
316
+
317
+ reagents = []
318
+ if keep_reagents:
319
+ if as_query_container:
320
+ reagents = [
321
+ reagent.substructure(reagent, as_query=True)
322
+ for reagent in reaction.reagents
323
+ ]
324
+ else:
325
+ reagents = reaction.reagents
326
+
327
+ return reactant_substructures, product_substructures, reagents
328
+
329
+
330
+ def assemble_final_rule(
331
+ reactant_substructures: List[QueryContainer],
332
+ product_substructures: List[QueryContainer],
333
+ reagents: List,
334
+ meta_debug: Dict[str, Set],
335
+ keep_metadata: bool,
336
+ reaction: ReactionContainer,
337
+ ) -> ReactionContainer:
338
+ """
339
+ Assembles the final reaction rule from the provided substructures and metadata.
340
+ This function brings together the various components of a reaction rule, including
341
+ reactant and product substructures, reagents, and metadata. It creates a
342
+ comprehensive representation of the reaction rule, which can be used for further
343
+ processing or analysis.
344
+
345
+ :param reactant_substructures: A list of substructures derived from the reactants of
346
+ the reaction. These substructures represent parts of reactants that are relevant
347
+ to the rule.
348
+ :param product_substructures: A list of substructures derived from the products of
349
+ the reaction. These substructures represent parts of products that are relevant
350
+ to the rule.
351
+ :param reagents: A list of reagents involved in the reaction. These may be included
352
+ as-is or as substructures, depending on earlier processing steps.
353
+ :param meta_debug: A dictionary containing additional metadata about the reaction,
354
+ such as leaving and incoming groups.
355
+ :param keep_metadata: A boolean flag indicating whether to retain the metadata
356
+ associated with the reaction in the rule.
357
+ :param reaction: The original reaction object (ReactionContainer) from which the
358
+ rule is being created.
359
+
360
+ :return: A ReactionContainer object representing the assembled reaction rule. This
361
+ container includes the reactant and product substructures, reagents, and any
362
+ additional metadata if keep_metadata is True.
363
+
364
+ """
365
+
366
+ rule_metadata = meta_debug if keep_metadata else {}
367
+ rule_metadata.update(reaction.meta if keep_metadata else {})
368
+
369
+ rule = ReactionContainer(
370
+ reactant_substructures, product_substructures, reagents, rule_metadata
371
+ )
372
+
373
+ if keep_metadata:
374
+ rule.name = reaction.name
375
+
376
+ rule.flush_cache()
377
+ return rule
378
+
379
+
380
+ def validate_rule(rule: ReactionContainer, reaction: ReactionContainer) -> bool:
381
+ """
382
+ Validates a reaction rule by ensuring it can correctly generate the products from
383
+ the reactants. The function uses a chemical reactor to simulate the reaction based
384
+ on the provided rule. It then compares the products generated by the simulation with
385
+ the actual products of the reaction. If they match, the rule is considered valid. If
386
+ not, a ValueError is raised, indicating an issue with the rule.
387
+
388
+ :param rule: The reaction rule to be validated. This is a ReactionContainer object
389
+ representing a chemical reaction rule, which includes the necessary information
390
+ to perform a reaction.
391
+ :param reaction: The original reaction object (ReactionContainer) against which the
392
+ rule is to be validated. This object contains the actual reactants and products
393
+ of the reaction.
394
+
395
+ :return: The validated rule if the rule correctly generates the products from the
396
+ reactants.
397
+
398
+ :raises ValueError: If the rule does not correctly generate the products from the
399
+ reactants, indicating an incorrect or incomplete rule.
400
+
401
+ """
402
+
403
+ # create a reactor with the given rule
404
+ reactor = Reactor(rule)
405
+ try:
406
+ for result_reaction in reactor(reaction.reactants):
407
+ result_products = []
408
+ for result_product in result_reaction.products:
409
+ tmp = result_product.copy()
410
+ try:
411
+ tmp.kekule()
412
+ if tmp.check_valence():
413
+ continue
414
+ except InvalidAromaticRing:
415
+ continue
416
+ result_products.append(result_product)
417
+ if set(reaction.products) == set(result_products) and len(
418
+ reaction.products
419
+ ) == len(result_products):
420
+ return True
421
+
422
+ except (KeyError, IndexError):
423
+ # KeyError - iteration over reactor is finished and products are different from the original reaction
424
+ # IndexError - mistake in __contract_ions, possibly problems with charges in reaction rule
425
+ return False
426
+
427
+ return False
428
+
429
+
430
+ def create_rule(
431
+ config: RuleExtractionConfig, reaction: ReactionContainer
432
+ ) -> ReactionContainer:
433
+ """
434
+ Creates a reaction rule from a given reaction based on the specified
435
+ configuration. The function processes the reaction to create a rule that matches the
436
+ configuration settings. It handles the inclusion of environmental atoms, functional
437
+ groups, ring structures, and leaving and incoming groups. It also constructs
438
+ substructures for reactants, products, and reagents, and cleans molecule
439
+ representations if required. Optionally, it validates the rule using a reactor.
440
+
441
+ :param config: An instance of ExtractRuleConfig, containing various settings that
442
+ determine how the rule is created, such as environmental atom count, inclusion
443
+ of functional groups, rings, leaving and incoming groups, and other parameters.
444
+ :param reaction: The reaction object (ReactionContainer) from which to create the
445
+ rule. This object represents a chemical reaction with specified reactants,
446
+ products, and possibly reagents.
447
+ :return: A ReactionContainer object representing the extracted reaction rule. This
448
+ rule includes various elements of the reaction as specified by the
449
+ configuration, such as reaction centers, environmental atoms, functional groups,
450
+ and others.
451
+
452
+ """
453
+
454
+ # 1. create reaction CGR
455
+ cgr = ~reaction
456
+ center_atoms = set(cgr.center_atoms)
457
+
458
+ # 2. add atoms of reaction environment based on config settings
459
+ center_atoms = add_environment_atoms(
460
+ cgr, center_atoms, config.environment_atom_count
461
+ )
462
+
463
+ # 3. include functional groups in the rule if specified in config
464
+ if config.include_func_groups and config.func_groups_list:
465
+ rule_atoms = add_functional_groups(
466
+ reaction, center_atoms, config.func_groups_list
467
+ )
468
+ else:
469
+ rule_atoms = center_atoms.copy()
470
+
471
+ # 4. include ring structures in the rule if specified in config
472
+ if config.include_rings:
473
+ rule_atoms = add_ring_structures(cgr, rule_atoms)
474
+
475
+ # 5. add leaving and incoming groups to the rule based on config settings
476
+ rule_atoms, meta_debug = add_leaving_incoming_groups(
477
+ reaction, rule_atoms, config.keep_leaving_groups, config.keep_incoming_groups
478
+ )
479
+
480
+ # 6. create substructures for reactants, products, and reagents
481
+ reactant_substructures, product_substructures, reagents = (
482
+ create_substructures_and_reagents(
483
+ reaction, rule_atoms, config.as_query_container, config.keep_reagents
484
+ )
485
+ )
486
+ # 7. clean atom marks in the molecules if they are being converted to query containers
487
+ if config.as_query_container:
488
+ reactant_substructures = clean_molecules(
489
+ reactant_substructures,
490
+ reaction.reactants,
491
+ center_atoms,
492
+ config.atom_info_retention,
493
+ )
494
+
495
+ product_substructures = clean_molecules(
496
+ product_substructures,
497
+ reaction.products,
498
+ center_atoms,
499
+ config.atom_info_retention,
500
+ )
501
+
502
+ # 8. assemble the final rule including metadata if specified
503
+ rule = assemble_final_rule(
504
+ reactant_substructures,
505
+ product_substructures,
506
+ reagents,
507
+ meta_debug,
508
+ config.keep_metadata,
509
+ reaction,
510
+ )
511
+
512
+ # 9. reverse extracted reaction rule and reaction
513
+ if config.reverse_rule:
514
+ rule = reverse_reaction(rule)
515
+ reaction = reverse_reaction(reaction)
516
+
517
+ # 10. validate the rule using a reactor if validation is enabled in config
518
+ if config.reactor_validation:
519
+ if validate_rule(rule, reaction):
520
+ rule.meta["reactor_validation"] = "passed"
521
+ else:
522
+ rule.meta["reactor_validation"] = "failed"
523
+
524
+ return rule
525
+
526
+
527
+ def extract_rules(
528
+ config: RuleExtractionConfig, reaction: ReactionContainer
529
+ ) -> List[ReactionContainer]:
530
+ """
531
+ Extracts reaction rules from a given reaction based on the specified
532
+ configuration.
533
+
534
+ :param config: An instance of ExtractRuleConfig, which contains various
535
+ configuration settings for rule extraction, such as whether to include
536
+ multicenter rules, functional groups, ring structures, leaving and incoming
537
+ groups, etc.
538
+ :param reaction: The reaction object (ReactionContainer) from which to extract
539
+ rules. The reaction object represents a chemical reaction with specified
540
+ reactants, products, and possibly reagents.
541
+ :return: A list of ReactionContainer objects, each representing a distinct reaction
542
+ rule. If config.multicenter_rules is True, a single rule encompassing all
543
+ reaction centers is returned. Otherwise, separate rules for each reaction center
544
+ are extracted, up to a maximum of 15 distinct centers.
545
+
546
+ """
547
+
548
+ standardizer = (
549
+ RemoveReagentsStandardizer()
550
+ ) # reagents are needed if they are the part of reaction rule specification
551
+ reaction = standardizer(reaction)
552
+
553
+ if config.multicenter_rules:
554
+ # extract a single rule encompassing all reaction centers
555
+ return [create_rule(config, reaction)]
556
+
557
+ # extract separate rules for each distinct reaction center
558
+ distinct_rules = set()
559
+ for center_reaction in islice(reaction.enumerate_centers(), 15):
560
+ single_rule = create_rule(config, center_reaction)
561
+ distinct_rules.add(single_rule)
562
+
563
+ return list(distinct_rules)
564
+
565
+
566
+ @ray.remote
567
+ def process_reaction_batch(
568
+ batch: List[Tuple[int, ReactionContainer]], config: RuleExtractionConfig
569
+ ) -> List[Tuple[int, List[ReactionContainer]]]:
570
+ """
571
+ Processes a batch of reactions to extract reaction rules based on the given
572
+ configuration. This function operates as a remote task in a distributed system using
573
+ Ray. It takes a batch of reactions, where each reaction is paired with an index. For
574
+ each reaction in the batch, it extracts reaction rules as specified by the
575
+ configuration object. The extracted rules for each reaction are then returned along
576
+ with the corresponding index. This function is intended to be used in a distributed
577
+ manner with Ray to parallelize the rule extraction process across multiple
578
+ reactions.
579
+
580
+ :param batch: A list where each element is a tuple containing an index (int) and a
581
+ ReactionContainer object. The index is typically used to keep track of the
582
+ reaction's position in a larger dataset.
583
+ :param config: An instance of ExtractRuleConfig that provides settings and
584
+ parameters for the rule extraction process.
585
+ :return: A list where each element is a tuple. The first element of the tuple is an
586
+ index (int), and the second is a list of ReactionContainer objects representing
587
+ the extracted rules for the corresponding reaction.
588
+
589
+ """
590
+
591
+ extracted_rules_list = []
592
+ for index, reaction in batch:
593
+ try:
594
+ extracted_rules = extract_rules(config, reaction)
595
+ extracted_rules_list.append((index, extracted_rules))
596
+ except Exception as e:
597
+ logging.debug(e)
598
+ continue
599
+ return extracted_rules_list
600
+
601
+
602
+ def process_completed_batch(
603
+ futures: Dict,
604
+ rules_statistics: Dict,
605
+ ) -> None:
606
+ """
607
+ Processes completed batches of reactions, updating the rules statistics and
608
+ writing rules to a file. This function waits for the completion of a batch of
609
+ reactions processed in parallel (using Ray), updates the statistics for each
610
+ extracted rule, and writes the rules to a result file if they are new. It also
611
+ updates the progress bar with the size of the processed batch.
612
+
613
+ :param futures: A dictionary of futures representing ongoing batch processing tasks.
614
+ :param rules_statistics: A dictionary to keep track of statistics for each rule.
615
+ :return: None
616
+
617
+ """
618
+
619
+ ready_id, running_id = ray.wait(list(futures.keys()), num_returns=1)
620
+ completed_batch = ray.get(ready_id[0])
621
+ for index, extracted_rules in completed_batch:
622
+ for rule in extracted_rules:
623
+ prev_stats_len = len(rules_statistics)
624
+ rules_statistics[rule].append(index)
625
+ if len(rules_statistics) != prev_stats_len:
626
+ rule.meta["first_reaction_index"] = index
627
+
628
+ del futures[ready_id[0]]
629
+
630
+
631
+ def sort_rules(
632
+ rules_stats: Dict, min_popularity: int, single_reactant_only: bool
633
+ ) -> List[Tuple[ReactionContainer, List[int]]]:
634
+ """
635
+ Sorts reaction rules based on their popularity and validation status. This
636
+ function sorts the given rules according to their popularity (i.e., the number of
637
+ times they have been applied) and filters out rules that haven't passed reactor
638
+ validation or are less popular than the specified minimum popularity threshold.
639
+
640
+ :param rules_stats: A dictionary where each key is a reaction rule and the value is
641
+ a list of integers. Each integer represents an index where the rule was applied.
642
+ :type rules_stats: The number of occurrence of the reaction rules.
643
+ :param min_popularity: The minimum number of times a rule must be applied to be
644
+ considered. Default is 3.
645
+ :type min_popularity: The minimum number of occurrence of the reaction rule to be
646
+ selected.
647
+ :param single_reactant_only: Whether to keep only reaction rules with a single
648
+ molecule on the right side of reaction arrow. Default is True.
649
+
650
+ :return: A list of tuples, where each tuple contains a reaction rule and a list of
651
+ indices representing the rule's applications. The list is sorted in descending
652
+ order of the rule's popularity.
653
+
654
+ """
655
+
656
+ return sorted(
657
+ (
658
+ (rule, indices)
659
+ for rule, indices in rules_stats.items()
660
+ if len(indices) >= min_popularity
661
+ and rule.meta["reactor_validation"] == "passed"
662
+ and (not single_reactant_only or len(rule.reactants) == 1)
663
+ ),
664
+ key=lambda x: -len(x[1]),
665
+ )
666
+
667
+
668
+ def extract_rules_from_reactions(
669
+ config: RuleExtractionConfig,
670
+ reaction_data_path: str,
671
+ reaction_rules_path: str,
672
+ num_cpus: int,
673
+ batch_size: int,
674
+ ) -> None:
675
+ """
676
+ Extracts reaction rules from a set of reactions based on the given configuration.
677
+ This function initializes a Ray environment for distributed computing and processes
678
+ each reaction in the provided reaction database to extract reaction rules. It
679
+ handles the reactions in batches, parallelize the rule extraction process. Extracted
680
+ rules are written to RDF files and their statistics are recorded. The function also
681
+ sorts the rules based on their popularity and saves the sorted rules.
682
+
683
+ :param config: Configuration settings for rule extraction, including file paths,
684
+ batch size, and other parameters.
685
+ :param reaction_data_path: Path to the file containing reaction database.
686
+ :param reaction_rules_path: Name of the file to store the extracted rules.
687
+ :param num_cpus: Number of CPU cores to use for processing. Defaults to 1.
688
+ :param batch_size: Number of reactions to process in each batch. Defaults to 10.
689
+ :return: None
690
+
691
+ """
692
+
693
+ ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
694
+
695
+ reaction_rules_path, _ = splitext(reaction_rules_path)
696
+ with ReactionReader(reaction_data_path) as reactions:
697
+
698
+ futures = {}
699
+ batch = []
700
+ max_concurrent_batches = num_cpus
701
+ extracted_rules_and_statistics = defaultdict(list)
702
+
703
+ for index, reaction in tqdm(
704
+ enumerate(reactions),
705
+ desc="Number of reactions processed: ",
706
+ bar_format="{desc}{n} [{elapsed}]",
707
+ ):
708
+
709
+ # reaction ready to use
710
+ batch.append((index, reaction))
711
+ if len(batch) == batch_size:
712
+ future = process_reaction_batch.remote(batch, config)
713
+
714
+ futures[future] = None
715
+ batch = []
716
+
717
+ while len(futures) >= max_concurrent_batches:
718
+ process_completed_batch(
719
+ futures,
720
+ extracted_rules_and_statistics,
721
+ )
722
+
723
+ if batch:
724
+ future = process_reaction_batch.remote(batch, config)
725
+ futures[future] = None
726
+
727
+ while futures:
728
+ process_completed_batch(
729
+ futures,
730
+ extracted_rules_and_statistics,
731
+ )
732
+
733
+ sorted_rules = sort_rules(
734
+ extracted_rules_and_statistics,
735
+ min_popularity=config.min_popularity,
736
+ single_reactant_only=config.single_reactant_only,
737
+ )
738
+
739
+ ray.shutdown()
740
+
741
+ with open(f"{reaction_rules_path}.pickle", "wb") as statistics_file:
742
+ pickle.dump(sorted_rules, statistics_file)
743
+
744
+ print(f"Number of extracted reaction rules: {len(sorted_rules)}")
synplan/chem/reaction_rules/manual/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .decompositions import rules as d_rules
2
+ from .transformations import rules as t_rules
3
+
4
+ hardcoded_rules = t_rules + d_rules
5
+
6
+ __all__ = ["hardcoded_rules"]
synplan/chem/reaction_rules/manual/decompositions.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing hardcoded decomposition reaction rules."""
2
+
3
+ from CGRtools import QueryContainer, ReactionContainer
4
+ from CGRtools.periodictable import ListElement
5
+
6
+ rules = []
7
+
8
+
9
+ def prepare():
10
+ """Creates and returns three query containers and appends a reaction container to
11
+ the "rules" list."""
12
+ q_ = QueryContainer()
13
+ p1_ = QueryContainer()
14
+ p2_ = QueryContainer()
15
+ rules.append(ReactionContainer((q_,), (p1_, p2_)))
16
+
17
+ return q_, p1_, p2_
18
+
19
+
20
+ # R-amide/ester formation
21
+ # [C](-[N,O;D23;Zs])(-[C])=[O]>>[A].[C]-[C](-[O])=[O]
22
+ q, p1, p2 = prepare()
23
+ q.add_atom("C")
24
+ q.add_atom("C")
25
+ q.add_atom("O")
26
+ q.add_atom(ListElement(["N", "O"]), hybridization=1, neighbors=(2, 3))
27
+ q.add_bond(1, 2, 1)
28
+ q.add_bond(2, 3, 2)
29
+ q.add_bond(2, 4, 1)
30
+
31
+ p1.add_atom("C")
32
+ p1.add_atom("C")
33
+ p1.add_atom("O")
34
+ p1.add_atom("O", _map=5)
35
+ p1.add_bond(1, 2, 1)
36
+ p1.add_bond(2, 3, 2)
37
+ p1.add_bond(2, 5, 1)
38
+
39
+ p2.add_atom("A", _map=4)
40
+
41
+ # acyl group addition with aromatic carbon's case (Friedel-Crafts)
42
+ # [C;Za]-[C](-[C])=[O]>>[C].[C]-[C](-[Cl])=[O]
43
+ q, p1, p2 = prepare()
44
+ q.add_atom("C")
45
+ q.add_atom("C")
46
+ q.add_atom("O")
47
+ q.add_atom("C", hybridization=4)
48
+ q.add_bond(1, 2, 1)
49
+ q.add_bond(2, 3, 2)
50
+ q.add_bond(2, 4, 1)
51
+
52
+ p1.add_atom("C")
53
+ p1.add_atom("C")
54
+ p1.add_atom("O")
55
+ p1.add_atom("Cl", _map=5)
56
+ p1.add_bond(1, 2, 1)
57
+ p1.add_bond(2, 3, 2)
58
+ p1.add_bond(2, 5, 1)
59
+
60
+ p2.add_atom("C", _map=4)
61
+
62
+ # Williamson reaction
63
+ # [C;Za]-[O]-[C;Zs;W0]>>[C]-[Br].[C]-[O]
64
+ q, p1, p2 = prepare()
65
+ q.add_atom("C", hybridization=4)
66
+ q.add_atom("O")
67
+ q.add_atom("C", hybridization=1, heteroatoms=1)
68
+ q.add_bond(1, 2, 1)
69
+ q.add_bond(2, 3, 1)
70
+
71
+ p1.add_atom("C")
72
+ p1.add_atom("O")
73
+ p1.add_bond(1, 2, 1)
74
+
75
+ p2.add_atom("C", _map=3)
76
+ p2.add_atom("Br")
77
+ p2.add_bond(3, 4, 1)
78
+
79
+ # Buchwald-Hartwig amination
80
+ # [N;D23;Zs;W0]-[C;Za]>>[C]-[Br].[N]
81
+ q, p1, p2 = prepare()
82
+ q.add_atom("N", heteroatoms=0, hybridization=1, neighbors=(2, 3))
83
+ q.add_atom("C", hybridization=4)
84
+ q.add_bond(1, 2, 1)
85
+
86
+ p1.add_atom("C", _map=2)
87
+ p1.add_atom("Br")
88
+ p1.add_bond(2, 3, 1)
89
+
90
+ p2.add_atom("N")
91
+
92
+ # imidazole imine atom's alkylation
93
+ # [C;r5](:[N;r5]-[C;Zs;W1]):[N;D2;r5]>>[C]-[Br].[N]:[C]:[N]
94
+ q, p1, p2 = prepare()
95
+ q.add_atom("N", rings_sizes=5)
96
+ q.add_atom("C", rings_sizes=5)
97
+ q.add_atom("N", rings_sizes=5, neighbors=2)
98
+ q.add_atom("C", hybridization=1, heteroatoms=(1, 2))
99
+ q.add_bond(1, 2, 4)
100
+ q.add_bond(2, 3, 4)
101
+ q.add_bond(1, 4, 1)
102
+
103
+ p1.add_atom("N")
104
+ p1.add_atom("C")
105
+ p1.add_atom("N")
106
+ p1.add_bond(1, 2, 4)
107
+ p1.add_bond(2, 3, 4)
108
+
109
+ p2.add_atom("C", _map=4)
110
+ p2.add_atom("Br")
111
+ p2.add_bond(4, 5, 1)
112
+
113
+ # Knoevenagel condensation (nitryl and carboxyl case)
114
+ # [C]=[C](-[C]#[N])-[C](-[O])=[O]>>[C]=[O].[C](-[C]#[N])-[C](-[O])=[O]
115
+ q, p1, p2 = prepare()
116
+ q.add_atom("C")
117
+ q.add_atom("C")
118
+ q.add_atom("C")
119
+ q.add_atom("N")
120
+ q.add_atom("C")
121
+ q.add_atom("O")
122
+ q.add_atom("O")
123
+ q.add_bond(1, 2, 2)
124
+ q.add_bond(2, 3, 1)
125
+ q.add_bond(3, 4, 3)
126
+ q.add_bond(2, 5, 1)
127
+ q.add_bond(5, 6, 2)
128
+ q.add_bond(5, 7, 1)
129
+
130
+ p1.add_atom("C", _map=2)
131
+ p1.add_atom("C")
132
+ p1.add_atom("N")
133
+ p1.add_atom("C")
134
+ p1.add_atom("O")
135
+ p1.add_atom("O")
136
+ p1.add_bond(2, 3, 1)
137
+ p1.add_bond(3, 4, 3)
138
+ p1.add_bond(2, 5, 1)
139
+ p1.add_bond(5, 6, 2)
140
+ p1.add_bond(5, 7, 1)
141
+
142
+ p2.add_atom("C", _map=1)
143
+ p2.add_atom("O", _map=8)
144
+ p2.add_bond(1, 8, 2)
145
+
146
+ # Knoevenagel condensation (double nitryl case)
147
+ # [C]=[C](-[C]#[N])-[C]#[N]>>[C]=[O].[C](-[C]#[N])-[C]#[N]
148
+ q, p1, p2 = prepare()
149
+ q.add_atom("C")
150
+ q.add_atom("C")
151
+ q.add_atom("C")
152
+ q.add_atom("N")
153
+ q.add_atom("C")
154
+ q.add_atom("N")
155
+ q.add_bond(1, 2, 2)
156
+ q.add_bond(2, 3, 1)
157
+ q.add_bond(3, 4, 3)
158
+ q.add_bond(2, 5, 1)
159
+ q.add_bond(5, 6, 3)
160
+
161
+ p1.add_atom("C", _map=2)
162
+ p1.add_atom("C")
163
+ p1.add_atom("N")
164
+ p1.add_atom("C")
165
+ p1.add_atom("N")
166
+ p1.add_bond(2, 3, 1)
167
+ p1.add_bond(3, 4, 3)
168
+ p1.add_bond(2, 5, 1)
169
+ p1.add_bond(5, 6, 3)
170
+
171
+ p2.add_atom("C", _map=1)
172
+ p2.add_atom("O", _map=8)
173
+ p2.add_bond(1, 8, 2)
174
+
175
+ # Knoevenagel condensation (double carboxyl case)
176
+ # [C]=[C](-[C](-[O])=[O])-[C](-[O])=[O]>>[C]=[O].[C](-[C](-[O])=[O])-[C](-[O])=[O]
177
+ q, p1, p2 = prepare()
178
+ q.add_atom("C")
179
+ q.add_atom("C")
180
+ q.add_atom("C")
181
+ q.add_atom("O")
182
+ q.add_atom("O")
183
+ q.add_atom("C")
184
+ q.add_atom("O")
185
+ q.add_atom("O")
186
+ q.add_bond(1, 2, 2)
187
+ q.add_bond(2, 3, 1)
188
+ q.add_bond(3, 4, 2)
189
+ q.add_bond(3, 5, 1)
190
+ q.add_bond(2, 6, 1)
191
+ q.add_bond(6, 7, 2)
192
+ q.add_bond(6, 8, 1)
193
+
194
+ p1.add_atom("C", _map=2)
195
+ p1.add_atom("C")
196
+ p1.add_atom("O")
197
+ p1.add_atom("O")
198
+ p1.add_atom("C")
199
+ p1.add_atom("O")
200
+ p1.add_atom("O")
201
+ p1.add_bond(2, 3, 1)
202
+ p1.add_bond(3, 4, 2)
203
+ p1.add_bond(3, 5, 1)
204
+ p1.add_bond(2, 6, 1)
205
+ p1.add_bond(6, 7, 2)
206
+ p1.add_bond(6, 8, 1)
207
+
208
+ p2.add_atom("C", _map=1)
209
+ p2.add_atom("O", _map=9)
210
+ p2.add_bond(1, 9, 2)
211
+
212
+ # heterocyclization with guanidine
213
+ # [c]((-[N;W0;Zs])@[n]@[c](-[N;D1])@[c;W0])@[n]@[c]-[O; D1]>>[C](-[N])(=[N])-[N].[C](#[N])-[C]-[C](-[O])=[O]
214
+ q, p1, p2 = prepare()
215
+ q.add_atom("C")
216
+ q.add_atom("N", heteroatoms=0, hybridization=1)
217
+ q.add_atom("N")
218
+ q.add_atom("C")
219
+ q.add_atom("N", neighbors=1)
220
+ q.add_atom("C", heteroatoms=0)
221
+ q.add_atom("N")
222
+ q.add_atom("C")
223
+ q.add_atom("O", neighbors=1)
224
+ q.add_bond(1, 2, 1)
225
+ q.add_bond(1, 3, 4)
226
+ q.add_bond(3, 4, 4)
227
+ q.add_bond(4, 5, 1)
228
+ q.add_bond(4, 6, 4)
229
+ q.add_bond(1, 7, 4)
230
+ q.add_bond(7, 8, 4)
231
+ q.add_bond(8, 9, 1)
232
+
233
+ p1.add_atom("C")
234
+ p1.add_atom("N")
235
+ p1.add_atom("N")
236
+ p1.add_atom("N", _map=7)
237
+ p1.add_bond(1, 2, 1)
238
+ p1.add_bond(1, 3, 2)
239
+ p1.add_bond(1, 7, 1)
240
+
241
+ p2.add_atom("C", _map=4)
242
+ p2.add_atom("N")
243
+ p2.add_atom("C")
244
+ p2.add_atom("C", _map=8)
245
+ p2.add_atom("O", _map=9)
246
+ p2.add_atom("O")
247
+ p2.add_bond(4, 5, 3)
248
+ p2.add_bond(4, 6, 1)
249
+ p2.add_bond(6, 8, 1)
250
+ p2.add_bond(8, 9, 2)
251
+ p2.add_bond(8, 10, 1)
252
+
253
+ # alkylation of amine
254
+ # [C]-[N]-[C]>>[C]-[N].[C]-[Br]
255
+ q, p1, p2 = prepare()
256
+ q.add_atom("C")
257
+ q.add_atom("N")
258
+ q.add_atom("C")
259
+ q.add_atom("C")
260
+ q.add_bond(1, 2, 1)
261
+ q.add_bond(2, 3, 1)
262
+ q.add_bond(2, 4, 1)
263
+
264
+ p1.add_atom("C")
265
+ p1.add_atom("N")
266
+ p1.add_atom("C")
267
+ p1.add_bond(1, 2, 1)
268
+ p1.add_bond(2, 3, 1)
269
+
270
+ p2.add_atom("C", _map=4)
271
+ p2.add_atom("Cl")
272
+ p2.add_bond(4, 5, 1)
273
+
274
+ # Synthesis of guanidines
275
+ #
276
+ q, p1, p2 = prepare()
277
+ q.add_atom("N")
278
+ q.add_atom("C")
279
+ q.add_atom("N", hybridization=1)
280
+ q.add_atom("N", hybridization=1)
281
+ q.add_bond(1, 2, 2)
282
+ q.add_bond(2, 3, 1)
283
+ q.add_bond(2, 4, 1)
284
+
285
+ p1.add_atom("N")
286
+ p1.add_atom("C")
287
+ p1.add_atom("N")
288
+ p1.add_bond(1, 2, 3)
289
+ p1.add_bond(2, 3, 1)
290
+
291
+ p2.add_atom("N", _map=4)
292
+
293
+ # Grignard reaction with nitrile
294
+ #
295
+ q, p1, p2 = prepare()
296
+ q.add_atom("C")
297
+ q.add_atom("C")
298
+ q.add_atom("O")
299
+ q.add_atom("C")
300
+ q.add_bond(1, 2, 1)
301
+ q.add_bond(2, 3, 2)
302
+ q.add_bond(2, 4, 1)
303
+
304
+ p1.add_atom("C")
305
+ p1.add_atom("C")
306
+ p1.add_atom("N")
307
+ p1.add_bond(1, 2, 1)
308
+ p1.add_bond(2, 3, 3)
309
+
310
+ p2.add_atom("C", _map=4)
311
+ p2.add_atom("Br")
312
+ p2.add_bond(4, 5, 1)
313
+
314
+ # Alkylation of alpha-carbon atom of nitrile
315
+ #
316
+ q, p1, p2 = prepare()
317
+ q.add_atom("N")
318
+ q.add_atom("C")
319
+ q.add_atom("C", neighbors=(3, 4))
320
+ q.add_atom("C", hybridization=1)
321
+ q.add_bond(1, 2, 3)
322
+ q.add_bond(2, 3, 1)
323
+ q.add_bond(3, 4, 1)
324
+
325
+ p1.add_atom("N")
326
+ p1.add_atom("C")
327
+ p1.add_atom("C")
328
+ p1.add_bond(1, 2, 3)
329
+ p1.add_bond(2, 3, 1)
330
+
331
+ p2.add_atom("C", _map=4)
332
+ p2.add_atom("Cl")
333
+ p2.add_bond(4, 5, 1)
334
+
335
+ # Gomberg-Bachmann reaction
336
+ #
337
+ q, p1, p2 = prepare()
338
+ q.add_atom("C", hybridization=4, heteroatoms=0)
339
+ q.add_atom("C", hybridization=4, heteroatoms=0)
340
+ q.add_bond(1, 2, 1)
341
+
342
+ p1.add_atom("C")
343
+ p1.add_atom("N", _map=3)
344
+ p1.add_bond(1, 3, 1)
345
+
346
+ p2.add_atom("C", _map=2)
347
+
348
+ # Cyclocondensation
349
+ #
350
+ q, p1, p2 = prepare()
351
+ q.add_atom("N", neighbors=2)
352
+ q.add_atom("C")
353
+ q.add_atom("C")
354
+ q.add_atom("C")
355
+ q.add_atom("N")
356
+ q.add_atom("C")
357
+ q.add_atom("C")
358
+ q.add_atom("O", neighbors=1)
359
+ q.add_bond(1, 2, 1)
360
+ q.add_bond(2, 3, 1)
361
+ q.add_bond(3, 4, 1)
362
+ q.add_bond(4, 5, 2)
363
+ q.add_bond(5, 6, 1)
364
+ q.add_bond(6, 7, 1)
365
+ q.add_bond(7, 8, 2)
366
+ q.add_bond(1, 7, 1)
367
+
368
+ p1.add_atom("N")
369
+ p1.add_atom("C")
370
+ p1.add_atom("C")
371
+ p1.add_atom("C")
372
+ p1.add_atom("O", _map=9)
373
+ p1.add_bond(1, 2, 1)
374
+ p1.add_bond(2, 3, 1)
375
+ p1.add_bond(3, 4, 1)
376
+ p1.add_bond(4, 9, 2)
377
+
378
+ p2.add_atom("N", _map=5)
379
+ p2.add_atom("C")
380
+ p2.add_atom("C")
381
+ p2.add_atom("O")
382
+ p2.add_atom("O", _map=10)
383
+ p2.add_bond(5, 6, 1)
384
+ p2.add_bond(6, 7, 1)
385
+ p2.add_bond(7, 8, 2)
386
+ p2.add_bond(7, 10, 1)
387
+
388
+ # heterocyclization dicarboxylic acids
389
+ #
390
+ q, p1, p2 = prepare()
391
+ q.add_atom("C", rings_sizes=(5, 6))
392
+ q.add_atom("O")
393
+ q.add_atom(ListElement(["O", "N"]))
394
+ q.add_atom("C", rings_sizes=(5, 6))
395
+ q.add_atom("O")
396
+ q.add_bond(1, 2, 2)
397
+ q.add_bond(1, 3, 1)
398
+ q.add_bond(3, 4, 1)
399
+ q.add_bond(4, 5, 2)
400
+
401
+ p1.add_atom("C")
402
+ p1.add_atom("O")
403
+ p1.add_atom("O", _map=6)
404
+ p1.add_bond(1, 2, 2)
405
+ p1.add_bond(1, 6, 1)
406
+
407
+ p2.add_atom("C", _map=4)
408
+ p2.add_atom("O")
409
+ p2.add_atom("O", _map=7)
410
+ p2.add_bond(4, 5, 2)
411
+ p2.add_bond(4, 7, 1)
412
+
413
+ __all__ = ["rules"]
synplan/chem/reaction_rules/manual/transformations.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing hardcoded transformation reaction rules."""
2
+
3
+ from CGRtools import QueryContainer, ReactionContainer
4
+ from CGRtools.periodictable import ListElement
5
+
6
+ rules = []
7
+
8
+
9
+ def prepare():
10
+ """Creates and returns three query containers and appends a reaction container to
11
+ the "rules" list."""
12
+ q_ = QueryContainer()
13
+ p_ = QueryContainer()
14
+ rules.append(ReactionContainer((q_,), (p_,)))
15
+ return q_, p_
16
+
17
+
18
+ # aryl nitro reduction
19
+ # [C;Za;W1]-[N;D1]>>[O-]-[N+](-[C])=[O]
20
+ q, p = prepare()
21
+ q.add_atom("N", neighbors=1)
22
+ q.add_atom("C", hybridization=4, heteroatoms=1)
23
+ q.add_bond(1, 2, 1)
24
+
25
+ p.add_atom("N", charge=1)
26
+ p.add_atom("C")
27
+ p.add_atom("O", charge=-1)
28
+ p.add_atom("O")
29
+ p.add_bond(1, 2, 1)
30
+ p.add_bond(1, 3, 1)
31
+ p.add_bond(1, 4, 2)
32
+
33
+ # aryl nitration
34
+ # [O-]-[N+](=[O])-[C;Za;W12]>>[C]
35
+ q, p = prepare()
36
+ q.add_atom("N", charge=1)
37
+ q.add_atom("C", hybridization=4, heteroatoms=(1, 2))
38
+ q.add_atom("O", charge=-1)
39
+ q.add_atom("O")
40
+ q.add_bond(1, 2, 1)
41
+ q.add_bond(1, 3, 1)
42
+ q.add_bond(1, 4, 2)
43
+
44
+ p.add_atom("C", _map=2)
45
+
46
+ # Beckmann rearrangement (oxime -> amide)
47
+ # [C]-[N;D2]-[C]=[O]>>[O]-[N]=[C]-[C]
48
+ q, p = prepare()
49
+ q.add_atom("C")
50
+ q.add_atom("N", neighbors=2)
51
+ q.add_atom("O")
52
+ q.add_atom("C")
53
+ q.add_bond(1, 2, 1)
54
+ q.add_bond(1, 3, 2)
55
+ q.add_bond(2, 4, 1)
56
+
57
+ p.add_atom("C")
58
+ p.add_atom("N")
59
+ p.add_atom("O")
60
+ p.add_atom("C")
61
+ p.add_bond(1, 2, 2)
62
+ p.add_bond(2, 3, 1)
63
+ p.add_bond(1, 4, 1)
64
+
65
+ # aldehydes or ketones into oxime/imine reaction
66
+ # [C;Zd;W1]=[N]>>[C]=[O]
67
+ q, p = prepare()
68
+ q.add_atom("C", hybridization=2, heteroatoms=1)
69
+ q.add_atom("N")
70
+ q.add_bond(1, 2, 2)
71
+
72
+ p.add_atom("C")
73
+ p.add_atom("O", _map=3)
74
+ p.add_bond(1, 3, 2)
75
+
76
+ # addition of halogen atom into phenol ring (orto)
77
+ # [C](-[Cl,F,Br,I;D1]):[C]-[O,N;Zs]>>[C](-[A]):[C]
78
+ q, p = prepare()
79
+ q.add_atom(ListElement(["O", "N"]), hybridization=1)
80
+ q.add_atom("C")
81
+ q.add_atom("C")
82
+ q.add_atom(ListElement(["Cl", "F", "Br", "I"]), neighbors=1)
83
+ q.add_bond(1, 2, 1)
84
+ q.add_bond(2, 3, 4)
85
+ q.add_bond(3, 4, 1)
86
+
87
+ p.add_atom("A")
88
+ p.add_atom("C")
89
+ p.add_atom("C")
90
+ p.add_bond(1, 2, 1)
91
+ p.add_bond(2, 3, 4)
92
+
93
+ # addition of halogen atom into phenol ring (para)
94
+ # [C](:[C]:[C]:[C]-[O,N;Zs])-[Cl,F,Br,I;D1]>>[A]-[C]:[C]:[C]:[C]
95
+ q, p = prepare()
96
+ q.add_atom(ListElement(["O", "N"]), hybridization=1)
97
+ q.add_atom("C")
98
+ q.add_atom("C")
99
+ q.add_atom("C")
100
+ q.add_atom("C")
101
+ q.add_atom(ListElement(["Cl", "F", "Br", "I"]), neighbors=1)
102
+ q.add_bond(1, 2, 1)
103
+ q.add_bond(2, 3, 4)
104
+ q.add_bond(3, 4, 4)
105
+ q.add_bond(4, 5, 4)
106
+ q.add_bond(5, 6, 1)
107
+
108
+ p.add_atom("A")
109
+ p.add_atom("C")
110
+ p.add_atom("C")
111
+ p.add_atom("C")
112
+ p.add_atom("C")
113
+ p.add_bond(1, 2, 1)
114
+ p.add_bond(2, 3, 4)
115
+ p.add_bond(3, 4, 4)
116
+ p.add_bond(4, 5, 4)
117
+
118
+ # hard reduction of Ar-ketones
119
+ # [C;Za]-[C;D2;Zs;W0]>>[C]-[C]=[O]
120
+ q, p = prepare()
121
+ q.add_atom("C", hybridization=4)
122
+ q.add_atom("C", hybridization=1, neighbors=2, heteroatoms=0)
123
+ q.add_bond(1, 2, 1)
124
+
125
+ p.add_atom("C")
126
+ p.add_atom("C")
127
+ p.add_atom("O")
128
+ p.add_bond(1, 2, 1)
129
+ p.add_bond(2, 3, 2)
130
+
131
+ # reduction of alpha-hydroxy pyridine
132
+ # [C;W1]:[N;H0;r6]>>[C](:[N])-[O]
133
+ q, p = prepare()
134
+ q.add_atom("C", heteroatoms=1)
135
+ q.add_atom("N", rings_sizes=6, hydrogens=0)
136
+ q.add_bond(1, 2, 4)
137
+
138
+ p.add_atom("C")
139
+ p.add_atom("N")
140
+ p.add_atom("O")
141
+ p.add_bond(1, 2, 4)
142
+ p.add_bond(1, 3, 1)
143
+
144
+ # Reduction of alkene
145
+ # [C]-[C;D23;Zs;W0]-[C;D123;Zs;W0]>>[C](-[C])=[C]
146
+ q, p = prepare()
147
+ q.add_atom("C")
148
+ q.add_atom("C", heteroatoms=0, neighbors=(2, 3), hybridization=1)
149
+ q.add_atom("C", heteroatoms=0, neighbors=(1, 2, 3), hybridization=1)
150
+ q.add_bond(1, 2, 1)
151
+ q.add_bond(2, 3, 1)
152
+
153
+ p.add_atom("C")
154
+ p.add_atom("C")
155
+ p.add_atom("C")
156
+ p.add_bond(1, 2, 1)
157
+ p.add_bond(2, 3, 2)
158
+
159
+ # Kolbe-Schmitt reaction
160
+ # [C](:[C]-[O;D1])-[C](=[O])-[O;D1]>>[C](-[O]):[C]
161
+ q, p = prepare()
162
+ q.add_atom("O", neighbors=1)
163
+ q.add_atom("C")
164
+ q.add_atom("C")
165
+ q.add_atom("C")
166
+ q.add_atom("O", neighbors=1)
167
+ q.add_atom("O")
168
+ q.add_bond(1, 2, 1)
169
+ q.add_bond(2, 3, 4)
170
+ q.add_bond(3, 4, 1)
171
+ q.add_bond(4, 5, 1)
172
+ q.add_bond(4, 6, 2)
173
+
174
+ p.add_atom("O")
175
+ p.add_atom("C")
176
+ p.add_atom("C")
177
+ p.add_bond(1, 2, 1)
178
+ p.add_bond(2, 3, 4)
179
+
180
+ # reduction of carboxylic acid
181
+ # [O;D1]-[C;D2]-[C]>>[C]-[C](-[O])=[O]
182
+ q, p = prepare()
183
+ q.add_atom("C")
184
+ q.add_atom("C", neighbors=2)
185
+ q.add_atom("O", neighbors=1)
186
+ q.add_bond(1, 2, 1)
187
+ q.add_bond(2, 3, 1)
188
+
189
+ p.add_atom("C")
190
+ p.add_atom("C")
191
+ p.add_atom("O")
192
+ p.add_atom("O")
193
+ p.add_bond(1, 2, 1)
194
+ p.add_bond(2, 3, 1)
195
+ p.add_bond(2, 4, 2)
196
+
197
+ # halogenation of alcohols
198
+ # [C;Zs]-[Cl,Br;D1]>>[C]-[O]
199
+ q, p = prepare()
200
+ q.add_atom("C", hybridization=1, heteroatoms=1)
201
+ q.add_atom(ListElement(["Cl", "Br"]), neighbors=1)
202
+ q.add_bond(1, 2, 1)
203
+
204
+ p.add_atom("C")
205
+ p.add_atom("O", _map=3)
206
+ p.add_bond(1, 3, 1)
207
+
208
+ # Kolbe nitrilation
209
+ # [N]#[C]-[C;Zs;W0]>>[Br]-[C]
210
+ q, p = prepare()
211
+ q.add_atom("C", heteroatoms=0, hybridization=1)
212
+ q.add_atom("C")
213
+ q.add_atom("N")
214
+ q.add_bond(1, 2, 1)
215
+ q.add_bond(2, 3, 3)
216
+
217
+ p.add_atom("C")
218
+ p.add_atom("Br", _map=4)
219
+ p.add_bond(1, 4, 1)
220
+
221
+ # Nitrile hydrolysis
222
+ # [O;D1]-[C]=[O]>>[N]#[C]
223
+ q, p = prepare()
224
+ q.add_atom("C")
225
+ q.add_atom("O", neighbors=1)
226
+ q.add_atom("O")
227
+ q.add_bond(1, 2, 1)
228
+ q.add_bond(1, 3, 2)
229
+
230
+ p.add_atom("C")
231
+ p.add_atom("N", _map=4)
232
+ p.add_bond(1, 4, 3)
233
+
234
+ # sulfamidation
235
+ # [c]-[S](=[O])(=[O])-[N]>>[c]
236
+ q, p = prepare()
237
+ q.add_atom("C", hybridization=4)
238
+ q.add_atom("S")
239
+ q.add_atom("O")
240
+ q.add_atom("O")
241
+ q.add_atom("N", neighbors=1)
242
+ q.add_bond(1, 2, 1)
243
+ q.add_bond(2, 3, 2)
244
+ q.add_bond(2, 4, 2)
245
+ q.add_bond(2, 5, 1)
246
+
247
+ p.add_atom("C")
248
+
249
+ # Ring expansion rearrangement
250
+ #
251
+ q, p = prepare()
252
+ q.add_atom("C")
253
+ q.add_atom("N")
254
+ q.add_atom("C", rings_sizes=6)
255
+ q.add_atom("C")
256
+ q.add_atom("O")
257
+ q.add_atom("C")
258
+ q.add_atom("C")
259
+ q.add_bond(1, 2, 1)
260
+ q.add_bond(2, 3, 1)
261
+ q.add_bond(3, 4, 1)
262
+ q.add_bond(4, 5, 2)
263
+ q.add_bond(3, 6, 1)
264
+ q.add_bond(4, 7, 1)
265
+
266
+ p.add_atom("C")
267
+ p.add_atom("N")
268
+ p.add_atom("C")
269
+ p.add_atom("C")
270
+ p.add_atom("O")
271
+ p.add_atom("C")
272
+ p.add_atom("C")
273
+ p.add_bond(1, 2, 1)
274
+ p.add_bond(2, 3, 2)
275
+ p.add_bond(3, 4, 1)
276
+ p.add_bond(4, 5, 1)
277
+ p.add_bond(4, 6, 1)
278
+ p.add_bond(4, 7, 1)
279
+
280
+ # hydrolysis of bromide alkyl
281
+ #
282
+ q, p = prepare()
283
+ q.add_atom("C", hybridization=1)
284
+ q.add_atom("O", neighbors=1)
285
+ q.add_bond(1, 2, 1)
286
+
287
+ p.add_atom("C")
288
+ p.add_atom("Br")
289
+ p.add_bond(1, 2, 1)
290
+
291
+ # Condensation of ketones/aldehydes and amines into imines
292
+ #
293
+ q, p = prepare()
294
+ q.add_atom("N", neighbors=(1, 2))
295
+ q.add_atom("C", neighbors=(2, 3), heteroatoms=1)
296
+ q.add_bond(1, 2, 2)
297
+
298
+ p.add_atom("C", _map=2)
299
+ p.add_atom("O")
300
+ p.add_bond(2, 3, 2)
301
+
302
+ # Halogenation of alkanes
303
+ #
304
+ q, p = prepare()
305
+ q.add_atom("C", hybridization=1)
306
+ q.add_atom(ListElement(["F", "Cl", "Br"]))
307
+ q.add_bond(1, 2, 1)
308
+
309
+ p.add_atom("C")
310
+
311
+ # heterocyclization
312
+ #
313
+ q, p = prepare()
314
+ q.add_atom("N", heteroatoms=0, hybridization=1, neighbors=(2, 3))
315
+ q.add_atom("C", heteroatoms=2)
316
+ q.add_atom("N", heteroatoms=0, neighbors=2)
317
+ q.add_bond(1, 2, 1)
318
+ q.add_bond(2, 3, 2)
319
+
320
+ p.add_atom("N")
321
+ p.add_atom("C")
322
+ p.add_atom("N")
323
+ p.add_atom("O")
324
+ p.add_bond(1, 2, 1)
325
+ p.add_bond(2, 4, 2)
326
+
327
+ # Reduction of nitrile
328
+ #
329
+ q, p = prepare()
330
+ q.add_atom("N", neighbors=1)
331
+ q.add_atom("C")
332
+ q.add_atom("C", hybridization=1)
333
+ q.add_bond(1, 2, 1)
334
+ q.add_bond(2, 3, 1)
335
+
336
+ p.add_atom("N")
337
+ p.add_atom("C")
338
+ p.add_atom("C")
339
+ p.add_bond(1, 2, 3)
340
+ p.add_bond(2, 3, 1)
341
+
342
+ # SPECIAL CASE
343
+ # Reduction of nitrile into methylamine
344
+ #
345
+ q, p = prepare()
346
+ q.add_atom("C", neighbors=1)
347
+ q.add_atom("N", neighbors=2)
348
+ q.add_atom("C")
349
+ q.add_atom("C", hybridization=1)
350
+ q.add_bond(1, 2, 1)
351
+ q.add_bond(2, 3, 1)
352
+ q.add_bond(3, 4, 1)
353
+
354
+ p.add_atom("N", _map=2)
355
+ p.add_atom("C")
356
+ p.add_atom("C")
357
+ p.add_bond(2, 3, 3)
358
+ p.add_bond(3, 4, 1)
359
+
360
+ # methylation of amides
361
+ #
362
+ q, p = prepare()
363
+ q.add_atom("O")
364
+ q.add_atom("C")
365
+ q.add_atom("N")
366
+ q.add_atom("C", neighbors=1)
367
+ q.add_bond(1, 2, 2)
368
+ q.add_bond(2, 3, 1)
369
+ q.add_bond(3, 4, 1)
370
+
371
+ p.add_atom("O")
372
+ p.add_atom("C")
373
+ p.add_atom("N")
374
+ p.add_bond(1, 2, 2)
375
+ p.add_bond(2, 3, 1)
376
+
377
+ # hydrocyanation of alkenes
378
+ #
379
+ q, p = prepare()
380
+ q.add_atom("C", hybridization=1)
381
+ q.add_atom("C")
382
+ q.add_atom("C")
383
+ q.add_atom("N")
384
+ q.add_bond(1, 2, 1)
385
+ q.add_bond(2, 3, 1)
386
+ q.add_bond(3, 4, 3)
387
+
388
+ p.add_atom("C")
389
+ p.add_atom("C")
390
+ p.add_bond(1, 2, 2)
391
+
392
+ # decarbocylation (alpha atom of nitrile)
393
+ #
394
+ q, p = prepare()
395
+ q.add_atom("N")
396
+ q.add_atom("C")
397
+ q.add_atom("C", neighbors=2)
398
+ q.add_bond(1, 2, 3)
399
+ q.add_bond(2, 3, 1)
400
+
401
+ p.add_atom("N")
402
+ p.add_atom("C")
403
+ p.add_atom("C")
404
+ p.add_atom("C")
405
+ p.add_atom("O")
406
+ p.add_atom("O")
407
+ p.add_bond(1, 2, 3)
408
+ p.add_bond(2, 3, 1)
409
+ p.add_bond(3, 4, 1)
410
+ p.add_bond(4, 5, 2)
411
+ p.add_bond(4, 6, 1)
412
+
413
+ # Bichler-Napieralski reaction
414
+ #
415
+ q, p = prepare()
416
+ q.add_atom("C", rings_sizes=(6,))
417
+ q.add_atom("C", rings_sizes=(6,))
418
+ q.add_atom("N", rings_sizes=(6,), neighbors=2)
419
+ q.add_atom("C")
420
+ q.add_atom("C")
421
+ q.add_atom("C")
422
+ q.add_atom("O")
423
+ q.add_atom("O")
424
+ q.add_atom("C")
425
+ q.add_atom("O", neighbors=1)
426
+ q.add_bond(1, 2, 4)
427
+ q.add_bond(2, 3, 1)
428
+ q.add_bond(3, 4, 1)
429
+ q.add_bond(4, 5, 2)
430
+ q.add_bond(5, 6, 1)
431
+ q.add_bond(6, 7, 2)
432
+ q.add_bond(6, 8, 1)
433
+ q.add_bond(5, 9, 4)
434
+ q.add_bond(9, 10, 1)
435
+ q.add_bond(1, 9, 1)
436
+
437
+ p.add_atom("C")
438
+ p.add_atom("C")
439
+ p.add_atom("N")
440
+ p.add_atom("C")
441
+ p.add_atom("C")
442
+ p.add_atom("C")
443
+ p.add_atom("O")
444
+ p.add_atom("O")
445
+ p.add_atom("C")
446
+ p.add_atom("O")
447
+ p.add_atom("O")
448
+ p.add_bond(1, 2, 4)
449
+ p.add_bond(2, 3, 1)
450
+ p.add_bond(3, 4, 1)
451
+ p.add_bond(4, 5, 2)
452
+ p.add_bond(5, 6, 1)
453
+ p.add_bond(6, 7, 2)
454
+ p.add_bond(6, 8, 1)
455
+ p.add_bond(5, 9, 1)
456
+ p.add_bond(9, 10, 2)
457
+ p.add_bond(9, 11, 1)
458
+
459
+ # heterocyclization in Prins reaction
460
+ #
461
+ q, p = prepare()
462
+ q.add_atom("C")
463
+ q.add_atom("O")
464
+ q.add_atom("C")
465
+ q.add_atom(ListElement(["N", "O"]), neighbors=2)
466
+ q.add_atom("C")
467
+ q.add_atom("C")
468
+ q.add_bond(1, 2, 1)
469
+ q.add_bond(2, 3, 1)
470
+ q.add_bond(3, 4, 1)
471
+ q.add_bond(4, 5, 1)
472
+ q.add_bond(5, 6, 1)
473
+ q.add_bond(1, 6, 1)
474
+
475
+ p.add_atom("C")
476
+ p.add_atom("C", _map=5)
477
+ p.add_bond(1, 5, 2)
478
+
479
+ # recyclization of tetrahydropyran through an opening the ring and dehydration
480
+ #
481
+ q, p = prepare()
482
+ q.add_atom("C")
483
+ q.add_atom("C")
484
+ q.add_atom("C")
485
+ q.add_atom(ListElement(["N", "O"]))
486
+ q.add_atom("C")
487
+ q.add_atom("C")
488
+ q.add_bond(1, 2, 1)
489
+ q.add_bond(2, 3, 1)
490
+ q.add_bond(3, 4, 1)
491
+ q.add_bond(4, 5, 1)
492
+ q.add_bond(5, 6, 1)
493
+ q.add_bond(1, 6, 2)
494
+
495
+ p.add_atom("C")
496
+ p.add_atom("C")
497
+ p.add_atom("C")
498
+ p.add_atom("A")
499
+ p.add_atom("C")
500
+ p.add_atom("C")
501
+ p.add_atom("O")
502
+ p.add_bond(1, 2, 1)
503
+ p.add_bond(1, 7, 1)
504
+ p.add_bond(3, 7, 1)
505
+ p.add_bond(3, 4, 1)
506
+ p.add_bond(4, 5, 1)
507
+ p.add_bond(5, 6, 1)
508
+ p.add_bond(1, 6, 1)
509
+
510
+ # alkenes + h2o/hHal
511
+ #
512
+ q, p = prepare()
513
+ q.add_atom("C", hybridization=1)
514
+ q.add_atom("C", hybridization=1)
515
+ q.add_atom(ListElement(["O", "F", "Cl", "Br", "I"]), neighbors=1)
516
+ q.add_bond(1, 2, 1)
517
+ q.add_bond(2, 3, 1)
518
+
519
+ p.add_atom("C")
520
+ p.add_atom("C")
521
+ p.add_bond(1, 2, 2)
522
+
523
+ # methylation of dimethylamines
524
+ #
525
+ q, p = prepare()
526
+ q.add_atom("C", neighbors=1)
527
+ q.add_atom("N", neighbors=3)
528
+ q.add_bond(1, 2, 1)
529
+
530
+ p.add_atom("N", _map=2)
531
+
532
+ __all__ = ["rules"]
synplan/chem/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing additional functions needed in different reaction data processing
2
+ protocols."""
3
+
4
+ import logging
5
+ from typing import Iterable
6
+
7
+ from CGRtools.containers import (
8
+ CGRContainer,
9
+ MoleculeContainer,
10
+ QueryContainer,
11
+ ReactionContainer,
12
+ )
13
+ from CGRtools.exceptions import InvalidAromaticRing
14
+ from tqdm import tqdm
15
+
16
+ from synplan.chem import smiles_parser
17
+ from synplan.utils.files import MoleculeReader, MoleculeWriter
18
+
19
+ from chython import MoleculeContainer as MoleculeContainerChython
20
+
21
+
22
+ def mol_from_smiles(
23
+ smiles: str,
24
+ standardize: bool = True,
25
+ clean_stereo: bool = True,
26
+ clean2d: bool = True,
27
+ ) -> MoleculeContainer:
28
+ """Converts a SMILES string to a `MoleculeContainer` object and optionally
29
+ standardizes, cleans stereochemistry, and cleans 2D coordinates.
30
+
31
+ :param smiles: The SMILES string representing the molecule.
32
+ :param standardize: Whether to standardize the molecule (default is True).
33
+ :param clean_stereo: Whether to remove the stereo marks on atoms of the molecule (default is True).
34
+ :param clean2d: Whether to clean the 2D coordinates of the molecule (default is True).
35
+ :return: The processed molecule object.
36
+ :raises ValueError: If the SMILES string could not be processed by CGRtools.
37
+ """
38
+ molecule = smiles_parser(smiles)
39
+
40
+ if not isinstance(molecule, MoleculeContainer):
41
+ raise ValueError("SMILES string was not processed by CGRtools")
42
+
43
+ tmp = molecule.copy()
44
+ try:
45
+ if standardize:
46
+ tmp.canonicalize()
47
+ if clean_stereo:
48
+ tmp.clean_stereo()
49
+ if clean2d:
50
+ tmp.clean2d()
51
+ molecule = tmp
52
+ except InvalidAromaticRing:
53
+ logging.warning(
54
+ "CGRtools was not able to standardize molecule due to invalid aromatic ring"
55
+ )
56
+ return molecule
57
+
58
+
59
+ def query_to_mol(query: QueryContainer) -> MoleculeContainer:
60
+ """Converts a QueryContainer object into a MoleculeContainer object.
61
+
62
+ :param query: A QueryContainer object representing the query structure.
63
+ :return: A MoleculeContainer object that replicates the structure of the query.
64
+ """
65
+ new_mol = MoleculeContainer()
66
+ for n, atom in query.atoms():
67
+ new_mol.add_atom(
68
+ atom.atomic_symbol, n, charge=atom.charge, is_radical=atom.is_radical
69
+ )
70
+ for i, j, bond in query.bonds():
71
+ new_mol.add_bond(i, j, int(bond))
72
+ return new_mol
73
+
74
+
75
+ def reaction_query_to_reaction(reaction_rule: ReactionContainer) -> ReactionContainer:
76
+ """Converts a ReactionContainer object with query structures into a
77
+ ReactionContainer with molecular structures.
78
+
79
+ :param reaction_rule: A ReactionContainer object where reactants and products are
80
+ QueryContainer objects.
81
+ :return: A new ReactionContainer object where reactants and products are
82
+ MoleculeContainer objects.
83
+ """
84
+ reactants = [query_to_mol(q) for q in reaction_rule.reactants]
85
+ products = [query_to_mol(q) for q in reaction_rule.products]
86
+ reagents = [
87
+ query_to_mol(q) for q in reaction_rule.reagents
88
+ ] # Assuming reagents are also part of the rule
89
+ reaction = ReactionContainer(reactants, products, reagents, reaction_rule.meta)
90
+ reaction.name = reaction_rule.name
91
+ return reaction
92
+
93
+
94
+ def unite_molecules(molecules: Iterable[MoleculeContainer]) -> MoleculeContainer:
95
+ """Unites a list of MoleculeContainer objects into a single MoleculeContainer. This
96
+ function takes multiple molecules and combines them into one larger molecule. The
97
+ first molecule in the list is taken as the base, and subsequent molecules are united
98
+ with it sequentially.
99
+
100
+ :param molecules: A list of MoleculeContainer objects to be united.
101
+ :return: A single MoleculeContainer object representing the union of all input
102
+ molecules.
103
+ """
104
+ new_mol = MoleculeContainer()
105
+ for mol in molecules:
106
+ new_mol = new_mol.union(mol)
107
+ return new_mol
108
+
109
+
110
+ def safe_canonicalization(molecule: MoleculeContainer) -> MoleculeContainer:
111
+ """Attempts to canonicalize a molecule, handling any exceptions. If the
112
+ canonicalization process fails due to an InvalidAromaticRing exception, it safely
113
+ returns the original molecule.
114
+
115
+ :param molecule: The given molecule to be canonicalized.
116
+ :return: The canonicalized molecule if successful, otherwise the original molecule.
117
+ """
118
+ molecule._atoms = dict(sorted(molecule._atoms.items()))
119
+
120
+ molecule_copy = molecule.copy()
121
+ try:
122
+ molecule_copy.canonicalize()
123
+ molecule_copy.clean_stereo()
124
+ return molecule_copy
125
+ except InvalidAromaticRing:
126
+ return molecule
127
+
128
+
129
+ def standardize_building_blocks(input_file: str, output_file: str) -> str:
130
+ """Standardizes custom building blocks.
131
+
132
+ :param input_file: The path to the file that stores the original building blocks.
133
+ :param output_file: The path to the file that will store the standardized building
134
+ blocks.
135
+ :return: The path to the file with standardized building blocks.
136
+ """
137
+ if input_file == output_file:
138
+ raise ValueError("input_file name and output_file name cannot be the same.")
139
+
140
+ with MoleculeReader(input_file) as inp_file, MoleculeWriter(
141
+ output_file
142
+ ) as out_file:
143
+ for mol in tqdm(
144
+ inp_file,
145
+ desc="Number of building blocks processed: ",
146
+ bar_format="{desc}{n} [{elapsed}]",
147
+ ):
148
+ try:
149
+ mol = safe_canonicalization(mol)
150
+ except Exception as e:
151
+ logging.debug(e)
152
+ continue
153
+ out_file.write(mol)
154
+
155
+ return output_file
156
+
157
+
158
+ def cgr_from_reaction_rule(reaction_rule: ReactionContainer) -> CGRContainer:
159
+ """Creates a CGR from the given reaction rule.
160
+
161
+ :param reaction_rule: The reaction rule to be converted.
162
+ :return: The resulting CGR.
163
+ """
164
+
165
+ reaction_rule = reaction_query_to_reaction(reaction_rule)
166
+ cgr_rule = ~reaction_rule
167
+
168
+ return cgr_rule
169
+
170
+
171
+ def hash_from_reaction_rule(reaction_rule: ReactionContainer) -> hash:
172
+ """Generates hash for the given reaction rule.
173
+
174
+ :param reaction_rule: The reaction rule to be converted.
175
+ :return: The resulting hash.
176
+ """
177
+
178
+ reactants_hash = tuple(sorted(hash(r) for r in reaction_rule.reactants))
179
+ reagents_hash = tuple(sorted(hash(r) for r in reaction_rule.reagents))
180
+ products_hash = tuple(sorted(hash(r) for r in reaction_rule.products))
181
+
182
+ return hash((reactants_hash, reagents_hash, products_hash))
183
+
184
+
185
+ def reverse_reaction(
186
+ reaction: ReactionContainer,
187
+ ) -> ReactionContainer:
188
+ """Reverses the given reaction.
189
+
190
+ :param reaction: The reaction to be reversed.
191
+ :return: The reversed reaction.
192
+ """
193
+ reversed_reaction = ReactionContainer(
194
+ reaction.products, reaction.reactants, reaction.reagents, reaction.meta
195
+ )
196
+ reversed_reaction.name = reaction.name
197
+
198
+ return reversed_reaction
199
+
200
+
201
+ def cgrtools_to_chython_molecule(molecule):
202
+ molecule_chython = MoleculeContainerChython()
203
+ for n, atom in molecule.atoms():
204
+ molecule_chython.add_atom(atom.atomic_symbol, n)
205
+
206
+ for n, m, bond in molecule.bonds():
207
+ molecule_chython.add_bond(n, m, int(bond))
208
+
209
+ return molecule_chython
210
+
211
+
212
+ def chython_query_to_cgrtools(query):
213
+ cgrtools_query = QueryContainer()
214
+ for n, atom in query.atoms():
215
+ cgrtools_query.add_atom(
216
+ atom=atom.atomic_symbol,
217
+ charge=atom.charge,
218
+ neighbors=atom.neighbors,
219
+ hybridization=atom.hybridization,
220
+ _map=n,
221
+ )
222
+ for n, m, bond in query.bonds():
223
+ cgrtools_query.add_bond(n, m, int(bond))
224
+
225
+ return cgrtools_query
synplan/interfaces/__init__.py ADDED
File without changes
synplan/interfaces/cli.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing commands line scripts for training and planning steps."""
2
+
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import click
8
+ import yaml
9
+
10
+ from synplan.chem.data.filtering import ReactionFilterConfig, filter_reactions_from_file
11
+ from synplan.chem.data.standardizing import (
12
+ ReactionStandardizationConfig,
13
+ standardize_reactions_from_file,
14
+ )
15
+ from synplan.chem.reaction_rules.extraction import extract_rules_from_reactions
16
+ from synplan.chem.reaction_routes.clustering import run_cluster_cli
17
+ from synplan.chem.utils import standardize_building_blocks
18
+ from synplan.mcts.search import run_search
19
+ from synplan.ml.training.supervised import create_policy_dataset, run_policy_training
20
+ from synplan.ml.training.reinforcement import run_updating
21
+ from synplan.utils.config import (
22
+ PolicyNetworkConfig,
23
+ RuleExtractionConfig,
24
+ TreeConfig,
25
+ TuningConfig,
26
+ ValueNetworkConfig,
27
+ )
28
+ from synplan.utils.loading import download_all_data
29
+ from synplan.utils.visualisation import (
30
+ routes_clustering_report,
31
+ routes_subclustering_report,
32
+ )
33
+
34
+ warnings.filterwarnings("ignore")
35
+
36
+
37
+ @click.group(name="synplan")
38
+ def synplan():
39
+ """SynPlanner command line interface."""
40
+
41
+
42
+ @synplan.command(name="download_all_data")
43
+ @click.option(
44
+ "--save_to",
45
+ "save_to",
46
+ help="Path to the folder where downloaded data will be stored.",
47
+ )
48
+ def download_all_data_cli(save_to: str = ".") -> None:
49
+ """Downloads all data for training, planning and benchmarking SynPlanner."""
50
+ download_all_data(save_to=save_to)
51
+
52
+
53
+ @synplan.command(name="building_blocks_standardizing")
54
+ @click.option(
55
+ "--input",
56
+ "input_file",
57
+ required=True,
58
+ type=click.Path(exists=True),
59
+ help="Path to the file with building blocks to be standardized.",
60
+ )
61
+ @click.option(
62
+ "--output",
63
+ "output_file",
64
+ required=True,
65
+ type=click.Path(),
66
+ help="Path to the file where standardized building blocks will be stored.",
67
+ )
68
+ def building_blocks_standardizing_cli(input_file: str, output_file: str) -> None:
69
+ """Standardizes building blocks."""
70
+ standardize_building_blocks(input_file=input_file, output_file=output_file)
71
+
72
+
73
+ @synplan.command(name="reaction_standardizing")
74
+ @click.option(
75
+ "--config",
76
+ "config_path",
77
+ required=True,
78
+ type=click.Path(exists=True),
79
+ help="Path to the configuration file for reactions standardizing.",
80
+ )
81
+ @click.option(
82
+ "--input",
83
+ "input_file",
84
+ required=True,
85
+ type=click.Path(exists=True),
86
+ help="Path to the file with reactions to be standardized.",
87
+ )
88
+ @click.option(
89
+ "--output",
90
+ "output_file",
91
+ type=click.Path(),
92
+ help="Path to the file where standardized reactions will be stored.",
93
+ )
94
+ @click.option(
95
+ "--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
96
+ )
97
+ def reaction_standardizing_cli(
98
+ config_path: str, input_file: str, output_file: str, num_cpus: int
99
+ ) -> None:
100
+ """Standardizes reactions and remove duplicates."""
101
+ stand_config = ReactionStandardizationConfig.from_yaml(config_path)
102
+ standardize_reactions_from_file(
103
+ config=stand_config,
104
+ input_reaction_data_path=input_file,
105
+ standardized_reaction_data_path=output_file,
106
+ num_cpus=num_cpus,
107
+ batch_size=100,
108
+ )
109
+
110
+
111
+ @synplan.command(name="reaction_filtering")
112
+ @click.option(
113
+ "--config",
114
+ "config_path",
115
+ required=True,
116
+ type=click.Path(exists=True),
117
+ help="Path to the configuration file for reactions filtering.",
118
+ )
119
+ @click.option(
120
+ "--input",
121
+ "input_file",
122
+ required=True,
123
+ type=click.Path(exists=True),
124
+ help="Path to the file with reactions to be filtered.",
125
+ )
126
+ @click.option(
127
+ "--output",
128
+ "output_file",
129
+ default=Path("./"),
130
+ type=click.Path(),
131
+ help="Path to the file where successfully filtered reactions will be stored.",
132
+ )
133
+ @click.option(
134
+ "--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
135
+ )
136
+ def reaction_filtering_cli(
137
+ config_path: str, input_file: str, output_file: str, num_cpus: int
138
+ ):
139
+ """Filters erroneous reactions."""
140
+ reaction_check_config = ReactionFilterConfig().from_yaml(config_path)
141
+ filter_reactions_from_file(
142
+ config=reaction_check_config,
143
+ input_reaction_data_path=input_file,
144
+ filtered_reaction_data_path=output_file,
145
+ num_cpus=num_cpus,
146
+ batch_size=100,
147
+ )
148
+
149
+
150
+ @synplan.command(name="rule_extracting")
151
+ @click.option(
152
+ "--config",
153
+ "config_path",
154
+ required=True,
155
+ type=click.Path(exists=True),
156
+ help="Path to the configuration file for reaction rules extracting.",
157
+ )
158
+ @click.option(
159
+ "--input",
160
+ "input_file",
161
+ required=True,
162
+ type=click.Path(exists=True),
163
+ help="Path to the file with reactions for reaction rules extraction.",
164
+ )
165
+ @click.option(
166
+ "--output",
167
+ "output_file",
168
+ required=True,
169
+ type=click.Path(),
170
+ help="Path to the file where extracted reaction rules will be stored.",
171
+ )
172
+ @click.option(
173
+ "--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
174
+ )
175
+ def rule_extracting_cli(
176
+ config_path: str, input_file: str, output_file: str, num_cpus: int
177
+ ):
178
+ """Reaction rules extraction."""
179
+ reaction_rule_config = RuleExtractionConfig.from_yaml(config_path)
180
+ extract_rules_from_reactions(
181
+ config=reaction_rule_config,
182
+ reaction_data_path=input_file,
183
+ reaction_rules_path=output_file,
184
+ num_cpus=num_cpus,
185
+ batch_size=100,
186
+ )
187
+
188
+
189
+ @synplan.command(name="ranking_policy_training")
190
+ @click.option(
191
+ "--config",
192
+ "config_path",
193
+ required=True,
194
+ type=click.Path(exists=True),
195
+ help="Path to the configuration file for ranking policy training.",
196
+ )
197
+ @click.option(
198
+ "--reaction_data",
199
+ required=True,
200
+ type=click.Path(exists=True),
201
+ help="Path to the file with reactions for ranking policy training.",
202
+ )
203
+ @click.option(
204
+ "--reaction_rules",
205
+ required=True,
206
+ type=click.Path(exists=True),
207
+ help="Path to the file with extracted reaction rules.",
208
+ )
209
+ @click.option(
210
+ "--results_dir",
211
+ default=Path("."),
212
+ type=click.Path(),
213
+ help="Path to the directory where the trained policy network will be stored.",
214
+ )
215
+ @click.option(
216
+ "--num_cpus",
217
+ default=4,
218
+ type=int,
219
+ help="The number of CPUs to use for training set preparation.",
220
+ )
221
+ def ranking_policy_training_cli(
222
+ config_path: str,
223
+ reaction_data: str,
224
+ reaction_rules: str,
225
+ results_dir: str,
226
+ num_cpus: int,
227
+ ) -> None:
228
+ """Ranking policy network training."""
229
+ policy_config = PolicyNetworkConfig.from_yaml(config_path)
230
+ policy_config.policy_type = "ranking"
231
+ policy_dataset_file = os.path.join(results_dir, "policy_dataset.dt")
232
+
233
+ datamodule = create_policy_dataset(
234
+ reaction_rules_path=reaction_rules,
235
+ molecules_or_reactions_path=reaction_data,
236
+ output_path=policy_dataset_file,
237
+ dataset_type="ranking",
238
+ batch_size=policy_config.batch_size,
239
+ num_cpus=num_cpus,
240
+ )
241
+
242
+ run_policy_training(datamodule, config=policy_config, results_path=results_dir)
243
+
244
+
245
+ @synplan.command(name="filtering_policy_training")
246
+ @click.option(
247
+ "--config",
248
+ "config_path",
249
+ required=True,
250
+ type=click.Path(exists=True),
251
+ help="Path to the configuration file for filtering policy training.",
252
+ )
253
+ @click.option(
254
+ "--molecule_data",
255
+ required=True,
256
+ type=click.Path(exists=True),
257
+ help="Path to the file with molecules for filtering policy training.",
258
+ )
259
+ @click.option(
260
+ "--reaction_rules",
261
+ required=True,
262
+ type=click.Path(exists=True),
263
+ help="Path to the file with extracted reaction rules.",
264
+ )
265
+ @click.option(
266
+ "--results_dir",
267
+ default=Path("."),
268
+ type=click.Path(),
269
+ help="Path to the directory where the trained policy network will be stored.",
270
+ )
271
+ @click.option(
272
+ "--num_cpus",
273
+ default=8,
274
+ type=int,
275
+ help="The number of CPUs to use for training set preparation.",
276
+ )
277
+ def filtering_policy_training_cli(
278
+ config_path: str,
279
+ molecule_data: str,
280
+ reaction_rules: str,
281
+ results_dir: str,
282
+ num_cpus: int,
283
+ ):
284
+ """Filtering policy network training."""
285
+
286
+ policy_config = PolicyNetworkConfig.from_yaml(config_path)
287
+ policy_config.policy_type = "filtering"
288
+ policy_dataset_file = os.path.join(results_dir, "policy_dataset.ckpt")
289
+
290
+ datamodule = create_policy_dataset(
291
+ reaction_rules_path=reaction_rules,
292
+ molecules_or_reactions_path=molecule_data,
293
+ output_path=policy_dataset_file,
294
+ dataset_type="filtering",
295
+ batch_size=policy_config.batch_size,
296
+ num_cpus=num_cpus,
297
+ )
298
+
299
+ run_policy_training(datamodule, config=policy_config, results_path=results_dir)
300
+
301
+
302
+ @synplan.command(name="value_network_tuning")
303
+ @click.option(
304
+ "--config",
305
+ "config_path",
306
+ required=True,
307
+ type=click.Path(exists=True),
308
+ help="Path to the configuration file for value network training.",
309
+ )
310
+ @click.option(
311
+ "--targets",
312
+ required=True,
313
+ type=click.Path(exists=True),
314
+ help="Path to the file with target molecules for planning simulations.",
315
+ )
316
+ @click.option(
317
+ "--reaction_rules",
318
+ required=True,
319
+ type=click.Path(exists=True),
320
+ help="Path to the file with extracted reaction rules. Needed for planning simulations.",
321
+ )
322
+ @click.option(
323
+ "--building_blocks",
324
+ required=True,
325
+ type=click.Path(exists=True),
326
+ help="Path to the file with building blocks. Needed for planning simulations.",
327
+ )
328
+ @click.option(
329
+ "--policy_network",
330
+ required=True,
331
+ type=click.Path(exists=True),
332
+ help="Path to the file with trained policy network. Needed for planning simulations.",
333
+ )
334
+ @click.option(
335
+ "--value_network",
336
+ default=None,
337
+ type=click.Path(exists=True),
338
+ help="Path to the file with trained value network. Needed in case of additional value network fine-tuning",
339
+ )
340
+ @click.option(
341
+ "--results_dir",
342
+ default=".",
343
+ type=click.Path(exists=False),
344
+ help="Path to the directory where the trained value network will be stored.",
345
+ )
346
+ def value_network_tuning_cli(
347
+ config_path: str,
348
+ targets: str,
349
+ reaction_rules: str,
350
+ building_blocks: str,
351
+ policy_network: str,
352
+ value_network: str,
353
+ results_dir: str,
354
+ ):
355
+ """Value network tuning."""
356
+
357
+ with open(config_path, "r", encoding="utf-8") as file:
358
+ config = yaml.safe_load(file)
359
+
360
+ policy_config = PolicyNetworkConfig.from_dict(config["node_expansion"])
361
+ policy_config.weights_path = policy_network
362
+
363
+ value_config = ValueNetworkConfig.from_dict(config["value_network"])
364
+ if value_network is None:
365
+ value_config.weights_path = os.path.join(
366
+ results_dir, "weights", "value_network.ckpt"
367
+ )
368
+
369
+ tree_config = TreeConfig.from_dict(config["tree"])
370
+ tuning_config = TuningConfig.from_dict(config["tuning"])
371
+
372
+ run_updating(
373
+ targets_path=targets,
374
+ tree_config=tree_config,
375
+ policy_config=policy_config,
376
+ value_config=value_config,
377
+ reinforce_config=tuning_config,
378
+ reaction_rules_path=reaction_rules,
379
+ building_blocks_path=building_blocks,
380
+ results_root=results_dir,
381
+ )
382
+
383
+
384
+ @synplan.command(name="planning")
385
+ @click.option(
386
+ "--config",
387
+ "config_path",
388
+ required=True,
389
+ type=click.Path(exists=True),
390
+ help="Path to the configuration file for retrosynthetic planning.",
391
+ )
392
+ @click.option(
393
+ "--targets",
394
+ required=True,
395
+ type=click.Path(exists=True),
396
+ help="Path to the file with target molecules for retrosynthetic planning.",
397
+ )
398
+ @click.option(
399
+ "--reaction_rules",
400
+ required=True,
401
+ type=click.Path(exists=True),
402
+ help="Path to the file with extracted reaction rules.",
403
+ )
404
+ @click.option(
405
+ "--building_blocks",
406
+ required=True,
407
+ type=click.Path(exists=True),
408
+ help="Path to the file with building blocks.",
409
+ )
410
+ @click.option(
411
+ "--policy_network",
412
+ required=True,
413
+ type=click.Path(exists=True),
414
+ help="Path to the file with trained policy network.",
415
+ )
416
+ @click.option(
417
+ "--value_network",
418
+ default=None,
419
+ type=click.Path(exists=True),
420
+ help="Path to the file with trained value network.",
421
+ )
422
+ @click.option(
423
+ "--results_dir",
424
+ default=".",
425
+ type=click.Path(exists=False),
426
+ help="Path to the file where retrosynthetic planning results will be stored.",
427
+ )
428
+ def planning_cli(
429
+ config_path: str,
430
+ targets: str,
431
+ reaction_rules: str,
432
+ building_blocks: str,
433
+ policy_network: str,
434
+ value_network: str,
435
+ results_dir: str,
436
+ ):
437
+ """Retrosynthetic planning."""
438
+
439
+ with open(config_path, "r", encoding="utf-8") as file:
440
+ config = yaml.safe_load(file)
441
+
442
+ search_config = {**config["tree"], **config["node_evaluation"]}
443
+ policy_config = PolicyNetworkConfig.from_dict(
444
+ {**config["node_expansion"], **{"weights_path": policy_network}}
445
+ )
446
+
447
+ run_search(
448
+ targets_path=targets,
449
+ search_config=search_config,
450
+ policy_config=policy_config,
451
+ reaction_rules_path=reaction_rules,
452
+ building_blocks_path=building_blocks,
453
+ value_network_path=value_network,
454
+ results_root=results_dir,
455
+ )
456
+
457
+
458
+ @synplan.command(name="clustering")
459
+ @click.option(
460
+ "--targets",
461
+ required=True,
462
+ type=click.Path(exists=True),
463
+ help="Path to the file with target molecules for retrosynthetic planning.",
464
+ )
465
+ @click.option(
466
+ "--routes_file",
467
+ default=".",
468
+ type=click.Path(exists=False),
469
+ help="Path to the file where the planning results are stored.",
470
+ )
471
+ @click.option(
472
+ "--cluster_results_dir",
473
+ default=".",
474
+ type=click.Path(exists=False),
475
+ help="Path to the file where clustering results will be stored.",
476
+ )
477
+ @click.option(
478
+ "--perform_subcluster",
479
+ default=None,
480
+ type=click.Path(exists=False),
481
+ help="Perform subclustering.",
482
+ )
483
+ @click.option(
484
+ "--subcluster_results_dir",
485
+ default=".",
486
+ type=click.Path(exists=False),
487
+ help="Path to the file where subclustering results will be stored.",
488
+ )
489
+ def cluster_route_from_file_cli(
490
+ targets: str,
491
+ routes_file: str,
492
+ cluster_results_dir: str,
493
+ perform_subcluster: bool,
494
+ subcluster_results_dir: str,
495
+ ):
496
+ """Clustering the routes from planning"""
497
+ run_cluster_cli(
498
+ routes_file=routes_file,
499
+ cluster_results_dir=cluster_results_dir,
500
+ perform_subcluster=perform_subcluster,
501
+ subcluster_results_dir=subcluster_results_dir if perform_subcluster else None,
502
+ )
503
+
504
+
505
+ if __name__ == "__main__":
506
+ synplan()
synplan/interfaces/gui.py ADDED
@@ -0,0 +1,1323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import pickle
3
+ import re
4
+ import uuid
5
+ import io
6
+ import zipfile
7
+
8
+ import pandas as pd
9
+ import streamlit as st
10
+ from CGRtools.files import SMILESRead
11
+ from streamlit_ketcher import st_ketcher
12
+ from huggingface_hub import hf_hub_download
13
+ from huggingface_hub.utils import disable_progress_bars
14
+
15
+
16
+ from synplan.mcts.expansion import PolicyNetworkFunction
17
+ from synplan.mcts.search import extract_tree_stats
18
+ from synplan.mcts.tree import Tree
19
+ from synplan.chem.utils import mol_from_smiles
20
+ from synplan.chem.reaction_routes.route_cgr import *
21
+ from synplan.chem.reaction_routes.clustering import *
22
+
23
+ from synplan.utils.visualisation import (
24
+ routes_clustering_report,
25
+ routes_subclustering_report,
26
+ generate_results_html,
27
+ html_top_routes_cluster,
28
+ get_route_svg,
29
+ get_route_svg_from_json
30
+ )
31
+ from synplan.utils.config import TreeConfig, PolicyNetworkConfig
32
+ from synplan.utils.loading import load_reaction_rules, load_building_blocks
33
+
34
+
35
+ import psutil
36
+ import gc
37
+
38
+
39
+ disable_progress_bars("huggingface_hub")
40
+
41
+ smiles_parser = SMILESRead.create_parser(ignore=True)
42
+ DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
43
+
44
+
45
+ # --- Helper Functions ---
46
+ def download_button(
47
+ object_to_download, download_filename, button_text, pickle_it=False
48
+ ):
49
+ """
50
+ Issued from
51
+ Generates a link to download the given object_to_download.
52
+ Params:
53
+ ------
54
+ object_to_download: The object to be downloaded.
55
+ download_filename (str): filename and extension of file. e.g. mydata.csv,
56
+ some_txt_output.txt download_link_text (str): Text to display for download
57
+ link.
58
+ button_text (str): Text to display on download button (e.g. 'click here to download file')
59
+ pickle_it (bool): If True, pickle file.
60
+ Returns:
61
+ -------
62
+ (str): the anchor tag to download object_to_download
63
+ Examples:
64
+ --------
65
+ download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
66
+ download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
67
+ """
68
+ if pickle_it:
69
+ try:
70
+ object_to_download = pickle.dumps(object_to_download)
71
+ except pickle.PicklingError as e:
72
+ st.write(e)
73
+ return None
74
+
75
+ else:
76
+ if isinstance(object_to_download, bytes):
77
+ pass
78
+
79
+ elif isinstance(object_to_download, pd.DataFrame):
80
+ object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
81
+
82
+ try:
83
+ b64 = base64.b64encode(object_to_download.encode()).decode()
84
+ except AttributeError:
85
+ b64 = base64.b64encode(object_to_download).decode()
86
+
87
+ button_uuid = str(uuid.uuid4()).replace("-", "")
88
+ button_id = re.sub("\d+", "", button_uuid)
89
+
90
+ custom_css = f"""
91
+ <style>
92
+ #{button_id} {{
93
+ background-color: rgb(255, 255, 255);
94
+ color: rgb(38, 39, 48);
95
+ text-decoration: none;
96
+ border-radius: 4px;
97
+ border-width: 1px;
98
+ border-style: solid;
99
+ border-color: rgb(230, 234, 241);
100
+ border-image: initial;
101
+ }}
102
+ #{button_id}:hover {{
103
+ border-color: rgb(246, 51, 102);
104
+ color: rgb(246, 51, 102);
105
+ }}
106
+ #{button_id}:active {{
107
+ box-shadow: none;
108
+ background-color: rgb(246, 51, 102);
109
+ color: white;
110
+ }}
111
+ </style> """
112
+
113
+ dl_link = (
114
+ custom_css
115
+ + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
116
+ )
117
+ return dl_link
118
+
119
+
120
+ @st.cache_resource
121
+ def load_planning_resources_cached(): # Renamed to avoid conflict if main calls it directly
122
+ building_blocks_path = hf_hub_download(
123
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
124
+ filename="building_blocks_em_sa_ln.smi",
125
+ subfolder="building_blocks",
126
+ local_dir=".",
127
+ )
128
+ ranking_policy_weights_path = hf_hub_download(
129
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
130
+ filename="ranking_policy_network.ckpt",
131
+ subfolder="uspto/weights",
132
+ local_dir=".",
133
+ )
134
+ reaction_rules_path = hf_hub_download(
135
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
136
+ filename="uspto_reaction_rules.pickle",
137
+ subfolder="uspto",
138
+ local_dir=".",
139
+ )
140
+ return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
141
+
142
+
143
+ # --- GUI Sections ---
144
+
145
+
146
+ def initialize_app():
147
+ """1. Initialization: Setting up the main window, layout, and initial widgets."""
148
+ st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
149
+
150
+ # Initialize session state variables if they don't exist.
151
+ if "planning_done" not in st.session_state:
152
+ st.session_state.planning_done = False
153
+ if "tree" not in st.session_state:
154
+ st.session_state.tree = None
155
+ if "res" not in st.session_state:
156
+ st.session_state.res = None
157
+ if "target_smiles" not in st.session_state:
158
+ st.session_state.target_smiles = (
159
+ "" # Initial value, might be overwritten by ketcher
160
+ )
161
+
162
+ # Clustering state
163
+ if "clustering_done" not in st.session_state:
164
+ st.session_state.clustering_done = False
165
+ if "clusters" not in st.session_state:
166
+ st.session_state.clusters = None
167
+ if "reactions_dict" not in st.session_state:
168
+ st.session_state.reactions_dict = None
169
+ if "num_clusters_setting" not in st.session_state: # Store the setting used
170
+ st.session_state.num_clusters_setting = 10
171
+ if "route_cgrs_dict" not in st.session_state:
172
+ st.session_state.route_cgrs_dict = None
173
+ if "sb_cgrs_dict" not in st.session_state:
174
+ st.session_state.sb_cgrs_dict = None
175
+ if "route_json" not in st.session_state:
176
+ st.session_state.route_json = None
177
+
178
+ # Subclustering state
179
+ if "subclustering_done" not in st.session_state:
180
+ st.session_state.subclustering_done = False
181
+ if "subclusters" not in st.session_state: # Renamed from 'sub' for clarity
182
+ st.session_state.subclusters = None
183
+
184
+ # Download state (less critical now with direct download links)
185
+ if "clusters_downloaded" not in st.session_state: # Example, might not be needed
186
+ st.session_state.clusters_downloaded = False
187
+
188
+ if "ketcher" not in st.session_state: # For ketcher persistence
189
+ st.session_state.ketcher = DEFAULT_MOL
190
+
191
+ intro_text = """
192
+ This is a demo of the graphical user interface of
193
+ [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
194
+ SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
195
+
196
+ More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
197
+ """
198
+ st.title("`SynPlanner GUI`")
199
+ st.write(intro_text)
200
+
201
+
202
+ def setup_sidebar():
203
+ """2. Sidebar: Handling the widgets and logic within the sidebar area."""
204
+ # st.sidebar.image("img/logo.png") # Assuming img/logo.png is available
205
+ st.sidebar.title("Docs")
206
+ st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
207
+
208
+ st.sidebar.title("Tutorials")
209
+ st.sidebar.markdown(
210
+ "https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials"
211
+ )
212
+
213
+ st.sidebar.title("Paper")
214
+ st.sidebar.markdown(
215
+ "https://chemrxiv.org/engage/chemrxiv/article-details/66add90bc9c6a5c07ae65796"
216
+ )
217
+
218
+ st.sidebar.title("Issues")
219
+ st.sidebar.markdown(
220
+ "[Report a bug 🐞](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
221
+ )
222
+
223
+
224
+ def handle_molecule_input():
225
+ """3. Molecule Input: Managing the input area for molecule data with two-way synchronization."""
226
+ st.header("Molecule input")
227
+ st.markdown(
228
+ """
229
+ You can provide a molecular structure by either providing:
230
+ * SMILES string + Enter
231
+ * Draw it + Apply
232
+ """
233
+ )
234
+
235
+ if "shared_smiles" not in st.session_state:
236
+ st.session_state.shared_smiles = st.session_state.get("ketcher", DEFAULT_MOL)
237
+
238
+ if "ketcher_render_count" not in st.session_state:
239
+ st.session_state.ketcher_render_count = 0
240
+
241
+ def text_input_changed_callback():
242
+ new_text_value = (
243
+ st.session_state.smiles_text_input_key_for_sync
244
+ ) # Key of the text_input
245
+ if new_text_value != st.session_state.shared_smiles:
246
+ st.session_state.shared_smiles = new_text_value
247
+ st.session_state.ketcher = new_text_value
248
+ st.session_state.ketcher_render_count += 1
249
+
250
+ # SMILES Text Input
251
+ st.text_input(
252
+ "SMILES:",
253
+ value=st.session_state.shared_smiles,
254
+ key="smiles_text_input_key_for_sync", # Unique key for this widget
255
+ on_change=text_input_changed_callback,
256
+ help="Enter SMILES string and press Enter. The drawing will update, and vice-versa.",
257
+ )
258
+
259
+ ketcher_key = f"ketcher_widget_for_sync_{st.session_state.ketcher_render_count}"
260
+ smile_code_output_from_ketcher = st_ketcher(
261
+ st.session_state.shared_smiles, key=ketcher_key
262
+ )
263
+
264
+ if smile_code_output_from_ketcher != st.session_state.shared_smiles:
265
+ st.session_state.shared_smiles = smile_code_output_from_ketcher
266
+ st.session_state.ketcher = smile_code_output_from_ketcher
267
+ st.rerun()
268
+
269
+ current_smiles_for_planning = st.session_state.shared_smiles
270
+
271
+ last_planned_smiles = st.session_state.get("target_smiles")
272
+ if (
273
+ last_planned_smiles
274
+ and current_smiles_for_planning != last_planned_smiles
275
+ and st.session_state.get("planning_done", False)
276
+ ):
277
+ st.warning(
278
+ "Molecule structure has changed since the last successful planning run. "
279
+ "Results shown below (if any) are for the previous molecule. "
280
+ "Please re-run planning for the current structure."
281
+ )
282
+
283
+ # Ensure st.session_state.ketcher is consistent for other parts of the app
284
+ if st.session_state.get("ketcher") != current_smiles_for_planning:
285
+ st.session_state.ketcher = current_smiles_for_planning
286
+
287
+ return current_smiles_for_planning
288
+
289
+
290
+ def setup_planning_options():
291
+ """4. Planning: Encapsulating the logic related to the "planning" functionality."""
292
+ st.header("Launch calculation")
293
+ st.markdown(
294
+ """If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
295
+ )
296
+
297
+ st.markdown(
298
+ f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
299
+ )
300
+
301
+ st.subheader("Planning options")
302
+ st.markdown(
303
+ """
304
+ The description of each option can be found in the
305
+ [Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
306
+ """
307
+ )
308
+
309
+ col_options_1, col_options_2 = st.columns(2, gap="medium")
310
+ with col_options_1:
311
+ search_strategy_input = st.selectbox(
312
+ label="Search strategy",
313
+ options=(
314
+ "Expansion first",
315
+ "Evaluation first",
316
+ ),
317
+ index=0,
318
+ key="search_strategy_input",
319
+ )
320
+ ucb_type = st.selectbox(
321
+ label="UCB type",
322
+ options=("uct", "puct", "value"),
323
+ index=0,
324
+ key="ucb_type_input",
325
+ ) # Fixed label
326
+ c_ucb = st.number_input(
327
+ "C coefficient of UCB",
328
+ value=0.1,
329
+ placeholder="Type a number...",
330
+ key="c_ucb_input",
331
+ )
332
+
333
+ with col_options_2:
334
+ max_iterations = st.slider(
335
+ "Total number of MCTS iterations",
336
+ min_value=50,
337
+ max_value=1000,
338
+ value=300,
339
+ key="max_iterations_slider",
340
+ )
341
+ max_depth = st.slider(
342
+ "Maximal number of reaction steps",
343
+ min_value=3,
344
+ max_value=9,
345
+ value=6,
346
+ key="max_depth_slider",
347
+ )
348
+ min_mol_size = st.slider(
349
+ "Minimum size of a molecule to be precursor",
350
+ min_value=0,
351
+ max_value=7,
352
+ value=0,
353
+ key="min_mol_size_slider",
354
+ help="Number of non-hydrogen atoms in molecule",
355
+ )
356
+
357
+ search_strategy_translator = {
358
+ "Expansion first": "expansion_first",
359
+ "Evaluation first": "evaluation_first",
360
+ }
361
+ search_strategy = search_strategy_translator[search_strategy_input]
362
+
363
+ planning_params = {
364
+ "search_strategy": search_strategy,
365
+ "ucb_type": ucb_type,
366
+ "c_ucb": c_ucb,
367
+ "max_iterations": max_iterations,
368
+ "max_depth": max_depth,
369
+ "min_mol_size": min_mol_size,
370
+ }
371
+
372
+ if st.button("Start retrosynthetic planning", key="submit_planning_button"):
373
+ # Reset downstream states if replanning
374
+ st.session_state.planning_done = False
375
+ st.session_state.clustering_done = False
376
+ st.session_state.subclustering_done = False
377
+ st.session_state.tree = None
378
+ st.session_state.res = None
379
+ st.session_state.clusters = None
380
+ st.session_state.reactions_dict = None
381
+ st.session_state.subclusters = None
382
+ st.session_state.route_cgrs_dict = None
383
+ st.session_state.sb_cgrs_dict = None
384
+ st.session_state.route_json = None
385
+ active_smile_code = st.session_state.get(
386
+ "ketcher", DEFAULT_MOL
387
+ ) # Get current SMILES
388
+ st.session_state.target_smiles = (
389
+ active_smile_code # Store the SMILES used for this run
390
+ )
391
+
392
+ try:
393
+ target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
394
+ if target_molecule is None:
395
+ st.error(f"Could not parse the input SMILES: {active_smile_code}")
396
+ else:
397
+ (
398
+ building_blocks_path,
399
+ ranking_policy_weights_path,
400
+ reaction_rules_path,
401
+ ) = load_planning_resources_cached()
402
+ with st.spinner("Running retrosynthetic planning..."):
403
+ with st.status("Loading resources...", expanded=False) as status:
404
+ st.write("Loading building blocks...")
405
+ building_blocks = load_building_blocks(
406
+ building_blocks_path, standardize=False
407
+ )
408
+ st.write("Loading reaction rules...")
409
+ reaction_rules = load_reaction_rules(reaction_rules_path)
410
+ st.write("Loading policy network...")
411
+ policy_config = PolicyNetworkConfig(
412
+ weights_path=ranking_policy_weights_path
413
+ )
414
+ policy_function = PolicyNetworkFunction(
415
+ policy_config=policy_config
416
+ )
417
+ status.update(label="Resources loaded!", state="complete")
418
+
419
+ tree_config = TreeConfig(
420
+ search_strategy=planning_params["search_strategy"],
421
+ evaluation_type="rollout", # This was hardcoded, keeping it.
422
+ max_iterations=planning_params["max_iterations"],
423
+ max_depth=planning_params["max_depth"],
424
+ min_mol_size=planning_params["min_mol_size"],
425
+ init_node_value=0.5, # This was hardcoded
426
+ ucb_type=planning_params["ucb_type"],
427
+ c_ucb=planning_params["c_ucb"],
428
+ silent=True, # This was hardcoded
429
+ )
430
+
431
+ tree = Tree(
432
+ target=target_molecule,
433
+ config=tree_config,
434
+ reaction_rules=reaction_rules,
435
+ building_blocks=building_blocks,
436
+ expansion_function=policy_function,
437
+ evaluation_function=None, # This was hardcoded
438
+ )
439
+
440
+ mcts_progress_text = "Running MCTS iterations..."
441
+ mcts_bar = st.progress(0, text=mcts_progress_text)
442
+ for step, (solved, route_id) in enumerate(tree):
443
+ progress_value = min(
444
+ 1.0, (step + 1) / planning_params["max_iterations"]
445
+ )
446
+ mcts_bar.progress(
447
+ progress_value,
448
+ text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
449
+ )
450
+
451
+ res = extract_tree_stats(tree, target_molecule)
452
+
453
+ st.session_state["tree"] = tree
454
+ st.session_state["res"] = res
455
+ st.session_state.planning_done = True
456
+ st.rerun()
457
+
458
+ except Exception as e:
459
+ st.error(f"An error occurred during planning: {e}")
460
+ st.session_state.planning_done = False
461
+
462
+
463
+ def display_planning_results():
464
+ """5. Planning Results Display: Handling the presentation of results."""
465
+ if st.session_state.get("planning_done", False):
466
+ res = st.session_state.res
467
+ tree = st.session_state.tree
468
+
469
+ if res is None or tree is None:
470
+ st.error(
471
+ "Planning results are missing from session state. Please re-run planning."
472
+ )
473
+ st.session_state.planning_done = False # Reset state
474
+ return # Exit this function if no results
475
+
476
+ if res.get("solved", False): # Use .get for safety
477
+ st.header("Planning Results")
478
+ winning_nodes = (
479
+ sorted(set(tree.winning_nodes))
480
+ if hasattr(tree, "winning_nodes") and tree.winning_nodes
481
+ else []
482
+ )
483
+ st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
484
+
485
+ st.subheader("Examples of found retrosynthetic routes")
486
+ image_counter = 0
487
+ visualised_route_ids = set()
488
+
489
+ if not winning_nodes:
490
+ st.warning(
491
+ "Planning solved, but no winning nodes found in the tree object."
492
+ )
493
+ else:
494
+ for n, route_id in enumerate(winning_nodes):
495
+ if image_counter >= 3:
496
+ break
497
+ if route_id not in visualised_route_ids:
498
+ try:
499
+ visualised_route_ids.add(route_id)
500
+ num_steps = len(tree.synthesis_route(route_id))
501
+ route_score = round(tree.route_score(route_id), 3)
502
+ svg = get_route_svg(tree, route_id)
503
+ # svg = get_route_svg_from_json(st.session_state.route_json, route_id)
504
+ if svg:
505
+ st.image(
506
+ svg,
507
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
508
+ )
509
+ image_counter += 1
510
+ else:
511
+ st.warning(
512
+ f"Could not generate SVG for route {route_id}."
513
+ )
514
+ except Exception as e:
515
+ st.error(f"Error displaying route {route_id}: {e}")
516
+ else: # Not solved
517
+ st.header("Planning Results")
518
+ st.warning(
519
+ "No reaction path found for the target molecule with the current settings."
520
+ )
521
+ st.write(
522
+ "Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
523
+ )
524
+ stat_col, _ = st.columns(2)
525
+ with stat_col:
526
+ st.subheader("Run Statistics (No Solution)")
527
+ try:
528
+ if (
529
+ "target_smiles" not in res
530
+ and "target_smiles" in st.session_state
531
+ ):
532
+ res["target_smiles"] = st.session_state.target_smiles
533
+ cols_to_show = [
534
+ col
535
+ for col in [
536
+ "target_smiles",
537
+ "num_nodes",
538
+ "num_iter",
539
+ "search_time",
540
+ ]
541
+ if col in res
542
+ ]
543
+ if cols_to_show:
544
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
545
+ st.dataframe(df)
546
+ else:
547
+ st.write("No statistics to display for the unsuccessful run.")
548
+ except Exception as e:
549
+ st.error(f"Error displaying statistics: {e}")
550
+ st.write(res)
551
+
552
+
553
+ def download_planning_results():
554
+ """6. Planning Results Download: Providing functionality to download."""
555
+ if (
556
+ st.session_state.get("planning_done", False)
557
+ and st.session_state.res
558
+ and st.session_state.res.get("solved", False)
559
+ ):
560
+ res = st.session_state.res
561
+ tree = st.session_state.tree
562
+ # This section is usually placed within a column in the original script
563
+ # We'll assume it's called after display_planning_results and can use a new column or area.
564
+ # For proper layout, this should be integrated with display_planning_results' columns.
565
+ # For now, creating a placeholder or separate section for downloads:
566
+ # st.subheader("Downloads") # This might be redundant if called within a layout context.
567
+
568
+ # The original code places downloads in the second column of planning results.
569
+ # To replicate, we'd need to pass the column object or call this within that context.
570
+ # Simulating this by just creating the download links:
571
+ try:
572
+ html_body = generate_results_html(tree, html_path=None, extended=True)
573
+ dl_html = download_button(
574
+ html_body,
575
+ f"results_synplanner_{st.session_state.target_smiles}.html",
576
+ "Download results (HTML)",
577
+ )
578
+ if dl_html:
579
+ st.markdown(dl_html, unsafe_allow_html=True)
580
+
581
+ try:
582
+ res_df = pd.DataFrame(res, index=[0])
583
+ dl_csv = download_button(
584
+ res_df,
585
+ f"stats_synplanner_{st.session_state.target_smiles}.csv",
586
+ "Download statistics (CSV)",
587
+ )
588
+ if dl_csv:
589
+ st.markdown(dl_csv, unsafe_allow_html=True)
590
+ except Exception as e:
591
+ st.error(f"Could not prepare statistics CSV for download: {e}")
592
+
593
+ except Exception as e:
594
+ st.error(f"Error generating download links for planning results: {e}")
595
+
596
+
597
+ def setup_clustering():
598
+ """7. Clustering: Encapsulating the logic related to the "clustering" functionality."""
599
+ if (
600
+ st.session_state.get("planning_done", False)
601
+ and st.session_state.res
602
+ and st.session_state.res.get("solved", False)
603
+ ):
604
+ st.divider()
605
+ st.header("Clustering the retrosynthetic routes")
606
+
607
+ if st.button("Run Clustering", key="submit_clustering_button"):
608
+ # st.session_state.num_clusters_setting = num_clusters_input
609
+ st.session_state.clustering_done = False
610
+ st.session_state.subclustering_done = False
611
+ st.session_state.clusters = None
612
+ st.session_state.reactions_dict = None
613
+ st.session_state.subclusters = None
614
+ st.session_state.route_cgrs_dict = None
615
+ st.session_state.sb_cgrs_dict = None
616
+ st.session_state.route_json = None
617
+
618
+ with st.spinner("Performing clustering..."):
619
+ try:
620
+ current_tree = st.session_state.tree
621
+ if not current_tree:
622
+ st.error("Tree object not found. Please re-run planning.")
623
+ return
624
+
625
+ st.write("Calculating RoutesCGRs...")
626
+ route_cgrs_dict = compose_all_route_cgrs(current_tree)
627
+ st.write("Processing SB-CGRs...")
628
+ sb_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
629
+
630
+ results = cluster_routes(
631
+ sb_cgrs_dict, use_strat=False
632
+ ) # num_clusters was removed from args
633
+ results = dict(sorted(results.items(), key=lambda x: float(x[0])))
634
+
635
+ st.session_state.clusters = results
636
+ st.session_state.route_cgrs_dict = route_cgrs_dict
637
+ st.session_state.sb_cgrs_dict = sb_cgrs_dict
638
+ st.write("Extracting reactions...")
639
+ st.session_state.reactions_dict = extract_reactions(current_tree)
640
+ st.session_state.route_json = make_json(st.session_state.reactions_dict)
641
+
642
+ if (
643
+ st.session_state.clusters is not None
644
+ and st.session_state.reactions_dict is not None
645
+ ): # Check for None explicitly
646
+ st.session_state.clustering_done = True
647
+ st.success(
648
+ f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
649
+ )
650
+ else:
651
+ st.error("Clustering failed or returned empty results.")
652
+ st.session_state.clustering_done = False
653
+
654
+ del results # route_cgrs_dict, sb_cgrs_dict are stored
655
+ gc.collect()
656
+ st.rerun()
657
+ except Exception as e:
658
+ st.error(f"An error occurred during clustering: {e}")
659
+ st.session_state.clustering_done = False
660
+
661
+
662
+ def display_clustering_results():
663
+ """8. Clustering Results Display: Handling the presentation of results."""
664
+ if st.session_state.get("clustering_done", False):
665
+ clusters = st.session_state.clusters
666
+ # reactions_dict = st.session_state.reactions_dict # Needed for download, not directly for display here
667
+ tree = st.session_state.tree
668
+ MAX_DISPLAY_CLUSTERS_DATA = 10
669
+
670
+ if (
671
+ clusters is None or tree is None
672
+ ): # reactions_dict removed as not critical for display part
673
+ st.error(
674
+ "Clustering results (clusters or tree) are missing. Please re-run clustering."
675
+ )
676
+ st.session_state.clustering_done = False
677
+ return
678
+
679
+ st.subheader(f"Best routes from {len(clusters)} Found Clusters")
680
+ clusters_items = list(clusters.items())
681
+ first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
682
+ remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
683
+
684
+ for cluster_num, group_data in first_items:
685
+ if (
686
+ not group_data
687
+ or "route_ids" not in group_data
688
+ or not group_data["route_ids"]
689
+ ):
690
+ st.warning(f"Cluster {cluster_num} has no data or route_ids.")
691
+ continue
692
+ st.markdown(
693
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
694
+ )
695
+ route_id = group_data["route_ids"][0]
696
+ try:
697
+ num_steps = len(tree.synthesis_route(route_id))
698
+ route_score = round(tree.route_score(route_id), 3)
699
+ # svg = get_route_svg(tree, route_id)
700
+ svg = get_route_svg_from_json(st.session_state.route_json, route_id)
701
+ sb_cgr = group_data.get("sb_cgr") # Safely get sb_cgr
702
+ sb_cgr_svg = None
703
+ if sb_cgr:
704
+ sb_cgr.clean2d()
705
+ sb_cgr_svg = cgr_display(sb_cgr)
706
+
707
+ if svg and sb_cgr_svg:
708
+ col1, col2 = st.columns([0.2, 0.8])
709
+ with col1:
710
+ st.image(sb_cgr_svg, caption="SB-CGR")
711
+ with col2:
712
+ st.image(
713
+ svg,
714
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
715
+ )
716
+ elif svg: # Only route SVG available
717
+ st.image(
718
+ svg,
719
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
720
+ )
721
+ st.warning(
722
+ f"SB-CGR could not be displayed for cluster {cluster_num}."
723
+ )
724
+ else:
725
+ st.warning(
726
+ f"Could not generate SVG for route {route_id} or its SB-CGR."
727
+ )
728
+ except Exception as e:
729
+ st.error(
730
+ f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
731
+ )
732
+
733
+ if remaining_items:
734
+ with st.expander(f"... and {len(remaining_items)} more clusters"):
735
+ for cluster_num, group_data in remaining_items:
736
+ if (
737
+ not group_data
738
+ or "route_ids" not in group_data
739
+ or not group_data["route_ids"]
740
+ ):
741
+ st.warning(
742
+ f"Cluster {cluster_num} in expansion has no data or route_ids."
743
+ )
744
+ continue
745
+ st.markdown(
746
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
747
+ )
748
+ route_id = group_data["route_ids"][0]
749
+ try:
750
+ num_steps = len(tree.synthesis_route(route_id))
751
+ route_score = round(tree.route_score(route_id), 3)
752
+ # svg = get_route_svg(tree, route_id)
753
+ svg = get_route_svg_from_json(st.session_state.route_json, route_id)
754
+ sb_cgr = group_data.get("sb_cgr")
755
+ sb_cgr_svg = None
756
+ if sb_cgr:
757
+ sb_cgr.clean2d()
758
+ sb_cgr_svg = cgr_display(sb_cgr)
759
+
760
+ if svg and sb_cgr_svg:
761
+ col1, col2 = st.columns([0.2, 0.8])
762
+ with col1:
763
+ st.image(sb_cgr_svg, caption="SB-CGR")
764
+ with col2:
765
+ st.image(
766
+ svg,
767
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
768
+ )
769
+ elif svg:
770
+ st.image(
771
+ svg,
772
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
773
+ )
774
+ st.warning(
775
+ f"SB-CGR could not be displayed for cluster {cluster_num}."
776
+ )
777
+ else:
778
+ st.warning(
779
+ f"Could not generate SVG for route {route_id} or its SB-CGR."
780
+ )
781
+ except Exception as e:
782
+ st.error(
783
+ f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
784
+ )
785
+
786
+
787
+ def download_clustering_results():
788
+ """10. Clustering Results Download: Providing functionality to download."""
789
+ if st.session_state.get("clustering_done", False):
790
+ tree_for_html = st.session_state.get("tree")
791
+ clusters_for_html = st.session_state.get("clusters")
792
+ sb_cgrs_for_html = st.session_state.get(
793
+ "sb_cgrs_dict"
794
+ ) # This was used instead of reactions_dict in the original for report
795
+
796
+ if not tree_for_html:
797
+ st.warning("MCTS Tree data not found. Cannot generate cluster reports.")
798
+ return
799
+ if not clusters_for_html:
800
+ st.warning("Cluster data not found. Cannot generate cluster reports.")
801
+ return
802
+ # sb_cgrs_for_html is optional for routes_clustering_report if not essential
803
+
804
+ st.subheader("Cluster Reports") # Changed subheader in original
805
+ st.write("Generate downloadable HTML reports for each cluster:")
806
+
807
+ MAX_DOWNLOAD_LINKS_DISPLAYED = 10
808
+ num_clusters_total = len(clusters_for_html)
809
+ clusters_items = list(clusters_for_html.items())
810
+
811
+ for i, (cluster_idx, group_data) in enumerate(
812
+ clusters_items
813
+ ): # group_data might not be needed here if report uses cluster_idx
814
+ if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
815
+ break
816
+ try:
817
+ html_content = routes_clustering_report(
818
+ tree_for_html,
819
+ clusters_for_html, # Pass the whole dict
820
+ str(cluster_idx), # Pass the key of the cluster
821
+ sb_cgrs_for_html, # Pass the sb_cgrs dict
822
+ aam=False,
823
+ )
824
+ st.download_button(
825
+ label=f"Download report for cluster {cluster_idx}",
826
+ data=html_content,
827
+ file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
828
+ mime="text/html",
829
+ key=f"download_cluster_{cluster_idx}",
830
+ )
831
+ except Exception as e:
832
+ st.error(f"Error generating report for cluster {cluster_idx}: {e}")
833
+
834
+ if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
835
+ remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
836
+ remaining_count = len(remaining_items)
837
+ expander_label = f"Show remaining {remaining_count} cluster reports"
838
+ with st.expander(expander_label):
839
+ for (
840
+ group_index,
841
+ _,
842
+ ) in remaining_items: # group_data not needed here either
843
+ try:
844
+ html_content = routes_clustering_report(
845
+ tree_for_html,
846
+ clusters_for_html,
847
+ str(group_index),
848
+ sb_cgrs_for_html,
849
+ aam=False,
850
+ )
851
+ st.download_button(
852
+ label=f"Download report for cluster {group_index}",
853
+ data=html_content,
854
+ file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
855
+ mime="text/html",
856
+ key=f"download_cluster_expanded_{group_index}",
857
+ )
858
+ except Exception as e:
859
+ st.error(
860
+ f"Error generating report for cluster {group_index} (expanded): {e}"
861
+ )
862
+
863
+ try:
864
+ buffer = io.BytesIO()
865
+ with zipfile.ZipFile(
866
+ buffer, mode="w", compression=zipfile.ZIP_DEFLATED
867
+ ) as zf:
868
+ for idx, _ in clusters_items: # group_data not needed
869
+ html_content_zip = routes_clustering_report(
870
+ tree_for_html,
871
+ clusters_for_html,
872
+ str(idx),
873
+ sb_cgrs_for_html,
874
+ aam=False,
875
+ )
876
+ filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
877
+ zf.writestr(filename, html_content_zip)
878
+ buffer.seek(0)
879
+
880
+ st.download_button(
881
+ label="📦 Download all cluster reports as ZIP",
882
+ data=buffer,
883
+ file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
884
+ mime="application/zip",
885
+ key="download_all_clusters_zip",
886
+ )
887
+ except Exception as e:
888
+ st.error(f"Error generating ZIP file for cluster reports: {e}")
889
+
890
+
891
+ def setup_subclustering():
892
+ """11. Subclustering: Encapsulating the logic related to the "subclustering" functionality."""
893
+ if st.session_state.get(
894
+ "clustering_done", False
895
+ ): # Subclustering depends on clustering being done
896
+ st.divider()
897
+ st.header("Sub-Clustering within a selected Cluster")
898
+
899
+ if st.button("Run Subclustering Analysis", key="submit_subclustering_button"):
900
+ st.session_state.subclustering_done = False
901
+ st.session_state.subclusters = None
902
+ with st.spinner("Performing subclustering analysis..."):
903
+ try:
904
+ clusters_for_sub = st.session_state.get("clusters")
905
+ sb_cgrs_dict_for_sub = st.session_state.get("sb_cgrs_dict")
906
+ route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
907
+
908
+ if (
909
+ clusters_for_sub
910
+ and sb_cgrs_dict_for_sub
911
+ and route_cgrs_dict_for_sub
912
+ ): # Ensure all are present
913
+ all_subgroups = subcluster_all_clusters(
914
+ clusters_for_sub,
915
+ sb_cgrs_dict_for_sub,
916
+ route_cgrs_dict_for_sub,
917
+ )
918
+ st.session_state.subclusters = all_subgroups
919
+ st.session_state.subclustering_done = True
920
+ st.success("Subclustering analysis complete.")
921
+ gc.collect()
922
+ st.rerun()
923
+ else:
924
+ missing = []
925
+ if not clusters_for_sub:
926
+ missing.append("clusters")
927
+ if not sb_cgrs_dict_for_sub:
928
+ missing.append("SB-CGRs dictionary")
929
+ if not route_cgrs_dict_for_sub:
930
+ missing.append("RouteCGRs dictionary")
931
+ st.error(
932
+ f"Cannot run subclustering. Missing data: {', '.join(missing)}. Please ensure clustering ran successfully."
933
+ )
934
+ st.session_state.subclustering_done = False
935
+
936
+ except Exception as e:
937
+ st.error(f"An error occurred during subclustering: {e}")
938
+ st.session_state.subclustering_done = False
939
+
940
+
941
+ def display_subclustering_results():
942
+ """12. Subclustering Results Display: Handling the presentation of results."""
943
+ if st.session_state.get("subclustering_done", False):
944
+ sub = st.session_state.get("subclusters")
945
+ tree = st.session_state.get("tree")
946
+ # clusters_for_sub_display = st.session_state.get('clusters') # Not directly used in display logic from original code snippet
947
+
948
+ if not sub or not tree:
949
+ st.error(
950
+ "Subclustering results (subclusters or tree) are missing. Please re-run subclustering."
951
+ )
952
+ st.session_state.subclustering_done = False
953
+ return
954
+
955
+ sub_input_col, sub_display_col = st.columns([0.25, 0.75])
956
+
957
+ with sub_input_col:
958
+ st.subheader("Select Cluster and Subcluster")
959
+ available_cluster_nums = list(sub.keys())
960
+ if not available_cluster_nums:
961
+ st.warning("No clusters available in subclustering results.")
962
+ return # Exit if no clusters to select
963
+
964
+ user_input_cluster_num_display = st.selectbox(
965
+ "Select Cluster #:",
966
+ options=sorted(available_cluster_nums),
967
+ key="subcluster_num_select_key",
968
+ )
969
+
970
+ selected_subcluster_idx = 0
971
+
972
+ if user_input_cluster_num_display in sub:
973
+ sub_step_cluster = sub[user_input_cluster_num_display]
974
+ allowed_subclusters_indices = sorted(list(sub_step_cluster.keys()))
975
+
976
+ if not allowed_subclusters_indices:
977
+ st.warning(
978
+ f"No reaction steps (subclusters) found for Cluster {user_input_cluster_num_display}."
979
+ )
980
+ else:
981
+ selected_subcluster_idx = st.selectbox(
982
+ "Select Subcluster Index:",
983
+ options=allowed_subclusters_indices,
984
+ key="subcluster_index_select_key",
985
+ )
986
+ if selected_subcluster_idx in sub[user_input_cluster_num_display]:
987
+ current_subcluster_data = sub[user_input_cluster_num_display][
988
+ selected_subcluster_idx
989
+ ]
990
+ if "sb_cgr" in current_subcluster_data:
991
+ cluster_sb_cgr_display = current_subcluster_data["sb_cgr"]
992
+ cluster_sb_cgr_display.clean2d()
993
+ st.image(
994
+ cluster_sb_cgr_display.depict(),
995
+ caption=f"SB-CGR of parent Cluster {user_input_cluster_num_display}",
996
+ )
997
+ else:
998
+ st.warning("SB-CGR for this subcluster not found.")
999
+ else:
1000
+ st.warning(
1001
+ f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
1002
+ )
1003
+ return
1004
+
1005
+ with sub_display_col:
1006
+ st.subheader("Subcluster Details")
1007
+ if (
1008
+ user_input_cluster_num_display in sub
1009
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
1010
+ ):
1011
+
1012
+ subcluster_content = sub[user_input_cluster_num_display][
1013
+ selected_subcluster_idx
1014
+ ]
1015
+
1016
+ # subcluster_to_display = post_process_subgroup(subcluster_content) #Under development
1017
+ subcluster_to_display = subcluster_content
1018
+ if (
1019
+ not subcluster_to_display
1020
+ or "routes_data" not in subcluster_to_display
1021
+ or not subcluster_to_display["routes_data"]
1022
+ ):
1023
+ st.info("No routes or data found for this subcluster selection.")
1024
+ else:
1025
+ MAX_ROUTES_PER_SUBCLUSTER = 5
1026
+ all_route_ids_in_subcluster = list(
1027
+ subcluster_to_display["routes_data"].keys()
1028
+ )
1029
+ routes_to_display_direct = all_route_ids_in_subcluster[
1030
+ :MAX_ROUTES_PER_SUBCLUSTER
1031
+ ]
1032
+ remaining_routes_sub = all_route_ids_in_subcluster[
1033
+ MAX_ROUTES_PER_SUBCLUSTER:
1034
+ ]
1035
+
1036
+ st.markdown(
1037
+ f"--- \n**Subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}** (Size: {len(all_route_ids_in_subcluster)})"
1038
+ )
1039
+
1040
+ if "synthon_reaction" in subcluster_to_display:
1041
+ synthon_reaction = subcluster_to_display["synthon_reaction"]
1042
+ try:
1043
+ synthon_reaction.clean2d()
1044
+ st.image(
1045
+ depict_custom_reaction(synthon_reaction),
1046
+ caption=f"Markush-like pseudo reaction of subcluster",
1047
+ ) # Assuming depict_custom_reaction
1048
+ except Exception as e_depict:
1049
+ st.warning(f"Could not depict synthon reaction: {e_depict}")
1050
+ else:
1051
+ st.info("No synthon reaction data for this subcluster.")
1052
+ with st.container(height=500):
1053
+ for route_id in routes_to_display_direct:
1054
+ try:
1055
+ route_score_sub = round(tree.route_score(route_id), 3)
1056
+ # svg_sub = get_route_svg(tree, route_id)
1057
+ svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
1058
+ if svg_sub:
1059
+ st.image(
1060
+ svg_sub,
1061
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1062
+ )
1063
+ else:
1064
+ st.warning(
1065
+ f"Could not generate SVG for route {route_id}."
1066
+ )
1067
+ except Exception as e:
1068
+ st.error(
1069
+ f"Error displaying route {route_id} in subcluster: {e}"
1070
+ )
1071
+
1072
+ if remaining_routes_sub:
1073
+ with st.expander(
1074
+ f"... and {len(remaining_routes_sub)} more routes in this subcluster"
1075
+ ):
1076
+ for route_id in remaining_routes_sub:
1077
+ try:
1078
+ route_score_sub = round(
1079
+ tree.route_score(route_id), 3
1080
+ )
1081
+ # svg_sub = get_route_svg(tree, route_id)
1082
+ svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
1083
+ if svg_sub:
1084
+ st.image(
1085
+ svg_sub,
1086
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1087
+ )
1088
+ else:
1089
+ st.warning(
1090
+ f"Could not generate SVG for route {route_id}."
1091
+ )
1092
+ except Exception as e:
1093
+ st.error(
1094
+ f"Error displaying route {route_id} in subcluster (expanded): {e}"
1095
+ )
1096
+ else:
1097
+ st.info("Select a valid cluster and subcluster index to see details.")
1098
+
1099
+
1100
+ def download_subclustering_results():
1101
+ """13. Subclustering Results Download: Providing functionality to download."""
1102
+ if (
1103
+ st.session_state.get("subclustering_done", False)
1104
+ and "subcluster_num_select_key" in st.session_state
1105
+ and "subcluster_index_select_key" in st.session_state
1106
+ ):
1107
+
1108
+ sub = st.session_state.get("subclusters")
1109
+ tree = st.session_state.get("tree")
1110
+ sb_cgrs_for_report = st.session_state.get(
1111
+ "sb_cgrs_dict"
1112
+ ) # Used by routes_subclustering_report
1113
+
1114
+ user_input_cluster_num_display = st.session_state.subcluster_num_select_key
1115
+ selected_subcluster_idx = st.session_state.subcluster_index_select_key
1116
+
1117
+ if not tree or not sub or not sb_cgrs_for_report:
1118
+ st.warning(
1119
+ "Missing data for subclustering report generation (tree, subclusters, or SB-CGRs)."
1120
+ )
1121
+ return
1122
+
1123
+ if (
1124
+ user_input_cluster_num_display in sub
1125
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
1126
+ ):
1127
+
1128
+ subcluster_data_for_report = sub[user_input_cluster_num_display][
1129
+ selected_subcluster_idx
1130
+ ]
1131
+ # Apply the same post-processing as in display
1132
+ processed_subcluster_data = post_process_subgroup(
1133
+ subcluster_data_for_report
1134
+ )
1135
+ if "routes_data" in subcluster_data_for_report and isinstance(
1136
+ subcluster_data_for_report["routes_data"], dict
1137
+ ):
1138
+ processed_subcluster_data["group_lgs"] = group_by_identical_values(
1139
+ subcluster_data_for_report["routes_data"]
1140
+ )
1141
+ else:
1142
+ processed_subcluster_data["group_lgs"] = {}
1143
+
1144
+ try:
1145
+ subcluster_html_content = routes_subclustering_report(
1146
+ tree,
1147
+ processed_subcluster_data, # Pass the specific post-processed subcluster data
1148
+ user_input_cluster_num_display,
1149
+ selected_subcluster_idx,
1150
+ sb_cgrs_for_report, # Pass the whole sb_cgrs dict
1151
+ if_lg_group=True, # This parameter was in the original call
1152
+ )
1153
+ st.download_button(
1154
+ label=f"Download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}",
1155
+ data=subcluster_html_content,
1156
+ file_name=f"subcluster_{user_input_cluster_num_display}.{selected_subcluster_idx}_{st.session_state.target_smiles}.html",
1157
+ mime="text/html",
1158
+ key=f"download_subcluster_{user_input_cluster_num_display}_{selected_subcluster_idx}",
1159
+ )
1160
+ except Exception as e:
1161
+ st.error(
1162
+ f"Error generating download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
1163
+ )
1164
+ # else:
1165
+ # This case is handled by the display logic mostly, download button just won't appear or will be for previous valid selection.
1166
+
1167
+
1168
+ def implement_restart():
1169
+ """14. Restart: Implementing the logic to reset or restart the application state."""
1170
+ st.divider()
1171
+ st.header("Restart Application State")
1172
+ if st.button("Clear All Results & Restart", key="restart_button"):
1173
+ keys_to_clear = [
1174
+ "planning_done",
1175
+ "tree",
1176
+ "res",
1177
+ "target_smiles",
1178
+ "clustering_done",
1179
+ "clusters",
1180
+ "reactions_dict",
1181
+ "num_clusters_setting",
1182
+ "route_cgrs_dict",
1183
+ "sb_cgrs_dict",
1184
+ "route_json",
1185
+ "subclustering_done",
1186
+ "subclusters", # "sub" was renamed
1187
+ "clusters_downloaded",
1188
+ # Potentially ketcher related keys if they need manual reset beyond new input
1189
+ "ketcher_widget",
1190
+ "smiles_text_input_key", # Keys for widgets
1191
+ "subcluster_num_select_key",
1192
+ "subcluster_index_select_key",
1193
+ ]
1194
+ for key in keys_to_clear:
1195
+ if key in st.session_state:
1196
+ del st.session_state[key]
1197
+
1198
+ # Reset ketcher input to default by resetting its session state variable
1199
+ st.session_state.ketcher = DEFAULT_MOL
1200
+ # Also explicitly set target_smiles to empty or default to avoid stale data
1201
+ st.session_state.target_smiles = ""
1202
+
1203
+ # It's generally better to let Streamlit manage widget state if possible,
1204
+ # but for a full reset, clearing their explicit session state keys might be needed.
1205
+ st.rerun()
1206
+
1207
+
1208
+ # --- Main Application Flow ---
1209
+ def main():
1210
+ initialize_app()
1211
+ setup_sidebar()
1212
+ current_smile_code = handle_molecule_input()
1213
+ # Update session_state.ketcher if current_smile_code has changed from ketcher output
1214
+ if st.session_state.get("ketcher") != current_smile_code:
1215
+ st.session_state.ketcher = current_smile_code
1216
+ # No rerun here, let the flow continue. handle_molecule_input already warns.
1217
+
1218
+ setup_planning_options() # This function now also handles the button press and logic for planning
1219
+
1220
+ # Display planning results and download options together
1221
+ if st.session_state.get("planning_done", False):
1222
+ display_planning_results() # Displays stats and routes
1223
+ if st.session_state.res and st.session_state.res.get("solved", False):
1224
+ stat_col, download_col = st.columns(
1225
+ 2, gap="medium"
1226
+ ) # Placeholder for download column
1227
+ with stat_col:
1228
+ st.subheader("Statistics")
1229
+ try:
1230
+ res = st.session_state.res
1231
+ if (
1232
+ "target_smiles" not in res
1233
+ and "target_smiles" in st.session_state
1234
+ ):
1235
+ res["target_smiles"] = st.session_state.target_smiles
1236
+ cols_to_show = [
1237
+ col
1238
+ for col in [
1239
+ "target_smiles",
1240
+ "num_routes",
1241
+ "num_nodes",
1242
+ "num_iter",
1243
+ "search_time",
1244
+ ]
1245
+ if col in res
1246
+ ]
1247
+ if cols_to_show: # Ensure there are columns to show
1248
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
1249
+ st.dataframe(df)
1250
+ else:
1251
+ st.write("No statistics to display from planning results.")
1252
+ except Exception as e:
1253
+ st.error(f"Error displaying statistics: {e}")
1254
+ st.write(res) # Show raw dict if DataFrame fails
1255
+ with download_col:
1256
+ st.subheader("Planning Downloads") # Adding a subheader for clarity
1257
+ download_planning_results()
1258
+
1259
+ # Clustering section (setup button, display, download)
1260
+ if (
1261
+ st.session_state.get("planning_done", False)
1262
+ and st.session_state.res
1263
+ and st.session_state.res.get("solved", False)
1264
+ ):
1265
+ setup_clustering() # Contains the "Run Clustering" button and logic
1266
+ if st.session_state.get("clustering_done", False):
1267
+ display_clustering_results() # Displays cluster routes and stats
1268
+ cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
1269
+
1270
+ with cluster_stat_col:
1271
+ clusters = st.session_state.clusters
1272
+ cluster_sizes = [
1273
+ cluster.get("group_size", 0)
1274
+ for cluster in clusters.values()
1275
+ if cluster
1276
+ ] # Safe get
1277
+ st.subheader("Cluster Statistics")
1278
+ if cluster_sizes:
1279
+ cluster_df = pd.DataFrame(
1280
+ {
1281
+ "Cluster": [
1282
+ k for k, v in clusters.items() if v
1283
+ ], # Filter out empty clusters
1284
+ "Number of Routes": [
1285
+ v["group_size"] for v in clusters.values() if v
1286
+ ],
1287
+ }
1288
+ )
1289
+ if not cluster_df.empty:
1290
+ cluster_df.index += 1
1291
+ st.dataframe(cluster_df)
1292
+ best_route_html = html_top_routes_cluster(
1293
+ clusters,
1294
+ st.session_state.tree,
1295
+ st.session_state.target_smiles,
1296
+ )
1297
+ st.download_button(
1298
+ label=f"Download best route from each cluster",
1299
+ data=best_route_html,
1300
+ file_name=f"cluster_best_{st.session_state.target_smiles}.html",
1301
+ mime="text/html",
1302
+ key=f"download_cluster_best",
1303
+ )
1304
+ else:
1305
+ st.write("No valid cluster data to display statistics for.")
1306
+ # download_top_routes_cluster()
1307
+ else:
1308
+ st.write("No cluster data to display statistics for.")
1309
+ with cluster_download_col:
1310
+ download_clustering_results()
1311
+
1312
+ # Subclustering section (setup button, display, download)
1313
+ if st.session_state.get("clustering_done", False): # Depends on clustering
1314
+ setup_subclustering() # Contains "Run Subclustering" button
1315
+ if st.session_state.get("subclustering_done", False):
1316
+ display_subclustering_results() # Displays subcluster details and routes
1317
+ download_subclustering_results() # This needs to be called after selections are made in display.
1318
+
1319
+ implement_restart()
1320
+
1321
+
1322
+ if __name__ == "__main__":
1323
+ main()
synplan/mcts/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.containers import MoleculeContainer
2
+ from .node import *
3
+ from .tree import *
4
+
5
+
6
+ MoleculeContainer.depict_settings(aam=False)
7
+
8
+ __all__ = ["Tree", "Node"]
synplan/mcts/evaluation.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class that represents a value function for prediction of
2
+ synthesisablity of new nodes in the tree search."""
3
+
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from synplan.chem.precursor import Precursor, compose_precursors
9
+ from synplan.ml.networks.value import ValueNetwork
10
+ from synplan.ml.training import mol_to_pyg
11
+
12
+
13
+ class ValueNetworkFunction:
14
+ """Value function implemented as a value neural network for node evaluation
15
+ (synthesisability prediction) in tree search."""
16
+
17
+ def __init__(self, weights_path: str) -> None:
18
+ """The value function predicts the probability to synthesize the target molecule
19
+ with available building blocks starting from a given precursor.
20
+
21
+ :param weights_path: The value network weights file path.
22
+ """
23
+
24
+ value_net = ValueNetwork.load_from_checkpoint(
25
+ weights_path, map_location=torch.device("cpu")
26
+ )
27
+ self.value_network = value_net.eval()
28
+
29
+ def predict_value(self, precursors: List[Precursor,]) -> float:
30
+ """Predicts a value based on the given precursors from the node. For prediction,
31
+ precursors must be composed into a single molecule (product).
32
+
33
+ :param precursors: The list of precursors.
34
+ :return: The predicted float value ("synthesisability") of the node.
35
+ """
36
+
37
+ molecule = compose_precursors(precursors=precursors, exclude_small=True)
38
+ pyg_graph = mol_to_pyg(molecule)
39
+ if pyg_graph:
40
+ with torch.no_grad():
41
+ value_pred = self.value_network.forward(pyg_graph)[0].item()
42
+ else:
43
+ value_pred = -1e6
44
+
45
+ return value_pred
synplan/mcts/expansion.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class that represents a policy function for node expansion in the
2
+ tree search."""
3
+
4
+ from typing import Iterator, List, Tuple, Union
5
+
6
+ import torch
7
+ import torch_geometric
8
+ from CGRtools.reactor.reactor import Reactor
9
+
10
+ from synplan.chem.precursor import Precursor
11
+ from synplan.ml.networks.policy import PolicyNetwork
12
+ from synplan.ml.training import mol_to_pyg
13
+ from synplan.utils.config import PolicyNetworkConfig
14
+
15
+
16
+ class PolicyNetworkFunction:
17
+ """Policy function implemented as a policy neural network for node expansion in tree
18
+ search."""
19
+
20
+ def __init__(
21
+ self, policy_config: PolicyNetworkConfig, compile: bool = False
22
+ ) -> None:
23
+ """Initializes the expansion function (ranking or filter policy network).
24
+
25
+ :param policy_config: An expansion policy configuration.
26
+ :param compile: Is supposed to speed up the training with model compilation.
27
+ """
28
+
29
+ self.config = policy_config
30
+
31
+ policy_net = PolicyNetwork.load_from_checkpoint(
32
+ self.config.weights_path,
33
+ map_location=torch.device("cpu"),
34
+ batch_size=1,
35
+ dropout=0,
36
+ )
37
+
38
+ policy_net = policy_net.eval()
39
+ if compile:
40
+ self.policy_net = torch_geometric.compile(policy_net, dynamic=True)
41
+ else:
42
+ self.policy_net = policy_net
43
+
44
+ def predict_reaction_rules(
45
+ self, precursor: Precursor, reaction_rules: List[Reactor]
46
+ ) -> Iterator[Union[Iterator, Iterator[Tuple[float, Reactor, int]]]]:
47
+ """The policy function predicts the list of reaction rules for a given precursor.
48
+
49
+ :param precursor: The current precursor for which the reaction rules are predicted.
50
+ :param reaction_rules: The list of reaction rules from which applicable reaction
51
+ rules are predicted and selected.
52
+ :return: Yielding the predicted probability for the reaction rule, reaction rule
53
+ and reaction rule id.
54
+ """
55
+
56
+ out_dim = list(self.policy_net.modules())[-1].out_features
57
+ if out_dim != len(reaction_rules):
58
+ raise Exception(
59
+ f"The policy network output dimensionality is {out_dim}, but the number of reaction rules is {len(reaction_rules)}. "
60
+ "Probably you use a different version of the policy network. Be sure to retain the policy network "
61
+ "with the current set of reaction rules"
62
+ )
63
+
64
+ pyg_graph = mol_to_pyg(precursor.molecule, canonicalize=False)
65
+ if pyg_graph:
66
+ with torch.no_grad():
67
+ if self.policy_net.policy_type == "filtering":
68
+ probs, priority = self.policy_net.forward(pyg_graph)
69
+ if self.policy_net.policy_type == "ranking":
70
+ probs = self.policy_net.forward(pyg_graph)
71
+ del pyg_graph
72
+ else:
73
+ return []
74
+
75
+ probs = probs[0].double()
76
+ if self.policy_net.policy_type == "filtering":
77
+ priority = priority[0].double()
78
+ priority_coef = self.config.priority_rules_fraction
79
+ probs = (1 - priority_coef) * probs + priority_coef * priority
80
+
81
+ sorted_probs, sorted_rules = torch.sort(probs, descending=True)
82
+ sorted_probs, sorted_rules = (
83
+ sorted_probs[: self.config.top_rules],
84
+ sorted_rules[: self.config.top_rules],
85
+ )
86
+
87
+ if self.policy_net.policy_type == "filtering":
88
+ sorted_probs = torch.softmax(sorted_probs, -1)
89
+
90
+ sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist()
91
+
92
+ for prob, rule_id in zip(sorted_probs, sorted_rules):
93
+ if (
94
+ prob > self.config.rule_prob_threshold
95
+ ): # search may fail if rule_prob_threshold is too low (recommended value is 0.0)
96
+ yield prob, reaction_rules[rule_id], rule_id
synplan/mcts/node.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class Node in the tree search."""
2
+
3
+
4
+ class Node:
5
+ """Node class represents a node in the tree search."""
6
+
7
+ def __init__(
8
+ self, precursors_to_expand: tuple = None, new_precursors: tuple = None
9
+ ) -> None:
10
+ """The function initializes the new Node object.
11
+
12
+ :param precursors_to_expand: The tuple of precursors to be expanded. The first precursor
13
+ in the tuple is the current precursor which will be expanded (for which new
14
+ precursors will be generated by applying the predicted reaction rules). When
15
+ the first precursor has been successfully expanded, the second precursor becomes
16
+ the current precursor to be expanded.
17
+ :param new_precursors: The tuple of new precursors generated by applying the reaction
18
+ rule.
19
+ """
20
+
21
+ self.precursors_to_expand = precursors_to_expand
22
+ self.new_precursors = new_precursors
23
+
24
+ if len(self.precursors_to_expand) == 0:
25
+ self.curr_precursor = tuple()
26
+ else:
27
+ self.curr_precursor = self.precursors_to_expand[0]
28
+ self.next_precursor = self.precursors_to_expand[1:]
29
+
30
+ def __len__(self) -> int:
31
+ """Returns the number of precursor in the node to expand."""
32
+ return len(self.precursors_to_expand)
33
+
34
+ def __repr__(self) -> str:
35
+ """Returns the SMILES of each precursor in precursor_to_expand and new_precursor."""
36
+ return (
37
+ f"New precursors: {self.new_precursors}\n"
38
+ f"Precursors to expand: {self.precursors_to_expand}\n"
39
+ )
40
+
41
+ def is_solved(self) -> bool:
42
+ """If True, it is a terminal node.
43
+
44
+ There are no precursors for expansion.
45
+ """
46
+
47
+ return len(self.precursors_to_expand) == 0
synplan/mcts/search.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for running tree search for the set of target
2
+ molecules."""
3
+
4
+ import csv
5
+ import json
6
+ import logging
7
+ import os.path
8
+ from pathlib import Path
9
+ from typing import Union
10
+
11
+ from CGRtools.containers import MoleculeContainer
12
+ from tqdm import tqdm
13
+
14
+ from synplan.chem.reaction_routes.route_cgr import extract_reactions
15
+ from synplan.chem.reaction_routes.io import write_routes_csv, write_routes_json
16
+ from synplan.chem.utils import mol_from_smiles
17
+ from synplan.mcts.evaluation import ValueNetworkFunction
18
+ from synplan.mcts.expansion import PolicyNetworkFunction
19
+ from synplan.mcts.tree import Tree, TreeConfig
20
+ from synplan.utils.config import PolicyNetworkConfig
21
+ from synplan.utils.loading import load_building_blocks, load_reaction_rules
22
+ from synplan.utils.visualisation import extract_routes, generate_results_html
23
+
24
+
25
+ def extract_tree_stats(
26
+ tree: Tree, target: Union[str, MoleculeContainer], init_smiles: str = None
27
+ ):
28
+ """Collects various statistics from a tree and returns them in a dictionary format.
29
+
30
+ :param tree: The built search tree.
31
+ :param target: The target molecule associated with the tree.
32
+ :param init_smiles: initial SMILES of the molecule, optional.
33
+ :return: A dictionary with the calculated statistics.
34
+ """
35
+
36
+ newick_tree, newick_meta = tree.newickify(visits_threshold=0)
37
+ newick_meta_line = ";".join(
38
+ [f"{nid},{v[0]},{v[1]},{v[2]}" for nid, v in newick_meta.items()]
39
+ )
40
+
41
+ return {
42
+ "target_smiles": init_smiles if init_smiles is not None else str(target),
43
+ "num_routes": len(tree.winning_nodes),
44
+ "num_nodes": len(tree),
45
+ "num_iter": tree.curr_iteration,
46
+ "tree_depth": max(tree.nodes_depth.values()),
47
+ "search_time": round(tree.curr_time, 1),
48
+ "newick_tree": newick_tree,
49
+ "newick_meta": newick_meta_line,
50
+ "solved": True if len(tree.winning_nodes) > 0 else False,
51
+ }
52
+
53
+
54
+ def run_search(
55
+ targets_path: str,
56
+ search_config: dict,
57
+ policy_config: PolicyNetworkConfig,
58
+ reaction_rules_path: str,
59
+ building_blocks_path: str,
60
+ value_network_path: str = None,
61
+ results_root: str = "search_results",
62
+ ) -> None:
63
+ """Performs a tree search on a set of target molecules using specified configuration
64
+ and reaction rules, logging the results and statistics.
65
+
66
+ :param targets_path: The path to the file containing the target molecules (in SDF or
67
+ SMILES format).
68
+ :param search_config: The config object containing the configuration for the tree
69
+ search.
70
+ :param policy_config: The config object containing the configuration for the policy.
71
+ :param reaction_rules_path: The path to the file containing reaction rules.
72
+ :param building_blocks_path: The path to the file containing building blocks.
73
+ :param value_network_path: The path to the file containing value weights (optional).
74
+ :param results_root: The name of the folder where the results of the tree search
75
+ will be saved.
76
+ :return: None.
77
+ """
78
+
79
+ # results folder
80
+ results_root = Path(results_root)
81
+ if not results_root.exists():
82
+ results_root.mkdir()
83
+
84
+ # output files
85
+ stats_file = results_root.joinpath("tree_search_stats.csv")
86
+ routes_file = results_root.joinpath("extracted_routes.json")
87
+ routes_folder = results_root.joinpath("extracted_routes_html")
88
+ routes_folder.mkdir(exist_ok=True)
89
+
90
+ # stats header
91
+ stats_header = [
92
+ "target_smiles",
93
+ "num_routes",
94
+ "num_nodes",
95
+ "num_iter",
96
+ "tree_depth",
97
+ "search_time",
98
+ "newick_tree",
99
+ "newick_meta",
100
+ "solved",
101
+ "error",
102
+ ]
103
+
104
+ # config
105
+ policy_function = PolicyNetworkFunction(policy_config=policy_config)
106
+ if search_config["evaluation_type"] == "gcn" and value_network_path:
107
+ value_function = ValueNetworkFunction(weights_path=value_network_path)
108
+ else:
109
+ value_function = None
110
+
111
+ reaction_rules = load_reaction_rules(reaction_rules_path)
112
+ building_blocks = load_building_blocks(building_blocks_path, standardize=True)
113
+
114
+ # run search
115
+ n_solved = 0
116
+ extracted_routes = []
117
+
118
+ tree_config = TreeConfig.from_dict(search_config)
119
+ tree_config.silent = True
120
+ with (
121
+ open(targets_path, "r", encoding="utf-8") as targets,
122
+ open(stats_file, "w", encoding="utf-8", newline="\n") as csvfile,
123
+ ):
124
+
125
+ statswriter = csv.DictWriter(csvfile, delimiter=",", fieldnames=stats_header)
126
+ statswriter.writeheader()
127
+
128
+ for ti, target_smi in tqdm(
129
+ enumerate(targets),
130
+ leave=True,
131
+ desc="Number of target molecules processed: ",
132
+ bar_format="{desc}{n} [{elapsed}]",
133
+ ):
134
+ target_smi = target_smi.strip()
135
+ target_mol = mol_from_smiles(target_smi)
136
+ try:
137
+ # run search
138
+ tree = Tree(
139
+ target=target_mol,
140
+ config=tree_config,
141
+ reaction_rules=reaction_rules,
142
+ building_blocks=building_blocks,
143
+ expansion_function=policy_function,
144
+ evaluation_function=value_function,
145
+ )
146
+
147
+ _ = list(tree)
148
+
149
+ except Exception as e:
150
+ extracted_routes.append(
151
+ [
152
+ {
153
+ "type": "mol",
154
+ "smiles": target_smi,
155
+ "in_stock": False,
156
+ "children": [],
157
+ }
158
+ ]
159
+ )
160
+ logging.warning(
161
+ f"Retrosynthetic_planning {target_smi} failed with the following error: {e}"
162
+ )
163
+
164
+ continue
165
+
166
+ # is solved
167
+ n_solved += bool(tree.winning_nodes)
168
+ if bool(tree.winning_nodes):
169
+
170
+ # extract routes
171
+ extracted_routes.append(extract_routes(tree))
172
+
173
+ # save routes
174
+ generate_results_html(
175
+ tree,
176
+ os.path.join(routes_folder, f"retroroutes_target_{ti}.html"),
177
+ extended=True,
178
+ )
179
+
180
+ # save stats
181
+ statswriter.writerow(extract_tree_stats(tree, target_smi))
182
+ csvfile.flush()
183
+
184
+ # save json routes
185
+ with open(routes_file, "w", encoding="utf-8") as f:
186
+ json.dump(extracted_routes, f)
187
+
188
+ # Save mapped reactions (CSV)
189
+ routes_dict = extract_reactions(tree)
190
+ write_routes_csv(
191
+ routes_dict, os.path.join(routes_folder, f"mapped_routes_{ti}.csv")
192
+ )
193
+
194
+ # save mapped reactions (JSON)
195
+ write_routes_json(
196
+ routes_dict, os.path.join(routes_folder, f"mapped_routes_{ti}.json")
197
+ )
198
+
199
+ print(f"Number of solved target molecules: {n_solved}")
synplan/mcts/tree.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class Tree that used for tree search of retrosynthetic routes."""
2
+
3
+ import logging
4
+ import warnings
5
+ from collections import defaultdict, deque
6
+ from math import sqrt
7
+ from random import choice, uniform
8
+ from time import time
9
+ from typing import Dict, List, Set, Tuple
10
+
11
+ from CGRtools.reactor import Reactor
12
+ from CGRtools.containers import MoleculeContainer
13
+ from tqdm.auto import tqdm
14
+
15
+ from synplan.chem.precursor import Precursor
16
+ from synplan.chem.reaction import Reaction, apply_reaction_rule
17
+ from synplan.mcts.evaluation import ValueNetworkFunction
18
+ from synplan.mcts.expansion import PolicyNetworkFunction
19
+ from synplan.mcts.node import Node
20
+ from synplan.utils.config import TreeConfig
21
+
22
+
23
+ class Tree:
24
+ """Tree class with attributes and methods for Monte-Carlo tree search."""
25
+
26
+ def __init__(
27
+ self,
28
+ target: MoleculeContainer,
29
+ config: TreeConfig,
30
+ reaction_rules: List[Reactor],
31
+ building_blocks: Set[str],
32
+ expansion_function: PolicyNetworkFunction,
33
+ evaluation_function: ValueNetworkFunction = None,
34
+ ):
35
+ """Initializes a tree object with optional parameters for tree search for target
36
+ molecule.
37
+
38
+ :param target: A target molecule for retrosynthetic routes search.
39
+ :param config: A tree configuration.
40
+ :param reaction_rules: A loaded reaction rules.
41
+ :param building_blocks: A loaded building blocks.
42
+ :param expansion_function: A loaded policy function.
43
+ :param evaluation_function: A loaded value function. If None, the rollout is
44
+ used as a default for node evaluation.
45
+ """
46
+
47
+ # config parameters
48
+ self.config = config
49
+
50
+ assert isinstance(
51
+ target, MoleculeContainer
52
+ ), "Target should be given as MoleculeContainer"
53
+ assert len(target) > 3, "Target molecule has less than 3 atoms"
54
+
55
+ target_molecule = Precursor(target)
56
+ target_molecule.prev_precursors.append(Precursor(target))
57
+ target_node = Node(
58
+ precursors_to_expand=(target_molecule,), new_precursors=(target_molecule,)
59
+ )
60
+
61
+ # tree structure init
62
+ self.nodes: Dict[int, Node] = {1: target_node}
63
+ self.parents: Dict[int, int] = {1: 0}
64
+ self.children: Dict[int, Set[int]] = {1: set()}
65
+ self.winning_nodes: List[int] = []
66
+ self.visited_nodes: Set[int] = set()
67
+ self.expanded_nodes: Set[int] = set()
68
+ self.nodes_visit: Dict[int, int] = {1: 0}
69
+ self.nodes_depth: Dict[int, int] = {1: 0}
70
+ self.nodes_prob: Dict[int, float] = {1: 0.0}
71
+ self.nodes_rules: Dict[int, float] = {}
72
+ self.nodes_init_value: Dict[int, float] = {1: 0.0}
73
+ self.nodes_total_value: Dict[int, float] = {1: 0.0}
74
+
75
+ # tree building limits
76
+ self.curr_iteration: int = 0
77
+ self.curr_tree_size: int = 2
78
+ self.start_time: float = 0
79
+ self.curr_time: float = 0
80
+
81
+ # building blocks and reaction reaction_rules
82
+ self.reaction_rules = reaction_rules
83
+ self.building_blocks = building_blocks
84
+
85
+ # policy and value functions
86
+ self.policy_network = expansion_function
87
+ if self.config.evaluation_type == "gcn":
88
+ if evaluation_function is None:
89
+ raise ValueError(
90
+ "Value function not specified while evaluation type is 'gcn'"
91
+ )
92
+ if (
93
+ evaluation_function is not None
94
+ and self.config.evaluation_type == "rollout"
95
+ ):
96
+ raise ValueError(
97
+ "Value function is not None while evaluation type is 'rollout'. What should be evaluation type ?"
98
+ )
99
+ self.value_network = evaluation_function
100
+
101
+ # utils
102
+ self._tqdm = True # needed to disable tqdm with multiprocessing module
103
+
104
+ target_smiles = str(self.nodes[1].curr_precursor.molecule)
105
+ if target_smiles in self.building_blocks:
106
+ self.building_blocks.remove(target_smiles)
107
+ print(
108
+ "Target was found in building blocks and removed from building blocks."
109
+ )
110
+
111
+ def __len__(self) -> int:
112
+ """Returns the current size (the number of nodes) in the tree."""
113
+
114
+ return self.curr_tree_size - 1
115
+
116
+ def __iter__(self) -> "Tree":
117
+ """The function is defining an iterator for a Tree object.
118
+
119
+ Also needed for the bar progress display.
120
+ """
121
+
122
+ self.start_time = time()
123
+ if self._tqdm:
124
+ self._tqdm = tqdm(
125
+ total=self.config.max_iterations, disable=self.config.silent
126
+ )
127
+ return self
128
+
129
+ def __repr__(self) -> str:
130
+ """Returns a string representation of the tree (target SMILES, tree size, and
131
+ the number of found routes)."""
132
+ return self.report()
133
+
134
+ def __next__(self) -> [bool, List[int]]:
135
+ """The __next__ method is used to do one iteration of the tree building.
136
+
137
+ :return: Returns True if the route was found and the node id of the last node in
138
+ the route. Otherwise, returns False and the id of the last visited node.
139
+ """
140
+
141
+ if self.curr_iteration >= self.config.max_iterations:
142
+ raise StopIteration("Iterations limit exceeded.")
143
+ if self.curr_tree_size >= self.config.max_tree_size:
144
+ raise StopIteration("Max tree size exceeded or all possible routes found.")
145
+ if self.curr_time >= self.config.max_time:
146
+ raise StopIteration("Time limit exceeded.")
147
+
148
+ # start new iteration
149
+ self.curr_iteration += 1
150
+ self.curr_time = time() - self.start_time
151
+
152
+ if self._tqdm:
153
+ self._tqdm.update()
154
+
155
+ curr_depth, node_id = 0, 1 # start from the root node_id
156
+
157
+ explore_route = True
158
+ while explore_route:
159
+ self.visited_nodes.add(node_id)
160
+
161
+ if self.nodes_visit[node_id]: # already visited
162
+ if not self.children[node_id]: # dead node
163
+ self._update_visits(node_id)
164
+ explore_route = False
165
+ else:
166
+ node_id = self._select_node(node_id) # select the child node
167
+ curr_depth += 1
168
+ else:
169
+ if self.nodes[node_id].is_solved(): # found route
170
+ self._update_visits(
171
+ node_id
172
+ ) # this prevents expanding of bb node_id
173
+ self.winning_nodes.append(node_id)
174
+ return True, [node_id]
175
+
176
+ if (
177
+ curr_depth < self.config.max_depth
178
+ ): # expand node if depth limit is not reached
179
+ self._expand_node(node_id)
180
+ if not self.children[node_id]: # node was not expanded
181
+ value_to_backprop = -1.0
182
+ else:
183
+ self.expanded_nodes.add(node_id)
184
+
185
+ if self.config.search_strategy == "evaluation_first":
186
+ # recalculate node value based on children synthesisability and backpropagation
187
+ child_values = [
188
+ self.nodes_init_value[child_id]
189
+ for child_id in self.children[node_id]
190
+ ]
191
+
192
+ if self.config.evaluation_agg == "max":
193
+ value_to_backprop = max(child_values)
194
+
195
+ elif self.config.evaluation_agg == "average":
196
+ value_to_backprop = sum(child_values) / len(
197
+ self.children[node_id]
198
+ )
199
+
200
+ elif self.config.search_strategy == "expansion_first":
201
+ value_to_backprop = self._get_node_value(node_id)
202
+
203
+ # backpropagation
204
+ self._backpropagate(node_id, value_to_backprop)
205
+ self._update_visits(node_id)
206
+ explore_route = False
207
+
208
+ if self.children[node_id]:
209
+ # found after expansion
210
+ found_after_expansion = set()
211
+ for child_id in iter(self.children[node_id]):
212
+ if self.nodes[child_id].is_solved():
213
+ found_after_expansion.add(child_id)
214
+ self.winning_nodes.append(child_id)
215
+
216
+ if found_after_expansion:
217
+ return True, list(found_after_expansion)
218
+
219
+ else:
220
+ self._backpropagate(node_id, self.nodes_total_value[node_id])
221
+ self._update_visits(node_id)
222
+ explore_route = False
223
+
224
+ return False, [node_id]
225
+
226
+ def _ucb(self, node_id: int) -> float:
227
+ """Calculates the Upper Confidence Bound (UCB) statistics for a given node.
228
+
229
+ :param node_id: The id of the node.
230
+ :return: The calculated UCB.
231
+ """
232
+
233
+ prob = self.nodes_prob[node_id] # predicted by policy network score
234
+ visit = self.nodes_visit[node_id]
235
+
236
+ if self.config.ucb_type == "puct":
237
+ u = (
238
+ self.config.c_ucb * prob * sqrt(self.nodes_visit[self.parents[node_id]])
239
+ ) / (visit + 1)
240
+ ucb_value = self.nodes_total_value[node_id] + u
241
+
242
+ if self.config.ucb_type == "uct":
243
+ u = (
244
+ self.config.c_ucb
245
+ * sqrt(self.nodes_visit[self.parents[node_id]])
246
+ / (visit + 1)
247
+ )
248
+ ucb_value = self.nodes_total_value[node_id] + u
249
+
250
+ if self.config.ucb_type == "value":
251
+ ucb_value = self.nodes_init_value[node_id] / (visit + 1)
252
+
253
+ return ucb_value
254
+
255
+ def _select_node(self, node_id: int) -> int:
256
+ """Selects a node based on its UCB value and returns the id of the node with the
257
+ highest UCB.
258
+
259
+ :param node_id: The id of the node.
260
+ :return: The id of the node with the highest UCB.
261
+ """
262
+
263
+ if self.config.epsilon > 0:
264
+ n = uniform(0, 1)
265
+ if n < self.config.epsilon:
266
+ return choice(list(self.children[node_id]))
267
+
268
+ best_score, best_children = None, []
269
+ for child_id in self.children[node_id]:
270
+ score = self._ucb(child_id)
271
+ if best_score is None or score > best_score:
272
+ best_score, best_children = score, [child_id]
273
+ elif score == best_score:
274
+ best_children.append(child_id)
275
+
276
+ # is needed for tree search reproducibility, when all child nodes has the same score
277
+ return best_children[0]
278
+
279
+ def _expand_node(self, node_id: int) -> None:
280
+ """Expands the node by generating new precursor with policy (expansion) function.
281
+
282
+ :param node_id: The id the node to be expanded.
283
+ :return: None.
284
+ """
285
+ curr_node = self.nodes[node_id]
286
+ prev_precursor = curr_node.curr_precursor.prev_precursors
287
+
288
+ tmp_precursor = set()
289
+ expanded = False
290
+ for prob, rule, rule_id in self.policy_network.predict_reaction_rules(
291
+ curr_node.curr_precursor, self.reaction_rules
292
+ ):
293
+ for products in apply_reaction_rule(
294
+ curr_node.curr_precursor.molecule, rule
295
+ ):
296
+ # check repeated products
297
+ if not products or not set(products) - tmp_precursor:
298
+ continue
299
+ tmp_precursor.update(products)
300
+
301
+ for molecule in products:
302
+ molecule.meta["reactor_id"] = rule_id
303
+
304
+ new_precursor = tuple(Precursor(mol) for mol in products)
305
+ scaled_prob = prob * len(
306
+ list(filter(lambda x: len(x) > self.config.min_mol_size, products))
307
+ )
308
+
309
+ if set(prev_precursor).isdisjoint(new_precursor):
310
+ precursors_to_expand = (
311
+ *curr_node.next_precursor,
312
+ *(
313
+ x
314
+ for x in new_precursor
315
+ if not x.is_building_block(
316
+ self.building_blocks, self.config.min_mol_size
317
+ )
318
+ ),
319
+ )
320
+
321
+ child_node = Node(
322
+ precursors_to_expand=precursors_to_expand,
323
+ new_precursors=new_precursor,
324
+ )
325
+
326
+ for new_precursor in new_precursor:
327
+ new_precursor.prev_precursors = [new_precursor, *prev_precursor]
328
+
329
+ self._add_node(node_id, child_node, scaled_prob, rule_id)
330
+
331
+ expanded = True
332
+ if not expanded and node_id == 1:
333
+ raise StopIteration("\nThe target molecule was not expanded.")
334
+
335
+ def _add_node(
336
+ self,
337
+ node_id: int,
338
+ new_node: Node,
339
+ policy_prob: float = None,
340
+ rule_id: int = None,
341
+ ) -> None:
342
+ """Adds a new node to the tree with probability of reaction rules predicted by
343
+ policy function and applied to the parent node of the new node.
344
+
345
+ :param node_id: The id of the parent node.
346
+ :param new_node: The new node to be added.
347
+ :param policy_prob: The probability of reaction rules predicted by policy
348
+ function for thr parent node.
349
+ :return: None.
350
+ """
351
+
352
+ new_node_id = self.curr_tree_size
353
+
354
+ self.nodes[new_node_id] = new_node
355
+ self.parents[new_node_id] = node_id
356
+ self.children[node_id].add(new_node_id)
357
+ self.children[new_node_id] = set()
358
+ self.nodes_visit[new_node_id] = 0
359
+ self.nodes_prob[new_node_id] = policy_prob
360
+ self.nodes_rules[new_node_id] = rule_id
361
+ self.nodes_depth[new_node_id] = self.nodes_depth[node_id] + 1
362
+ self.curr_tree_size += 1
363
+
364
+ if self.config.search_strategy == "evaluation_first":
365
+ node_value = self._get_node_value(new_node_id)
366
+ elif self.config.search_strategy == "expansion_first":
367
+ node_value = self.config.init_node_value
368
+
369
+ self.nodes_init_value[new_node_id] = node_value
370
+ self.nodes_total_value[new_node_id] = node_value
371
+
372
+ def _get_node_value(self, node_id: int) -> float:
373
+ """Calculates the value for the given node (for example with rollout or value
374
+ network).
375
+
376
+ :param node_id: The id of the node to be evaluated.
377
+ :return: The estimated value of the node.
378
+ """
379
+
380
+ node = self.nodes[node_id]
381
+
382
+ if self.config.evaluation_type == "random":
383
+ node_value = uniform(0, 1)
384
+
385
+ elif self.config.evaluation_type == "rollout":
386
+ node_value = min(
387
+ (
388
+ self._rollout_node(
389
+ precursor, current_depth=self.nodes_depth[node_id]
390
+ )
391
+ for precursor in node.precursors_to_expand
392
+ ),
393
+ default=1.0,
394
+ )
395
+
396
+ elif self.config.evaluation_type == "gcn":
397
+ node_value = self.value_network.predict_value(node.new_precursors)
398
+
399
+ return node_value
400
+
401
+ def _update_visits(self, node_id: int) -> None:
402
+ """Updates the number of visits from the current node to the root node.
403
+
404
+ :param node_id: The id of the current node.
405
+ :return: None.
406
+ """
407
+
408
+ while node_id:
409
+ self.nodes_visit[node_id] += 1
410
+ node_id = self.parents[node_id]
411
+
412
+ def _backpropagate(self, node_id: int, value: float) -> None:
413
+ """Backpropagates the value through the tree from the current.
414
+
415
+ :param node_id: The id of the node from which to backpropagate the value.
416
+ :param value: The value to backpropagate.
417
+ :return: None.
418
+ """
419
+ while node_id:
420
+ if self.config.backprop_type == "muzero":
421
+ self.nodes_total_value[node_id] = (
422
+ self.nodes_total_value[node_id] * self.nodes_visit[node_id] + value
423
+ ) / (self.nodes_visit[node_id] + 1)
424
+ elif self.config.backprop_type == "cumulative":
425
+ self.nodes_total_value[node_id] += value
426
+ node_id = self.parents[node_id]
427
+
428
+ def _rollout_node(self, precursor: Precursor, current_depth: int = None) -> float:
429
+ """Performs a rollout simulation from a given node in the tree. Given the
430
+ current precursor, find the first successful reaction and return the new precursor.
431
+
432
+ If the precursor is a building_block, return 1.0, else check the
433
+ first successful reaction.
434
+
435
+ If the reaction is not successful, return -1.0.
436
+
437
+ If the reaction is successful, but the generated precursor are not
438
+ the building_blocks and the precursor cannot be generated without
439
+ exceeding current_depth threshold, return -0.5.
440
+
441
+ If the reaction is successful, but the precursor are not the
442
+ building_blocks and the precursor cannot be generated, return
443
+ -1.0.
444
+
445
+ :param precursor: The precursor to be evaluated.
446
+ :param current_depth: The current depth of the tree.
447
+ :return: The reward (value) assigned to the precursor.
448
+ """
449
+
450
+ max_depth = self.config.max_depth - current_depth
451
+
452
+ # precursor checking
453
+ if precursor.is_building_block(self.building_blocks, self.config.min_mol_size):
454
+ return 1.0
455
+
456
+ if max_depth == 0:
457
+ print("max depth reached in the beginning")
458
+
459
+ # precursor simulating
460
+ occurred_precursor = set()
461
+ precursor_to_expand = deque([precursor])
462
+ history = defaultdict(dict)
463
+ rollout_depth = 0
464
+ while precursor_to_expand:
465
+ # Iterate through reactors and pick first successful reaction.
466
+ # Check products of the reaction if you can find them in in-building_blocks data
467
+ # If not, then add missed products to precursor_to_expand and try to decompose them
468
+ if len(history) >= max_depth:
469
+ reward = -0.5
470
+ return reward
471
+
472
+ current_precursor = precursor_to_expand.popleft()
473
+ history[rollout_depth]["target"] = current_precursor
474
+ occurred_precursor.add(current_precursor)
475
+
476
+ # Pick the first successful reaction while iterating through reactors
477
+ reaction_rule_applied = False
478
+ for prob, rule, rule_id in self.policy_network.predict_reaction_rules(
479
+ current_precursor, self.reaction_rules
480
+ ):
481
+ for products in apply_reaction_rule(current_precursor.molecule, rule):
482
+ if products:
483
+ reaction_rule_applied = True
484
+ break
485
+
486
+ if reaction_rule_applied:
487
+ history[rollout_depth]["rule_index"] = rule_id
488
+ break
489
+
490
+ if not reaction_rule_applied:
491
+ reward = -1.0
492
+ return reward
493
+
494
+ products = tuple(Precursor(product) for product in products)
495
+ history[rollout_depth]["products"] = products
496
+
497
+ # check loops
498
+ if any(x in occurred_precursor for x in products) and products:
499
+ # sometimes manual can create a loop, when
500
+ # print('occurred_precursor')
501
+ reward = -1.0
502
+ return reward
503
+
504
+ if occurred_precursor.isdisjoint(products):
505
+ # added number of atoms check
506
+ precursor_to_expand.extend(
507
+ [
508
+ x
509
+ for x in products
510
+ if not x.is_building_block(
511
+ self.building_blocks, self.config.min_mol_size
512
+ )
513
+ ]
514
+ )
515
+ rollout_depth += 1
516
+
517
+ reward = 1.0
518
+ return reward
519
+
520
+ def report(self) -> str:
521
+ """Returns the string representation of the tree."""
522
+
523
+ return (
524
+ f"Tree for: {str(self.nodes[1].precursors_to_expand[0])}\n"
525
+ f"Time: {round(self.curr_time, 1)} seconds\n"
526
+ f"Number of nodes: {len(self)}\n"
527
+ f"Number of iterations: {self.curr_iteration}\n"
528
+ f"Number of visited nodes: {len(self.visited_nodes)}\n"
529
+ f"Number of found routes: {len(self.winning_nodes)}"
530
+ )
531
+
532
+ def route_score(self, node_id: int) -> float:
533
+ """Calculates the score of a given route from the current node to the root node.
534
+ The score depends on cumulated node values nad the route length.
535
+
536
+ :param node_id: The id of the current given node.
537
+ :return: The route score.
538
+ """
539
+
540
+ cumulated_nodes_value, route_length = 0, 0
541
+ while node_id:
542
+ route_length += 1
543
+
544
+ cumulated_nodes_value += self.nodes_total_value[node_id]
545
+ node_id = self.parents[node_id]
546
+
547
+ return cumulated_nodes_value / (route_length**2)
548
+
549
+ def route_to_node(self, node_id: int) -> List[Node,]:
550
+ """Returns the route (list of id of nodes) to from the node current node to the
551
+ root node.
552
+
553
+ :param node_id: The id of the current node.
554
+ :return: The list of nodes.
555
+ """
556
+
557
+ nodes = []
558
+ while node_id:
559
+ nodes.append(node_id)
560
+ node_id = self.parents[node_id]
561
+ return [self.nodes[node_id] for node_id in reversed(nodes)]
562
+
563
+ def synthesis_route(self, node_id: int) -> Tuple[Reaction,]:
564
+ """Given a node_id, return a tuple of reactions that represent the
565
+ retrosynthetic route from the current node.
566
+
567
+ :param node_id: The id of the current node.
568
+ :return: The tuple of extracted reactions representing the synthesis route.
569
+ """
570
+
571
+ nodes = self.route_to_node(node_id)
572
+
573
+ reaction_sequence = [
574
+ Reaction(
575
+ [x.molecule for x in after.new_precursors],
576
+ [before.curr_precursor.molecule],
577
+ )
578
+ for before, after in zip(nodes, nodes[1:])
579
+ ]
580
+
581
+ for r in reaction_sequence:
582
+ r.clean2d()
583
+ return tuple(reversed(reaction_sequence))
584
+
585
+ def newickify(self, visits_threshold: int = 0, root_node_id: int = 1):
586
+ """
587
+ Adopted from https://stackoverflow.com/questions/50003007/how-to-convert-python-dictionary-to-newick-form-format.
588
+
589
+ :param visits_threshold: The minimum number of visits for the given node.
590
+ :param root_node_id: The id of the root node.
591
+
592
+ :return: The newick string and meta dict.
593
+ """
594
+ visited_nodes = set()
595
+
596
+ def newick_render_node(current_node_id: int) -> str:
597
+ """Recursively generates a Newick string representation of the tree.
598
+
599
+ :param current_node_id: The id of the current node.
600
+ :return: A string representation of a node in a Newick format.
601
+ """
602
+ assert (
603
+ current_node_id not in visited_nodes
604
+ ), "Error: The tree may not be circular!"
605
+ node_visit = self.nodes_visit[current_node_id]
606
+
607
+ visited_nodes.add(current_node_id)
608
+ if self.children[current_node_id]:
609
+ # Nodes
610
+ children = [
611
+ child
612
+ for child in list(self.children[current_node_id])
613
+ if self.nodes_visit[child] >= visits_threshold
614
+ ]
615
+ children_strings = [newick_render_node(child) for child in children]
616
+ children_strings = ",".join(children_strings)
617
+ if children_strings:
618
+ return f"({children_strings}){current_node_id}:{node_visit}"
619
+ # leafs within threshold
620
+ return f"{current_node_id}:{node_visit}"
621
+
622
+ return f"{current_node_id}:{node_visit}"
623
+
624
+ newick_string = newick_render_node(root_node_id) + ";"
625
+
626
+ meta = {}
627
+ for node_id in iter(visited_nodes):
628
+ node_value = round(self.nodes_total_value[node_id], 3)
629
+
630
+ node_synthesisability = round(self.nodes_init_value[node_id])
631
+
632
+ visit_in_node = self.nodes_visit[node_id]
633
+ meta[node_id] = (node_value, node_synthesisability, visit_in_node)
634
+
635
+ return newick_string, meta
synplan/ml/__init__.py ADDED
File without changes
synplan/ml/networks/__init__.py ADDED
File without changes
synplan/ml/networks/modules.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing basic pytorch architectures of policy and value neural networks."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Tuple, Union
5
+
6
+ import torch
7
+ from adabelief_pytorch import AdaBelief
8
+ from pytorch_lightning import LightningModule
9
+ from torch import Tensor
10
+ from torch.nn import GELU, Dropout, Linear, Module, ModuleDict, ModuleList
11
+ from torch.nn.functional import relu
12
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
13
+ from torch_geometric.data.batch import Batch
14
+ from torch_geometric.nn.conv import GCNConv
15
+ from torch_geometric.nn.pool import global_add_pool
16
+
17
+
18
+ class GraphEmbedding(Module):
19
+ """Needed to convert molecule atom vectors to the single vector using graph
20
+ convolution."""
21
+
22
+ def __init__(
23
+ self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 5
24
+ ):
25
+ """Initializes a graph convolutional module. Needed to convert molecule atom
26
+ vectors to the single vector using graph convolution.
27
+
28
+ :param vector_dim: The dimensionality of the hidden layers and output layer of
29
+ graph convolution module.
30
+ :param dropout: Dropout is a regularization technique used in neural networks to
31
+ prevent overfitting. It randomly sets a fraction of input units to 0 at each
32
+ update during training time.
33
+ :param num_conv_layers: The number of convolutional layers in a graph
34
+ convolutional module.
35
+ """
36
+
37
+ super().__init__()
38
+ self.expansion = Linear(11, vector_dim)
39
+ self.dropout = Dropout(dropout)
40
+ self.gcn_convs = ModuleList(
41
+ [
42
+ GCNConv(
43
+ vector_dim,
44
+ vector_dim,
45
+ improved=True,
46
+ )
47
+ for _ in range(num_conv_layers)
48
+ ]
49
+ )
50
+
51
+ def forward(self, graph: Batch, batch_size: int) -> Tensor:
52
+ """Takes a graph as input and performs graph convolution on it.
53
+
54
+ :param graph: The batch of molecular graphs, where each atom is represented by
55
+ the atom/bond vector.
56
+ :param batch_size: The size of the batch.
57
+ :return: Graph embedding.
58
+ """
59
+ atoms, connections = graph.x.float(), graph.edge_index.long()
60
+ atoms = torch.log(atoms + 1)
61
+ atoms = self.expansion(atoms)
62
+ for gcn_conv in self.gcn_convs:
63
+ atoms = atoms + self.dropout(relu(gcn_conv(atoms, connections)))
64
+
65
+ return global_add_pool(atoms, graph.batch, size=batch_size)
66
+
67
+
68
+ class GraphEmbeddingConcat(GraphEmbedding, Module):
69
+ """Needed to concat.""" # TODO for what ?
70
+
71
+ def __init__(
72
+ self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 8
73
+ ):
74
+ super().__init__()
75
+
76
+ gcn_dim = vector_dim // num_conv_layers
77
+
78
+ self.expansion = Linear(11, gcn_dim)
79
+ self.dropout = Dropout(dropout)
80
+ self.gcn_convs = ModuleList(
81
+ [
82
+ ModuleDict(
83
+ {
84
+ "gcn": GCNConv(gcn_dim, gcn_dim, improved=True),
85
+ "activation": GELU(),
86
+ }
87
+ )
88
+ for _ in range(num_conv_layers)
89
+ ]
90
+ )
91
+
92
+ def forward(self, graph: Batch, batch_size: int) -> Tensor:
93
+ """Takes a graph as input and performs graph convolution on it.
94
+
95
+ :param graph: The batch of molecular graphs, where each atom is represented by
96
+ the atom/bond vector.
97
+ :param batch_size: The size of the batch.
98
+ :return: Graph embedding.
99
+ """
100
+
101
+ atoms, connections = graph.x.float(), graph.edge_index.long()
102
+ atoms = torch.log(atoms + 1)
103
+ atoms = self.expansion(atoms)
104
+
105
+ collected_atoms = []
106
+ for gcn_convs in self.gcn_convs:
107
+ atoms = gcn_convs["gcn"](atoms, connections)
108
+ atoms = gcn_convs["activation"](atoms)
109
+ atoms = self.dropout(atoms)
110
+ collected_atoms.append(atoms)
111
+
112
+ atoms = torch.cat(collected_atoms, dim=-1)
113
+
114
+ return global_add_pool(atoms, graph.batch, size=batch_size)
115
+
116
+
117
+ class MCTSNetwork(LightningModule, ABC):
118
+ """Basic class for policy and value networks."""
119
+
120
+ def __init__(
121
+ self,
122
+ vector_dim: int,
123
+ batch_size: int,
124
+ dropout: float = 0.4,
125
+ num_conv_layers: int = 5,
126
+ learning_rate: float = 0.001,
127
+ gcn_concat: bool = False,
128
+ ):
129
+ """The basic class for MCTS graph convolutional neural networks (policy and
130
+ value network).
131
+
132
+ :param vector_dim: The dimensionality of the hidden layers and output layer of
133
+ graph convolution module.
134
+ :param dropout: Dropout is a regularization technique used in neural networks to
135
+ prevent overfitting.
136
+ :param num_conv_layers: The number of convolutional layers in a graph
137
+ convolutional module.
138
+ :param learning_rate: The learning rate determines how quickly the model learns
139
+ from the training data.
140
+ :param gcn_concat: ???. #TODO explain
141
+ """
142
+ super().__init__()
143
+ if gcn_concat:
144
+ self.embedder = GraphEmbeddingConcat(vector_dim, dropout, num_conv_layers)
145
+ else:
146
+ self.embedder = GraphEmbedding(vector_dim, dropout, num_conv_layers)
147
+ self.batch_size = batch_size
148
+ self.lr = learning_rate
149
+
150
+ @abstractmethod
151
+ def forward(self, batch: Batch) -> Tensor:
152
+ """The forward function takes a batch of input data and performs forward
153
+ propagation through the neural network.
154
+
155
+ :param batch: The batch of molecular graphs processed together in a single
156
+ forward pass through the neural network.
157
+ """
158
+
159
+ @abstractmethod
160
+ def _get_loss(self, batch: Batch) -> Tensor:
161
+ """Calculate the loss for a given batch of data.
162
+
163
+ :param batch: The batch of input data that is used to compute the loss.
164
+ """
165
+
166
+ def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
167
+ """Calculates the loss for a given training batch and logs the loss value.
168
+
169
+ :param batch: The batch of data that is used for training.
170
+ :param batch_idx: The index of the batch.
171
+ :return: The value of the training loss.
172
+ """
173
+ metrics = self._get_loss(batch)
174
+ for name, value in metrics.items():
175
+ self.log(
176
+ "train_" + name,
177
+ value,
178
+ prog_bar=True,
179
+ on_step=True,
180
+ on_epoch=True,
181
+ batch_size=self.batch_size,
182
+ )
183
+ return metrics["loss"]
184
+
185
+ def validation_step(self, batch: Batch, batch_idx: int) -> None:
186
+ """Calculates the loss for a given validation batch and logs the loss value.
187
+
188
+ :param batch: The batch of data that is used for validation.
189
+ :param batch_idx: The index of the batch.
190
+ """
191
+ metrics = self._get_loss(batch)
192
+ for name, value in metrics.items():
193
+ self.log("val_" + name, value, on_epoch=True, batch_size=self.batch_size)
194
+
195
+ def test_step(self, batch: Batch, batch_idx: int) -> None:
196
+ """Calculates the loss for a given test batch and logs the loss value.
197
+
198
+ :param batch: The batch of data that is used for testing.
199
+ :param batch_idx: The index of the batch.
200
+ """
201
+ metrics = self._get_loss(batch)
202
+ for name, value in metrics.items():
203
+ self.log("test_" + name, value, on_epoch=True, batch_size=self.batch_size)
204
+
205
+ def configure_optimizers(
206
+ self,
207
+ ) -> Tuple[List[AdaBelief], List[Dict[str, Union[bool, str, ReduceLROnPlateau]]]]:
208
+ """Returns an optimizer and a learning rate scheduler for training a model using
209
+ the AdaBelief optimizer and ReduceLROnPlateau scheduler.
210
+
211
+ :return: The optimizer and a scheduler.
212
+ """
213
+
214
+ optimizer = AdaBelief(
215
+ self.parameters(),
216
+ lr=self.lr,
217
+ eps=1e-16,
218
+ betas=(0.9, 0.999),
219
+ weight_decouple=True,
220
+ rectify=True,
221
+ weight_decay=0.01,
222
+ print_change_log=False,
223
+ )
224
+
225
+ lr_scheduler = ReduceLROnPlateau(
226
+ optimizer, patience=3, factor=0.8, min_lr=5e-5, verbose=True
227
+ )
228
+ scheduler = {
229
+ "scheduler": lr_scheduler,
230
+ "reduce_on_plateau": True,
231
+ "monitor": "val_loss",
232
+ }
233
+
234
+ return [optimizer], [scheduler]
synplan/ml/networks/policy.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing main class for policy network."""
2
+
3
+ from abc import ABC
4
+ from typing import Dict
5
+
6
+ import torch
7
+ from pytorch_lightning import LightningModule
8
+ from torch import Tensor
9
+ from torch.nn import Linear
10
+ from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy, one_hot
11
+ from torch_geometric.data.batch import Batch
12
+ from torchmetrics.functional.classification import f1_score, recall, specificity
13
+
14
+ from synplan.ml.networks.modules import MCTSNetwork
15
+
16
+
17
+ class PolicyNetwork(MCTSNetwork, LightningModule, ABC):
18
+ """Policy network."""
19
+
20
+ def __init__(
21
+ self,
22
+ *args,
23
+ n_rules: int,
24
+ vector_dim: int,
25
+ policy_type: str = "ranking",
26
+ **kwargs
27
+ ):
28
+ """Initializes a policy network with the given number of reaction rules (output
29
+ dimension) and vector graph embedding dimension, and creates linear layers for
30
+ predicting the regular and priority reaction rules.
31
+
32
+ :param n_rules: The number of reaction rules in the policy network.
33
+ :param vector_dim: The dimensionality of the input vectors.
34
+ """
35
+ super().__init__(vector_dim, *args, **kwargs)
36
+ self.save_hyperparameters()
37
+ self.policy_type = policy_type
38
+ self.n_rules = n_rules
39
+ self.y_predictor = Linear(vector_dim, n_rules)
40
+
41
+ if self.policy_type == "filtering":
42
+ self.priority_predictor = Linear(vector_dim, n_rules)
43
+
44
+ def forward(self, batch: Batch) -> Tensor:
45
+ """Takes a molecular graph, applies a graph convolution and sigmoid layers to
46
+ predict regular and priority reaction rules.
47
+
48
+ :param batch: The input batch of molecular graphs.
49
+ :return: Returns the vector of probabilities (given by sigmoid) of successful
50
+ application of regular and priority reaction rules.
51
+ """
52
+ x = self.embedder(batch, self.batch_size)
53
+ y = self.y_predictor(x)
54
+
55
+ if self.policy_type == "ranking":
56
+ y = torch.softmax(y, dim=-1)
57
+ return y
58
+
59
+ if self.policy_type == "filtering":
60
+ y = torch.sigmoid(y)
61
+ priority = torch.sigmoid(self.priority_predictor(x))
62
+ return y, priority
63
+
64
+ def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
65
+ """Calculates the loss and various classification metrics for a given batch for
66
+ reaction rules prediction.
67
+
68
+ :param batch: The batch of molecular graphs.
69
+ :return: A dictionary with loss value and balanced accuracy of reaction rules
70
+ prediction.
71
+ """
72
+ true_y = batch.y_rules.long()
73
+ x = self.embedder(batch, self.batch_size)
74
+ pred_y = self.y_predictor(x)
75
+
76
+ if self.policy_type == "ranking":
77
+ true_one_hot = one_hot(true_y, num_classes=self.n_rules)
78
+ loss = cross_entropy(pred_y, true_one_hot.float())
79
+ ba_y = (
80
+ recall(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
81
+ + specificity(
82
+ pred_y, true_y, task="multiclass", num_classes=self.n_rules
83
+ )
84
+ ) / 2
85
+ f1_y = f1_score(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
86
+
87
+ metrics = {"loss": loss, "balanced_accuracy_y": ba_y, "f1_score_y": f1_y}
88
+
89
+ elif self.policy_type == "filtering":
90
+ loss_y = binary_cross_entropy_with_logits(pred_y, true_y.float())
91
+
92
+ ba_y = (
93
+ recall(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
94
+ + specificity(
95
+ pred_y, true_y, task="multilabel", num_labels=self.n_rules
96
+ )
97
+ ) / 2
98
+
99
+ f1_y = f1_score(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
100
+
101
+ true_priority = batch.y_priority.float()
102
+ pred_priority = self.priority_predictor(x)
103
+ loss_priority = binary_cross_entropy_with_logits(
104
+ pred_priority, true_priority
105
+ )
106
+
107
+ loss = loss_y + loss_priority
108
+
109
+ true_priority = true_priority.long()
110
+ ba_priority = (
111
+ recall(
112
+ pred_priority,
113
+ true_priority,
114
+ task="multilabel",
115
+ num_labels=self.n_rules,
116
+ )
117
+ + specificity(
118
+ pred_priority,
119
+ true_priority,
120
+ task="multilabel",
121
+ num_labels=self.n_rules,
122
+ )
123
+ ) / 2
124
+
125
+ f1_priority = f1_score(
126
+ pred_priority, true_priority, task="multilabel", num_labels=self.n_rules
127
+ )
128
+
129
+ metrics = {
130
+ "loss": loss,
131
+ "balanced_accuracy_y": ba_y,
132
+ "f1_score_y": f1_y,
133
+ "balanced_accuracy_priority": ba_priority,
134
+ "f1_score_priority": f1_priority,
135
+ }
136
+
137
+ return metrics
synplan/ml/networks/value.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing main class for value network."""
2
+
3
+ from abc import ABC
4
+ from typing import Any, Dict
5
+
6
+ import torch
7
+ from pytorch_lightning import LightningModule
8
+ from torch import Tensor
9
+ from torch.nn import Linear
10
+ from torch.nn.functional import binary_cross_entropy_with_logits
11
+ from torch_geometric.data.batch import Batch
12
+ from torchmetrics.functional.classification import (
13
+ binary_f1_score,
14
+ binary_recall,
15
+ binary_specificity,
16
+ )
17
+
18
+ from synplan.ml.networks.modules import MCTSNetwork
19
+
20
+
21
+ class ValueNetwork(MCTSNetwork, LightningModule, ABC):
22
+ """Value network."""
23
+
24
+ def __init__(self, vector_dim: int, *args: Any, **kwargs: Any) -> None:
25
+ """Initializes a value network, and creates linear layer for predicting the
26
+ synthesisability of given precursor represented by molecular graph.
27
+
28
+ :param vector_dim: The dimensionality of the output linear layer.
29
+ """
30
+ super().__init__(vector_dim, *args, **kwargs)
31
+ self.save_hyperparameters()
32
+ self.predictor = Linear(vector_dim, 1)
33
+
34
+ def forward(self, batch) -> torch.Tensor:
35
+ """Takes a batch of molecular graphs, applies a graph convolution returns the
36
+ synthesisability (probability given by sigmoid function) of a given precursor
37
+ represented by molecular graph precessed by graph convolution.
38
+
39
+ :param batch: The batch of molecular graphs.
40
+ :return: The predicted synthesisability (between 0 and 1).
41
+ """
42
+
43
+ x = self.embedder(batch, self.batch_size)
44
+ x = torch.sigmoid(self.predictor(x))
45
+ return x
46
+
47
+ def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
48
+ """Calculates the loss and various classification metrics for a given batch for
49
+ the precursor synthesysability prediction.
50
+
51
+ :param batch: The batch of molecular graphs.
52
+ :return: The dictionary with loss value and balanced accuracy of precursor
53
+ synthesysability prediction.
54
+ """
55
+
56
+ true_y = batch.y.float()
57
+ true_y = torch.unsqueeze(true_y, -1)
58
+ x = self.embedder(batch, self.batch_size)
59
+ pred_y = self.predictor(x)
60
+ # calc loss func
61
+ loss = binary_cross_entropy_with_logits(pred_y, true_y)
62
+
63
+ true_y = true_y.long()
64
+ ba = (binary_recall(pred_y, true_y) + binary_specificity(pred_y, true_y)) / 2
65
+ f1 = binary_f1_score(pred_y, true_y)
66
+ metrics = {"loss": loss, "balanced_accuracy": ba, "f1_score": f1}
67
+ return metrics
synplan/ml/training/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .supervised import *
2
+ from .preprocessing import ValueNetworkDataset, mol_to_pyg, MENDEL_INFO
3
+ from .supervised import create_policy_dataset, run_policy_training
4
+
5
+ __all__ = [
6
+ "ValueNetworkDataset",
7
+ "mol_to_pyg",
8
+ "MENDEL_INFO",
9
+ "create_policy_dataset",
10
+ "run_policy_training",
11
+ ]
synplan/ml/training/preprocessing.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for preparation of the training sets for policy and value
2
+ network."""
3
+
4
+ import logging
5
+ import os
6
+ import pickle
7
+ from abc import ABC
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import ray
11
+ import torch
12
+ from CGRtools import smiles
13
+ from CGRtools.containers import MoleculeContainer
14
+ from CGRtools.exceptions import InvalidAromaticRing
15
+ from CGRtools.reactor import Reactor
16
+ from ray.util.queue import Empty, Queue
17
+ from torch import Tensor
18
+ from torch_geometric.data import InMemoryDataset
19
+ from torch_geometric.data.data import Data
20
+ from torch_geometric.data.makedirs import makedirs
21
+ from torch_geometric.transforms import ToUndirected
22
+ from tqdm import tqdm
23
+
24
+ from synplan.chem.utils import unite_molecules
25
+ from synplan.utils.files import ReactionReader
26
+ from synplan.utils.loading import load_reaction_rules
27
+
28
+
29
+ class ValueNetworkDataset(InMemoryDataset, ABC):
30
+ """Value network dataset."""
31
+
32
+ def __init__(self, extracted_precursor: Dict[str, float]) -> None:
33
+ """Initializes a value network dataset object.
34
+
35
+ :param extracted_precursor: The dictionary with the extracted from the built
36
+ search trees precursor and their labels.
37
+ """
38
+ super().__init__(None, None, None)
39
+
40
+ if extracted_precursor:
41
+ self.data, self.slices = self.graphs_from_extracted_precursor(
42
+ extracted_precursor
43
+ )
44
+
45
+ @staticmethod
46
+ def mol_to_graph(molecule: MoleculeContainer, label: float) -> Optional[Data]:
47
+ """Takes a molecule as input, and converts the molecule to a PyTorch geometric
48
+ graph, assigns the reward value (label) to the graph, and returns the graph.
49
+
50
+ :param molecule: The input molecule.
51
+ :param label: The label (solved/unsolved routes in the tree) of the molecule
52
+ (precursor).
53
+ :return: A PyTorch Geometric graph representation of a molecule.
54
+ """
55
+ if len(molecule) > 2:
56
+ pyg = mol_to_pyg(molecule)
57
+ if pyg:
58
+ pyg.y = torch.tensor([label])
59
+ return pyg
60
+
61
+ return None
62
+
63
+ def graphs_from_extracted_precursor(
64
+ self, extracted_precursor: Dict[str, float]
65
+ ) -> Tuple[Data, Dict]:
66
+ """Converts the extracted from the search trees precursor to the PyTorch geometric
67
+ graphs.
68
+
69
+ :param extracted_precursor: The dictionary with the extracted from the built
70
+ search trees precursor and their labels.
71
+ :return: The PyTorch geometric graphs and slices.
72
+ """
73
+ processed_data = []
74
+ for smi, label in extracted_precursor.items():
75
+ mol = smiles(smi)
76
+ pyg = self.mol_to_graph(mol, label)
77
+ if pyg:
78
+ processed_data.append(pyg)
79
+ data, slices = self.collate(processed_data)
80
+ return data, slices
81
+
82
+
83
+ class RankingPolicyDataset(InMemoryDataset):
84
+ """Ranking policy network dataset."""
85
+
86
+ def __init__(self, reactions_path: str, reaction_rules_path: str, output_path: str):
87
+ """Initializes a policy network dataset.
88
+
89
+ :param reactions_path: The path to the file containing the reaction data used
90
+ for extraction of reaction rules.
91
+ :param reaction_rules_path: The path to the file containing the reaction rules.
92
+ :param output_path: The output path to the file where policy network dataset
93
+ will be saved.
94
+ """
95
+ super().__init__(None, None, None)
96
+
97
+ self.reactions_path = reactions_path
98
+ self.reaction_rules_path = reaction_rules_path
99
+ self.output_path = output_path
100
+
101
+ if output_path and os.path.exists(output_path):
102
+ self.data, self.slices = torch.load(self.output_path)
103
+ else:
104
+ self.data, self.slices = self.prepare_data()
105
+
106
+ @property
107
+ def num_classes(self) -> int:
108
+ return self._infer_num_classes(self._data.y_rules)
109
+
110
+ def prepare_data(self) -> Tuple[Data, Dict[str, Tensor]]:
111
+ """Prepares data by loading reaction rules, preprocessing the molecules,
112
+ collating the data, and returning the data and slices.
113
+
114
+ :return: The PyTorch geometric graphs and slices.
115
+ """
116
+
117
+ with open(self.reaction_rules_path, "rb") as inp:
118
+ reaction_rules = pickle.load(inp)
119
+ reaction_rules = sorted(reaction_rules, key=lambda x: len(x[1]), reverse=True)
120
+
121
+ reaction_rule_pairs = {}
122
+ for rule_i, (_, reactions_ids) in enumerate(reaction_rules):
123
+ for reaction_id in reactions_ids:
124
+ reaction_rule_pairs[reaction_id] = rule_i
125
+ reaction_rule_pairs = dict(sorted(reaction_rule_pairs.items()))
126
+
127
+ list_of_graphs = []
128
+ with ReactionReader(self.reactions_path) as reactions:
129
+
130
+ for reaction_id, reaction in tqdm(
131
+ enumerate(reactions),
132
+ desc="Number of reactions processed: ",
133
+ bar_format="{desc}{n} [{elapsed}]",
134
+ ):
135
+
136
+ rule_id = reaction_rule_pairs.get(reaction_id)
137
+ if rule_id:
138
+ try: # MENDEL_INFO does not contain cadmium (Cd) properties
139
+ molecule = unite_molecules(reaction.products)
140
+ pyg_graph = mol_to_pyg(molecule)
141
+
142
+ except (
143
+ Exception
144
+ ) as e: # TypeError: can't assign a NoneType to a torch.ByteTensor
145
+ logging.debug(e)
146
+ continue
147
+
148
+ if pyg_graph is not None:
149
+ pyg_graph.y_rules = torch.tensor([rule_id], dtype=torch.long)
150
+ list_of_graphs.append(pyg_graph)
151
+ else:
152
+ continue
153
+
154
+ data, slices = self.collate(list_of_graphs)
155
+ if self.output_path:
156
+ makedirs(os.path.dirname(self.output_path))
157
+ torch.save((data, slices), self.output_path)
158
+
159
+ return data, slices
160
+
161
+
162
+ class FilteringPolicyDataset(InMemoryDataset):
163
+ """Filtering policy network dataset."""
164
+
165
+ def __init__(
166
+ self,
167
+ molecules_path: str,
168
+ reaction_rules_path: str,
169
+ output_path: str,
170
+ num_cpus: int,
171
+ ) -> None:
172
+ """Initializes a policy network dataset object.
173
+
174
+ :param molecules_path: The path to the file containing the molecules for
175
+ reaction rule appliance.
176
+ :param reaction_rules_path: The path to the file containing the reaction rules.
177
+ :param output_path: The output path to the file where policy network dataset
178
+ will be stored.
179
+ :param num_cpus: The number of CPUs to be used for the dataset preparation.
180
+ :return: None.
181
+ """
182
+ super().__init__(None, None, None)
183
+
184
+ self.molecules_path = molecules_path
185
+ self.reaction_rules_path = reaction_rules_path
186
+ self.output_path = output_path
187
+ self.num_cpus = num_cpus
188
+ self.batch_size = 100
189
+
190
+ if output_path and os.path.exists(output_path):
191
+ self.data, self.slices = torch.load(self.output_path)
192
+ else:
193
+ self.data, self.slices = self.prepare_data()
194
+
195
+ @property
196
+ def num_classes(self) -> int:
197
+ return self._data.y_rules.shape[1]
198
+
199
+ def prepare_data(self) -> Tuple[Data, Dict]:
200
+ """Prepares data by loading reaction rules, initializing Ray, preprocessing the
201
+ molecules, collating the data, and returning the data and slices.
202
+
203
+ :return: The PyTorch geometric graphs and slices.
204
+ """
205
+
206
+ ray.init(num_cpus=self.num_cpus, ignore_reinit_error=True)
207
+ reaction_rules = load_reaction_rules(self.reaction_rules_path)
208
+ reaction_rules_ids = ray.put(reaction_rules)
209
+
210
+ to_process = Queue(maxsize=self.batch_size * self.num_cpus)
211
+ processed_data = []
212
+ results_ids = [
213
+ preprocess_filtering_policy_molecules.remote(to_process, reaction_rules_ids)
214
+ for _ in range(self.num_cpus)
215
+ ]
216
+
217
+ with open(self.molecules_path, "r", encoding="utf-8") as inp_data:
218
+ for molecule in tqdm(
219
+ inp_data.read().splitlines(),
220
+ desc="Number of molecules processed: ",
221
+ bar_format="{desc}{n} [{elapsed}]",
222
+ ):
223
+
224
+ to_process.put(molecule)
225
+
226
+ results = [graph for res in ray.get(results_ids) if res for graph in res]
227
+ processed_data.extend(results)
228
+
229
+ ray.shutdown()
230
+
231
+ for pyg in processed_data:
232
+ pyg.y_rules = pyg.y_rules.to_dense()
233
+ pyg.y_priority = pyg.y_priority.to_dense()
234
+
235
+ data, slices = self.collate(processed_data)
236
+ if self.output_path:
237
+ makedirs(os.path.dirname(self.output_path))
238
+ torch.save((data, slices), self.output_path)
239
+
240
+ return data, slices
241
+
242
+
243
+ def reaction_rules_appliance(
244
+ molecule: MoleculeContainer, reaction_rules: List[Reactor]
245
+ ) -> Tuple[List[int], List[int]]:
246
+ """Applies each reaction rule from the list of reaction rules to a given molecule
247
+ and returns the indexes of the successfully applied regular and prioritized reaction
248
+ rules.
249
+
250
+ :param molecule: The input molecule.
251
+ :param reaction_rules: The list of reaction rules.
252
+ :return: The two lists of indexes of successfully applied regular reaction rules and
253
+ priority reaction rules.
254
+ """
255
+
256
+ applied_rules, priority_rules = [], []
257
+ for i, rule in enumerate(reaction_rules):
258
+
259
+ rule_applied = False
260
+ rule_prioritized = False
261
+
262
+ try:
263
+ for reaction in rule([molecule]):
264
+ for prod in reaction.products:
265
+ prod.kekule()
266
+ if prod.check_valence():
267
+ break
268
+ rule_applied = True
269
+
270
+ # check priority rules
271
+ if len(reaction.products) > 1:
272
+ # check coupling retro manual
273
+ if all(len(mol) > 6 for mol in reaction.products):
274
+ if (
275
+ sum(len(mol) for mol in reaction.products)
276
+ - len(reaction.reactants[0])
277
+ < 6
278
+ ):
279
+ rule_prioritized = True
280
+ else:
281
+ # check cyclization retro manual
282
+ if sum(len(mol.sssr) for mol in reaction.products) < sum(
283
+ len(mol.sssr) for mol in reaction.reactants
284
+ ):
285
+ rule_prioritized = True
286
+ #
287
+ if rule_applied:
288
+ applied_rules.append(i)
289
+ #
290
+ if rule_prioritized:
291
+ priority_rules.append(i)
292
+ except Exception as e:
293
+ logging.debug(e)
294
+ continue
295
+
296
+ return applied_rules, priority_rules
297
+
298
+
299
+ @ray.remote
300
+ def preprocess_filtering_policy_molecules(
301
+ to_process: Queue, reaction_rules: List[Reactor]
302
+ ) -> List[Optional[Data]]:
303
+ """Preprocesses a list of molecules by applying reaction rules and converting
304
+ molecules into PyTorch geometric graphs. Successfully applied reaction rules are
305
+ converted to binary vectors for policy network training.
306
+
307
+ :param to_process: The queue containing SMILES of molecules to be converted to the
308
+ training data.
309
+ :param reaction_rules: The list of reaction rules.
310
+ :return: The list of PyGraph objects.
311
+ """
312
+
313
+ pyg_graphs = []
314
+ while True:
315
+ try:
316
+ molecule = smiles(to_process.get(timeout=30))
317
+ if not isinstance(molecule, MoleculeContainer):
318
+ continue
319
+
320
+ # reaction reaction_rules application
321
+ applied_rules, priority_rules = reaction_rules_appliance(
322
+ molecule, reaction_rules
323
+ )
324
+
325
+ y_rules = torch.sparse_coo_tensor(
326
+ [applied_rules],
327
+ torch.ones(len(applied_rules)),
328
+ (len(reaction_rules),),
329
+ dtype=torch.uint8,
330
+ )
331
+ y_priority = torch.sparse_coo_tensor(
332
+ [priority_rules],
333
+ torch.ones(len(priority_rules)),
334
+ (len(reaction_rules),),
335
+ dtype=torch.uint8,
336
+ )
337
+
338
+ y_rules = torch.unsqueeze(y_rules, 0)
339
+ y_priority = torch.unsqueeze(y_priority, 0)
340
+
341
+ pyg_graph = mol_to_pyg(molecule)
342
+ if not pyg_graph:
343
+ continue
344
+ pyg_graph.y_rules = y_rules
345
+ pyg_graph.y_priority = y_priority
346
+ pyg_graphs.append(pyg_graph)
347
+
348
+ except Empty:
349
+ break
350
+
351
+ return pyg_graphs
352
+
353
+
354
+ def atom_to_vector(atom: Any) -> Tensor:
355
+ """Given an atom, return a vector of length 8 with the following
356
+ information:
357
+
358
+ 1. Atomic number
359
+ 2. Period
360
+ 3. Group
361
+ 4. Number of electrons + atom's charge
362
+ 5. Shell
363
+ 6. Total number of hydrogens
364
+ 7. Whether the atom is in a ring
365
+ 8. Number of neighbors
366
+
367
+ :param atom: The atom object.
368
+
369
+ :return: The vector of the atom.
370
+ """
371
+ vector = torch.zeros(8, dtype=torch.uint8)
372
+ period, group, shell, electrons = MENDEL_INFO[atom.atomic_symbol]
373
+ vector[0] = atom.atomic_number
374
+ vector[1] = period
375
+ vector[2] = group
376
+ vector[3] = electrons + atom.charge
377
+ vector[4] = shell
378
+ vector[5] = atom.total_hydrogens
379
+ vector[6] = int(atom.in_ring)
380
+ vector[7] = atom.neighbors
381
+ return vector
382
+
383
+
384
+ def bonds_to_vector(molecule: MoleculeContainer, atom_ind: int) -> Tensor:
385
+ """Takes a molecule and an atom index as input, and returns a vector representing
386
+ the bond orders of the atom's bonds.
387
+
388
+ :param molecule: The given molecule.
389
+ :param atom_ind: The index of the atom in the molecule to be converted to the bond
390
+ vector.
391
+ :return: The torch tensor of size 3, with each element representing the order of
392
+ bonds connected to the atom with the given index in the molecule.
393
+ """
394
+
395
+ vector = torch.zeros(3, dtype=torch.uint8)
396
+ for b_order in molecule._bonds[atom_ind].values():
397
+ vector[int(b_order) - 1] += 1
398
+ return vector
399
+
400
+
401
+ def mol_to_matrix(molecule: MoleculeContainer) -> Tensor:
402
+ """Given a molecule, it returns a vector of shape (max_atoms, 12) where each row is
403
+ an atom and each column is a feature.
404
+
405
+ :param molecule: The molecule to be converted to a vector
406
+ :return: The atoms vectors array.
407
+ """
408
+
409
+ atoms_vectors = torch.zeros((len(molecule), 11), dtype=torch.uint8)
410
+ for n, atom in molecule.atoms():
411
+ atoms_vectors[n - 1][:8] = atom_to_vector(atom)
412
+ for n, _ in molecule.atoms():
413
+ atoms_vectors[n - 1][8:] = bonds_to_vector(molecule, n)
414
+
415
+ return atoms_vectors
416
+
417
+
418
+ def mol_to_pyg(
419
+ molecule: MoleculeContainer, canonicalize: bool = True
420
+ ) -> Optional[Data]:
421
+ """Takes a list of molecules and returns a list of PyTorch Geometric graphs, a one-
422
+ hot encoded vectors of the atoms, and a matrices of the bonds.
423
+
424
+ :param molecule: The molecule to be converted to PyTorch Geometric graph.
425
+ :param canonicalize: If True, the input molecule is canonicalized.
426
+ :return: The list of PyGraph objects.
427
+ """
428
+
429
+ if len(molecule) == 1: # to avoid a precursor to be a single atom
430
+ return None
431
+
432
+ tmp_molecule = molecule.copy()
433
+ try:
434
+ if canonicalize:
435
+ tmp_molecule.canonicalize()
436
+ tmp_molecule.kekule()
437
+ if tmp_molecule.check_valence():
438
+ return None
439
+ except InvalidAromaticRing:
440
+ return None
441
+
442
+ # remapping target for torch_geometric because
443
+ # it is necessary that the elements in edge_index only hold nodes_idx in the range { 0, ..., num_nodes - 1}
444
+ new_mappings = {n: i for i, (n, _) in enumerate(tmp_molecule.atoms(), 1)}
445
+ tmp_molecule.remap(new_mappings)
446
+
447
+ # get edge indexes from target mapping
448
+ edge_index = []
449
+ for atom, neighbour, bond in tmp_molecule.bonds():
450
+ edge_index.append([atom - 1, neighbour - 1])
451
+ edge_index = torch.tensor(edge_index, dtype=torch.long)
452
+
453
+ #
454
+ x = mol_to_matrix(tmp_molecule)
455
+
456
+ mol_pyg_graph = Data(x=x, edge_index=edge_index.t().contiguous())
457
+ mol_pyg_graph = ToUndirected()(mol_pyg_graph)
458
+
459
+ assert mol_pyg_graph.is_undirected()
460
+
461
+ return mol_pyg_graph
462
+
463
+
464
+ MENDEL_INFO = {
465
+ "Ag": (5, 11, 1, 1),
466
+ "Al": (3, 13, 2, 1),
467
+ "Ar": (3, 18, 2, 6),
468
+ "As": (4, 15, 2, 3),
469
+ "B": (2, 13, 2, 1),
470
+ "Ba": (6, 2, 1, 2),
471
+ "Bi": (6, 15, 2, 3),
472
+ "Br": (4, 17, 2, 5),
473
+ "C": (2, 14, 2, 2),
474
+ "Ca": (4, 2, 1, 2),
475
+ "Ce": (6, None, 1, 2),
476
+ "Cl": (3, 17, 2, 5),
477
+ "Cr": (4, 6, 1, 1),
478
+ "Cs": (6, 1, 1, 1),
479
+ "Cu": (4, 11, 1, 1),
480
+ "Dy": (6, None, 1, 2),
481
+ "Er": (6, None, 1, 2),
482
+ "F": (2, 17, 2, 5),
483
+ "Fe": (4, 8, 1, 2),
484
+ "Ga": (4, 13, 2, 1),
485
+ "Gd": (6, None, 1, 2),
486
+ "Ge": (4, 14, 2, 2),
487
+ "Hg": (6, 12, 1, 2),
488
+ "I": (5, 17, 2, 5),
489
+ "In": (5, 13, 2, 1),
490
+ "K": (4, 1, 1, 1),
491
+ "La": (6, 3, 1, 2),
492
+ "Li": (2, 1, 1, 1),
493
+ "Mg": (3, 2, 1, 2),
494
+ "Mn": (4, 7, 1, 2),
495
+ "N": (2, 15, 2, 3),
496
+ "Na": (3, 1, 1, 1),
497
+ "Nd": (6, None, 1, 2),
498
+ "O": (2, 16, 2, 4),
499
+ "P": (3, 15, 2, 3),
500
+ "Pb": (6, 14, 2, 2),
501
+ "Pd": (5, 10, 3, 10),
502
+ "Pr": (6, None, 1, 2),
503
+ "Rb": (5, 1, 1, 1),
504
+ "S": (3, 16, 2, 4),
505
+ "Sb": (5, 15, 2, 3),
506
+ "Se": (4, 16, 2, 4),
507
+ "Si": (3, 14, 2, 2),
508
+ "Sm": (6, None, 1, 2),
509
+ "Sn": (5, 14, 2, 2),
510
+ "Sr": (5, 2, 1, 2),
511
+ "Te": (5, 16, 2, 4),
512
+ "Ti": (4, 4, 1, 2),
513
+ "Tl": (6, 13, 2, 1),
514
+ "Yb": (6, None, 1, 2),
515
+ "Zn": (4, 12, 1, 2),
516
+ }
synplan/ml/training/reinforcement.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for running value network tuning with reinforcement learning
2
+ approach."""
3
+
4
+ import os
5
+ import random
6
+ from collections import defaultdict
7
+ from pathlib import Path
8
+ from random import shuffle
9
+ from typing import Dict, List
10
+
11
+ import torch
12
+ from CGRtools.containers import MoleculeContainer
13
+ from pytorch_lightning import Trainer
14
+ from torch.utils.data import random_split
15
+ from torch_geometric.data.lightning import LightningDataset
16
+
17
+ from synplan.chem.precursor import compose_precursors
18
+ from synplan.mcts.evaluation import ValueNetworkFunction
19
+ from synplan.mcts.expansion import PolicyNetworkFunction
20
+ from synplan.mcts.tree import Tree
21
+ from synplan.ml.networks.value import ValueNetwork
22
+ from synplan.ml.training.preprocessing import ValueNetworkDataset
23
+ from synplan.utils.config import (
24
+ PolicyNetworkConfig,
25
+ TuningConfig,
26
+ TreeConfig,
27
+ ValueNetworkConfig,
28
+ )
29
+ from synplan.utils.files import MoleculeReader
30
+ from synplan.utils.loading import (
31
+ load_building_blocks,
32
+ load_reaction_rules,
33
+ load_value_net,
34
+ )
35
+ from synplan.utils.logging import DisableLogger, HiddenPrints
36
+
37
+
38
+ def create_value_network(value_config: ValueNetworkConfig) -> ValueNetwork:
39
+ """Creates the initial value network.
40
+
41
+ :param value_config: The value network configuration.
42
+ :return: The valueNetwork to be trained/tuned.
43
+ """
44
+
45
+ weights_path = Path(value_config.weights_path)
46
+ value_network = ValueNetwork(
47
+ vector_dim=value_config.vector_dim,
48
+ batch_size=value_config.batch_size,
49
+ dropout=value_config.dropout,
50
+ num_conv_layers=value_config.num_conv_layers,
51
+ learning_rate=value_config.learning_rate,
52
+ )
53
+
54
+ with DisableLogger(), HiddenPrints():
55
+ trainer = Trainer()
56
+ trainer.strategy.connect(value_network)
57
+ trainer.save_checkpoint(weights_path)
58
+
59
+ return value_network
60
+
61
+
62
+ def create_targets_batch(
63
+ targets: List[MoleculeContainer], batch_size: int
64
+ ) -> List[List[MoleculeContainer]]:
65
+ """Creates the targets batches for planning simulations and value network tuning.
66
+
67
+ :param targets: The list of target molecules.
68
+ :param batch_size: The size of each target batch.
69
+ :return: The list of lists corresponding to each target batch.
70
+ """
71
+
72
+ num_targets = len(targets)
73
+ batch_splits = list(
74
+ range(num_targets // batch_size + int(bool(num_targets % batch_size)))
75
+ )
76
+
77
+ if int(num_targets / batch_size) == 0:
78
+ print(f"1 batch were created with {num_targets} molecules")
79
+ else:
80
+ print(
81
+ f"{len(batch_splits)} batches were created with {batch_size} molecules each"
82
+ )
83
+
84
+ targets_batch_list = []
85
+ for batch_id in batch_splits:
86
+ batch_slices = [
87
+ i
88
+ for i in range(batch_id * batch_size, (batch_id + 1) * batch_size)
89
+ if i < len(targets)
90
+ ]
91
+ targets_batch_list.append([targets[i] for i in batch_slices])
92
+
93
+ return targets_batch_list
94
+
95
+
96
+ def run_tree_search(
97
+ target: MoleculeContainer,
98
+ tree_config: TreeConfig,
99
+ policy_config: PolicyNetworkConfig,
100
+ value_config: ValueNetworkConfig,
101
+ reaction_rules_path: str,
102
+ building_blocks_path: str,
103
+ ) -> Tree:
104
+ """Runs tree search for the given target molecule.
105
+
106
+ :param target: The target molecule.
107
+ :param tree_config: The planning configuration of tree search.
108
+ :param policy_config: The policy network configuration.
109
+ :param value_config: The value network configuration.
110
+ :param reaction_rules_path: The path to the file with reaction rules.
111
+ :param building_blocks_path: The path to the file with building blocks.
112
+ :return: The built search tree for the given molecule.
113
+ """
114
+
115
+ # policy and value function loading
116
+ policy_function = PolicyNetworkFunction(policy_config=policy_config)
117
+ value_function = ValueNetworkFunction(weights_path=value_config.weights_path)
118
+ reaction_rules = load_reaction_rules(reaction_rules_path)
119
+ building_blocks = load_building_blocks(building_blocks_path, standardize=True)
120
+
121
+ # initialize tree
122
+ tree_config.evaluation_type = "gcn"
123
+ tree_config.silent = True
124
+ tree = Tree(
125
+ target=target,
126
+ config=tree_config,
127
+ reaction_rules=reaction_rules,
128
+ building_blocks=building_blocks,
129
+ expansion_function=policy_function,
130
+ evaluation_function=value_function,
131
+ )
132
+ tree._tqdm = False
133
+
134
+ # remove target from buildings blocs
135
+ if str(target) in tree.building_blocks:
136
+ tree.building_blocks.remove(str(target))
137
+
138
+ # run tree search
139
+ _ = list(tree)
140
+
141
+ return tree
142
+
143
+
144
+ def extract_tree_precursor(tree_list: List[Tree]) -> Dict[str, float]:
145
+ """Takes the built tree and extracts the precursor for value network tuning. The
146
+ precursor from found retrosynthetic routes are labeled as a positive class and precursor
147
+ from not solved routes are labeled as a negative class.
148
+
149
+ :param tree_list: The list of built search trees.
150
+
151
+ :return: The dictionary with the precursor SMILES and its class (positive - 1 or negative - 0).
152
+ """
153
+ extracted_precursor = defaultdict(float)
154
+ for tree in tree_list:
155
+ for idx, node in tree.nodes.items():
156
+ # add solved nodes to set
157
+ if node.is_solved():
158
+ parent = idx
159
+ while parent and parent != 1:
160
+ composed_smi = str(
161
+ compose_precursors(tree.nodes[parent].new_precursors)
162
+ )
163
+ extracted_precursor[composed_smi] = 1.0
164
+ parent = tree.parents[parent]
165
+ else:
166
+ composed_smi = str(compose_precursors(tree.nodes[idx].new_precursors))
167
+ extracted_precursor[composed_smi] = 0.0
168
+
169
+ # shuffle extracted precursor
170
+ processed_keys = list(extracted_precursor.keys())
171
+ shuffle(processed_keys)
172
+ extracted_precursor = {i: extracted_precursor[i] for i in processed_keys}
173
+
174
+ return extracted_precursor
175
+
176
+
177
+ def balance_extracted_precursor(extracted_precursor):
178
+ extracted_precursor_balanced = {}
179
+ neg_list = [i for i, j in extracted_precursor.items() if j == 0]
180
+ for k, v in extracted_precursor.items():
181
+ if v == 1:
182
+ extracted_precursor_balanced[k] = v
183
+ if len(extracted_precursor_balanced) < len(neg_list):
184
+ neg_list.pop(random.choice(range(len(neg_list))))
185
+ return extracted_precursor_balanced
186
+
187
+
188
+ def create_updating_set(
189
+ extracted_precursor: Dict[str, float], batch_size: int = 1
190
+ ) -> LightningDataset:
191
+ """Creates the value network updating dataset from precursor extracted from the planning
192
+ simulation.
193
+
194
+ :param extracted_precursor: The dictionary with the extracted precursor and their
195
+ labels.
196
+ :param batch_size: The size of the batch in value network updating.
197
+ :return: A LightningDataset object, which contains the tuning set for value network
198
+ tuning.
199
+ """
200
+
201
+ extracted_precursor = balance_extracted_precursor(extracted_precursor)
202
+
203
+ full_dataset = ValueNetworkDataset(extracted_precursor)
204
+ train_size = int(0.6 * len(full_dataset))
205
+ val_size = len(full_dataset) - train_size
206
+
207
+ train_set, val_set = random_split(
208
+ full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
209
+ )
210
+
211
+ print(f"Training set size: {len(train_set)}")
212
+ print(f"Validation set size: {len(val_set)}")
213
+
214
+ return LightningDataset(
215
+ train_set, val_set, batch_size=batch_size, pin_memory=True, drop_last=True
216
+ )
217
+
218
+
219
+ def tune_value_network(
220
+ datamodule: LightningDataset, value_config: ValueNetworkConfig
221
+ ) -> None:
222
+ """Trains the value network using a given tuning data and saves the trained neural
223
+ network.
224
+
225
+ :param datamodule: The tuning dataset (LightningDataset).
226
+ :param value_config: The value network configuration.
227
+ :return: None.
228
+ """
229
+
230
+ current_weights = value_config.weights_path
231
+ value_network = load_value_net(ValueNetwork, current_weights)
232
+
233
+ with DisableLogger(), HiddenPrints():
234
+ trainer = Trainer(
235
+ accelerator="gpu",
236
+ devices=[0],
237
+ max_epochs=value_config.num_epoch,
238
+ enable_checkpointing=False,
239
+ logger=False,
240
+ gradient_clip_val=1.0,
241
+ enable_progress_bar=False,
242
+ )
243
+
244
+ trainer.fit(value_network, datamodule)
245
+ val_score = trainer.validate(value_network, datamodule.val_dataloader())[0]
246
+ trainer.save_checkpoint(current_weights)
247
+
248
+ print(f"Value network balanced accuracy: {val_score['val_balanced_accuracy']}")
249
+
250
+
251
+ def run_training(
252
+ extracted_precursor: Dict[str, float] = None,
253
+ value_config: ValueNetworkConfig = None,
254
+ ) -> None:
255
+ """Runs the training stage in value network tuning.
256
+
257
+ :param extracted_precursor: The precursor extracted from the planing simulations.
258
+ :param value_config: The value network configuration.
259
+ :return: None.
260
+ """
261
+
262
+ # create training set
263
+ training_set = create_updating_set(
264
+ extracted_precursor=extracted_precursor, batch_size=value_config.batch_size
265
+ )
266
+
267
+ # retrain value network
268
+ tune_value_network(datamodule=training_set, value_config=value_config)
269
+
270
+
271
+ def run_planning(
272
+ targets_batch: List[MoleculeContainer],
273
+ tree_config: TreeConfig,
274
+ policy_config: PolicyNetworkConfig,
275
+ value_config: ValueNetworkConfig,
276
+ reaction_rules_path: str,
277
+ building_blocks_path: str,
278
+ targets_batch_id: int,
279
+ ):
280
+ """Performs planning stage (tree search) for target molecules and save extracted
281
+ from built trees precursor for further tuning the value network in the training stage.
282
+
283
+ :param targets_batch:
284
+ :param tree_config:
285
+ :param policy_config:
286
+ :param value_config:
287
+ :param reaction_rules_path:
288
+ :param building_blocks_path:
289
+ :param targets_batch_id:
290
+ """
291
+ from tqdm import tqdm
292
+
293
+ print(f"\nProcess batch number {targets_batch_id}")
294
+ tree_list = []
295
+ tree_config.silent = False
296
+ for target in tqdm(targets_batch):
297
+
298
+ try:
299
+ tree = run_tree_search(
300
+ target=target,
301
+ tree_config=tree_config,
302
+ policy_config=policy_config,
303
+ value_config=value_config,
304
+ reaction_rules_path=reaction_rules_path,
305
+ building_blocks_path=building_blocks_path,
306
+ )
307
+ tree_list.append(tree)
308
+
309
+ except Exception as e:
310
+ print(e)
311
+ continue
312
+
313
+ num_solved = sum([len(i.winning_nodes) > 0 for i in tree_list])
314
+ print(f"Planning is finished with {num_solved} solved targets")
315
+
316
+ return tree_list
317
+
318
+
319
+ def run_updating(
320
+ targets_path: str,
321
+ tree_config: TreeConfig,
322
+ policy_config: PolicyNetworkConfig,
323
+ value_config: ValueNetworkConfig,
324
+ reinforce_config: TuningConfig,
325
+ reaction_rules_path: str,
326
+ building_blocks_path: str,
327
+ results_root: str = None,
328
+ ) -> None:
329
+ """Performs updating of value network.
330
+
331
+ :param targets_path: The path to the file with target molecules.
332
+ :param tree_config: The search tree configuration.
333
+ :param policy_config: The policy network configuration.
334
+ :param value_config: The value network configuration.
335
+ :param reinforce_config: The value network tuning configuration.
336
+ :param reaction_rules_path: The path to the file with reaction rules.
337
+ :param building_blocks_path: The path to the file with building blocks.
338
+ :param results_root: The path to the directory where trained value network will be
339
+ saved.
340
+ :return: None.
341
+ """
342
+
343
+ # create results root folder
344
+ results_root = Path(results_root)
345
+ if not results_root.exists():
346
+ results_root.mkdir()
347
+
348
+ # load targets list
349
+ with MoleculeReader(targets_path) as targets:
350
+ targets = list(targets)
351
+
352
+ # create value neural network
353
+ value_config.weights_path = os.path.join(results_root, "value_network.ckpt")
354
+ create_value_network(value_config)
355
+
356
+ # create targets batch
357
+ targets_batch_list = create_targets_batch(
358
+ targets, batch_size=reinforce_config.batch_size
359
+ )
360
+
361
+ # run value network tuning
362
+ for batch_id, targets_batch in enumerate(targets_batch_list, start=1):
363
+
364
+ # start tree planning simulation for batch of targets
365
+ tree_list = run_planning(
366
+ targets_batch=targets_batch,
367
+ tree_config=tree_config,
368
+ policy_config=policy_config,
369
+ value_config=value_config,
370
+ reaction_rules_path=reaction_rules_path,
371
+ building_blocks_path=building_blocks_path,
372
+ targets_batch_id=batch_id,
373
+ )
374
+
375
+ # extract pos and neg precursor from the list of built trees
376
+ extracted_precursor = extract_tree_precursor(tree_list)
377
+
378
+ # train value network for extracted precursor
379
+ run_training(extracted_precursor=extracted_precursor, value_config=value_config)
synplan/ml/training/supervised.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for the preparation and training of a policy network used in the expansion of
2
+ nodes in tree search.
3
+
4
+ This module includes functions for creating training datasets and running the training
5
+ process for the policy network.
6
+ """
7
+
8
+ import warnings
9
+ from pathlib import Path
10
+ from typing import Union, List
11
+
12
+ import os
13
+ import torch
14
+ from pytorch_lightning import Trainer
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+ from torch.utils.data import random_split
17
+ from torch_geometric.data.lightning import LightningDataset
18
+
19
+ from synplan.ml.networks.policy import PolicyNetwork
20
+ from synplan.ml.training.preprocessing import (
21
+ FilteringPolicyDataset,
22
+ RankingPolicyDataset,
23
+ )
24
+ from synplan.utils.config import PolicyNetworkConfig
25
+ from synplan.utils.logging import DisableLogger, HiddenPrints
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+
30
+ def create_policy_dataset(
31
+ reaction_rules_path: str,
32
+ molecules_or_reactions_path: str,
33
+ output_path: str,
34
+ dataset_type: str = "filtering",
35
+ batch_size: int = 100,
36
+ num_cpus: int = 1,
37
+ training_data_ratio: float = 0.8,
38
+ ):
39
+ """
40
+ Create a training dataset for a policy network.
41
+
42
+ :param reaction_rules_path: Path to the reaction rules file.
43
+ :param molecules_or_reactions_path: Path to the molecules or reactions file used to create the training set.
44
+ :param output_path: Path to store the processed dataset.
45
+ :param dataset_type: Type of the dataset to be created ('ranking' or 'filtering').
46
+ :param batch_size: The size of batch of molecules/reactions.
47
+ :param training_data_ratio: Ratio of training data to total data.
48
+ :param num_cpus: Number of CPUs to use for data processing.
49
+
50
+ :return: A `LightningDataset` object containing training and validation datasets.
51
+
52
+ """
53
+
54
+ with DisableLogger(), HiddenPrints():
55
+ if dataset_type == "filtering":
56
+ full_dataset = FilteringPolicyDataset(
57
+ reaction_rules_path=reaction_rules_path,
58
+ molecules_path=molecules_or_reactions_path,
59
+ output_path=output_path,
60
+ num_cpus=num_cpus,
61
+ )
62
+
63
+ elif dataset_type == "ranking":
64
+ full_dataset = RankingPolicyDataset(
65
+ reaction_rules_path=reaction_rules_path,
66
+ reactions_path=molecules_or_reactions_path,
67
+ output_path=output_path,
68
+ )
69
+
70
+ train_size = int(training_data_ratio * len(full_dataset))
71
+ val_size = len(full_dataset) - train_size
72
+
73
+ train_dataset, val_dataset = random_split(
74
+ full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
75
+ )
76
+ print(
77
+ f"Training set size: {len(train_dataset)}, validation set size: {len(val_dataset)}"
78
+ )
79
+
80
+ datamodule = LightningDataset(
81
+ train_dataset,
82
+ val_dataset,
83
+ batch_size=batch_size,
84
+ pin_memory=True,
85
+ drop_last=True,
86
+ )
87
+
88
+ return datamodule
89
+
90
+
91
+ def run_policy_training(
92
+ datamodule: LightningDataset,
93
+ config: PolicyNetworkConfig,
94
+ results_path: str,
95
+ weights_file_name: str = "policy_network",
96
+ accelerator: str = "gpu",
97
+ devices: Union[List[int], str, int] = "auto",
98
+ silent: bool = False,
99
+ ) -> None:
100
+ """
101
+ Trains a policy network using a given datamodule and training configuration.
102
+
103
+ :param datamodule: A PyTorch Lightning `DataModule` class instance. It is responsible for loading, processing, and preparing the training data for the model.
104
+ :param config: The dictionary that contains various configuration settings for the policy training process.
105
+ :param results_path: Path to store the training results and logs.
106
+ :param accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances. Default: "gpu".
107
+ :param devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto".
108
+ :param silent: Run in the silent mode with no progress bars. Default: True.
109
+ :param weights_file_name: The name of weights file to be saved. Default: "policy_network".
110
+
111
+ :return: None.
112
+
113
+ """
114
+ results_path = Path(results_path)
115
+ results_path.mkdir(exist_ok=True)
116
+
117
+ network = PolicyNetwork(
118
+ vector_dim=config.vector_dim,
119
+ n_rules=datamodule.train_dataset.dataset.num_classes,
120
+ batch_size=config.batch_size,
121
+ dropout=config.dropout,
122
+ num_conv_layers=config.num_conv_layers,
123
+ learning_rate=config.learning_rate,
124
+ policy_type=config.policy_type,
125
+ )
126
+
127
+ checkpoint = ModelCheckpoint(
128
+ dirpath=results_path, filename=weights_file_name, monitor="val_loss", mode="min"
129
+ )
130
+
131
+ if silent:
132
+ enable_progress_bar = False
133
+ else:
134
+ enable_progress_bar = True
135
+
136
+ trainer = Trainer(
137
+ accelerator=accelerator,
138
+ devices=devices,
139
+ max_epochs=config.num_epoch,
140
+ callbacks=[checkpoint],
141
+ logger=False,
142
+ gradient_clip_val=1.0,
143
+ enable_progress_bar=enable_progress_bar,
144
+ )
145
+
146
+ if silent:
147
+ with DisableLogger(), HiddenPrints():
148
+ trainer.fit(network, datamodule)
149
+ else:
150
+ trainer.fit(network, datamodule)
151
+
152
+ ba = round(trainer.logged_metrics["train_balanced_accuracy_y_step"].item(), 3)
153
+ print(f"Policy network balanced accuracy: {ba}")
synplan/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from typing import Union
2
+ from os import PathLike
3
+
4
+ path_type = Union[str, PathLike]
synplan/utils/config.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing configuration classes."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Union
7
+ from chython import smarts
8
+
9
+ import yaml
10
+ from CGRtools.containers import MoleculeContainer, QueryContainer
11
+
12
+
13
+ @dataclass
14
+ class ConfigABC(ABC):
15
+ """Abstract base class for configuration classes."""
16
+
17
+ @staticmethod
18
+ @abstractmethod
19
+ def from_dict(config_dict: Dict[str, Any]):
20
+ """Create an instance of the configuration from a dictionary."""
21
+
22
+ def to_dict(self) -> Dict[str, Any]:
23
+ """Convert the configuration into a dictionary."""
24
+ return {
25
+ k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items()
26
+ }
27
+
28
+ @staticmethod
29
+ @abstractmethod
30
+ def from_yaml(file_path: str):
31
+ """Deserialize a YAML file into a configuration object."""
32
+
33
+ def to_yaml(self, file_path: str):
34
+ """Serializes the configuration to a YAML file.
35
+
36
+ :param file_path: The path to the output YAML file.
37
+ """
38
+ with open(file_path, "w", encoding="utf-8") as file:
39
+ yaml.dump(self.to_dict(), file)
40
+
41
+ @abstractmethod
42
+ def _validate_params(self, params: Dict[str, Any]):
43
+ """Validate configuration parameters."""
44
+
45
+ def __post_init__(self):
46
+ """Validates the configuration parameters."""
47
+ # call _validate_params method after initialization
48
+ params = self.to_dict()
49
+ self._validate_params(params)
50
+
51
+
52
+ @dataclass
53
+ class RuleExtractionConfig(ConfigABC):
54
+ """Configuration class for extracting reaction rules.
55
+
56
+ :param multicenter_rules: If True, extracts a single rule
57
+ encompassing all centers. If False, extracts separate reaction
58
+ rules for each reaction center in a multicenter reaction.
59
+ :param as_query_container: If True, the extracted rules are
60
+ generated as QueryContainer objects, analogous to SMARTS objects
61
+ for pattern matching in chemical structures.
62
+ :param reverse_rule: If True, reverses the direction of the reaction
63
+ for rule extraction.
64
+ :param reactor_validation: If True, validates each generated rule in
65
+ a chemical reactor to ensure correct generation of products from
66
+ reactants.
67
+ :param include_func_groups: If True, includes specific functional
68
+ groups in the reaction rule in addition to the reaction center
69
+ and its environment.
70
+ :param func_groups_list: A list of functional groups to be
71
+ considered when include_func_groups is True.
72
+ :param include_rings: If True, includes ring structures in the
73
+ reaction rules.
74
+ :param keep_leaving_groups: If True, retains leaving groups in the
75
+ extracted reaction rule.
76
+ :param keep_incoming_groups: If True, retains incoming groups in the
77
+ extracted reaction rule.
78
+ :param keep_reagents: If True, includes reagents in the extracted
79
+ reaction rule.
80
+ :param environment_atom_count: Defines the size of the environment
81
+ around the reaction center to be included in the rule (0 for
82
+ only the reaction center, 1 for the first environment, etc.).
83
+ :param min_popularity: Minimum number of times a rule must be
84
+ applied to be considered for further analysis.
85
+ :param keep_metadata: If True, retains metadata associated with the
86
+ reaction in the extracted rule.
87
+ :param single_reactant_only: If True, includes only reaction rules
88
+ with a single reactant molecule.
89
+ :param atom_info_retention: Controls the amount of information about
90
+ each atom to retain ('none', 'reaction_center', or 'all').
91
+ """
92
+
93
+ # default low-level parameters
94
+ single_reactant_only: bool = True
95
+ keep_metadata: bool = False
96
+ reactor_validation: bool = True
97
+ reverse_rule: bool = True
98
+ as_query_container: bool = True
99
+ include_func_groups: bool = False
100
+ func_groups_list: List[str] = field(default_factory=list)
101
+
102
+ # adjustable parameters
103
+ environment_atom_count: int = 1
104
+ min_popularity: int = 3
105
+ include_rings: bool = True
106
+ multicenter_rules: bool = True
107
+ keep_leaving_groups: bool = True
108
+ keep_incoming_groups: bool = True
109
+ keep_reagents: bool = False
110
+ atom_info_retention: Dict[str, Dict[str, bool]] = field(default_factory=dict)
111
+
112
+ def __post_init__(self):
113
+ super().__post_init__()
114
+ self._validate_params(self.to_dict())
115
+ self._initialize_default_atom_info_retention()
116
+ self._parse_functional_groups()
117
+
118
+ def _initialize_default_atom_info_retention(self):
119
+ default_atom_info = {
120
+ "reaction_center": {
121
+ "neighbors": True,
122
+ "hybridization": True,
123
+ "implicit_hydrogens": False,
124
+ "ring_sizes": False,
125
+ },
126
+ "environment": {
127
+ "neighbors": False,
128
+ "hybridization": False,
129
+ "implicit_hydrogens": False,
130
+ "ring_sizes": False,
131
+ },
132
+ }
133
+
134
+ if not self.atom_info_retention:
135
+ self.atom_info_retention = default_atom_info
136
+ else:
137
+ for key in default_atom_info:
138
+ self.atom_info_retention[key].update(
139
+ self.atom_info_retention.get(key, {})
140
+ )
141
+
142
+ def _parse_functional_groups(self):
143
+ func_groups_list = []
144
+ for group_smarts in self.func_groups_list:
145
+ try:
146
+ query = smarts(group_smarts)
147
+ func_groups_list.append(query)
148
+ except Exception as e:
149
+ print(f"Functional group {group_smarts} was not parsed because of {e}")
150
+ self.func_groups_list = func_groups_list
151
+
152
+ @staticmethod
153
+ def from_dict(config_dict: Dict[str, Any]) -> "RuleExtractionConfig":
154
+ return RuleExtractionConfig(**config_dict)
155
+
156
+ @staticmethod
157
+ def from_yaml(file_path: str) -> "RuleExtractionConfig":
158
+
159
+ with open(file_path, "r", encoding="utf-8") as file:
160
+ config_dict = yaml.safe_load(file)
161
+ return RuleExtractionConfig.from_dict(config_dict)
162
+
163
+ def _validate_params(self, params: Dict[str, Any]) -> None:
164
+
165
+ if not isinstance(params["multicenter_rules"], bool):
166
+ raise ValueError("multicenter_rules must be a boolean.")
167
+
168
+ if not isinstance(params["as_query_container"], bool):
169
+ raise ValueError("as_query_container must be a boolean.")
170
+
171
+ if not isinstance(params["reverse_rule"], bool):
172
+ raise ValueError("reverse_rule must be a boolean.")
173
+
174
+ if not isinstance(params["reactor_validation"], bool):
175
+ raise ValueError("reactor_validation must be a boolean.")
176
+
177
+ if not isinstance(params["include_func_groups"], bool):
178
+ raise ValueError("include_func_groups must be a boolean.")
179
+
180
+ if params["func_groups_list"] is not None and not all(
181
+ isinstance(group, str) for group in params["func_groups_list"]
182
+ ):
183
+ raise ValueError("func_groups_list must be a list of SMARTS.")
184
+
185
+ if not isinstance(params["include_rings"], bool):
186
+ raise ValueError("include_rings must be a boolean.")
187
+
188
+ if not isinstance(params["keep_leaving_groups"], bool):
189
+ raise ValueError("keep_leaving_groups must be a boolean.")
190
+
191
+ if not isinstance(params["keep_incoming_groups"], bool):
192
+ raise ValueError("keep_incoming_groups must be a boolean.")
193
+
194
+ if not isinstance(params["keep_reagents"], bool):
195
+ raise ValueError("keep_reagents must be a boolean.")
196
+
197
+ if not isinstance(params["environment_atom_count"], int):
198
+ raise ValueError("environment_atom_count must be an integer.")
199
+
200
+ if not isinstance(params["min_popularity"], int):
201
+ raise ValueError("min_popularity must be an integer.")
202
+
203
+ if not isinstance(params["keep_metadata"], bool):
204
+ raise ValueError("keep_metadata must be a boolean.")
205
+
206
+ if not isinstance(params["single_reactant_only"], bool):
207
+ raise ValueError("single_reactant_only must be a boolean.")
208
+
209
+ if params["atom_info_retention"] is not None:
210
+ if not isinstance(params["atom_info_retention"], dict):
211
+ raise ValueError("atom_info_retention must be a dictionary.")
212
+
213
+ required_keys = {"reaction_center", "environment"}
214
+ if not required_keys.issubset(params["atom_info_retention"]):
215
+ missing_keys = required_keys - set(params["atom_info_retention"].keys())
216
+ raise ValueError(
217
+ f"atom_info_retention missing required keys: {missing_keys}"
218
+ )
219
+
220
+ for key, value in params["atom_info_retention"].items():
221
+ if key not in required_keys:
222
+ raise ValueError(f"Unexpected key in atom_info_retention: {key}")
223
+
224
+ expected_subkeys = {
225
+ "neighbors",
226
+ "hybridization",
227
+ "implicit_hydrogens",
228
+ "ring_sizes",
229
+ }
230
+ if not isinstance(value, dict) or not expected_subkeys.issubset(value):
231
+ missing_subkeys = expected_subkeys - set(value.keys())
232
+ raise ValueError(
233
+ f"Invalid structure for {key} in atom_info_retention. Missing subkeys: {missing_subkeys}"
234
+ )
235
+
236
+ for subkey, subvalue in value.items():
237
+ if not isinstance(subvalue, bool):
238
+ raise ValueError(
239
+ f"Value for {subkey} in {key} of atom_info_retention must be boolean."
240
+ )
241
+
242
+
243
+ @dataclass
244
+ class PolicyNetworkConfig(ConfigABC):
245
+ """Configuration class for the policy network.
246
+
247
+ :param vector_dim: Dimension of the input vectors.
248
+ :param batch_size: Number of samples per batch.
249
+ :param dropout: Dropout rate for regularization.
250
+ :param learning_rate: Learning rate for the optimizer.
251
+ :param num_conv_layers: Number of convolutional layers in the network.
252
+ :param num_epoch: Number of training epochs.
253
+ :param policy_type: Mode of operation, either 'filtering' or 'ranking'.
254
+ """
255
+
256
+ policy_type: str = "ranking"
257
+ vector_dim: int = 256
258
+ batch_size: int = 500
259
+ dropout: float = 0.4
260
+ learning_rate: float = 0.008
261
+ num_conv_layers: int = 5
262
+ num_epoch: int = 100
263
+ weights_path: str = None
264
+
265
+ # for filtering policy
266
+ priority_rules_fraction: float = 0.5
267
+ rule_prob_threshold: float = 0.0
268
+ top_rules: int = 50
269
+
270
+ @staticmethod
271
+ def from_dict(config_dict: Dict[str, Any]) -> "PolicyNetworkConfig":
272
+ return PolicyNetworkConfig(**config_dict)
273
+
274
+ @staticmethod
275
+ def from_yaml(file_path: str) -> "PolicyNetworkConfig":
276
+ with open(file_path, "r", encoding="utf-8") as file:
277
+ config_dict = yaml.safe_load(file)
278
+ return PolicyNetworkConfig.from_dict(config_dict)
279
+
280
+ def _validate_params(self, params: Dict[str, Any]):
281
+
282
+ if params["policy_type"] not in ["filtering", "ranking"]:
283
+ raise ValueError("policy_type must be either 'filtering' or 'ranking'.")
284
+
285
+ if not isinstance(params["vector_dim"], int) or params["vector_dim"] <= 0:
286
+ raise ValueError("vector_dim must be a positive integer.")
287
+
288
+ if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
289
+ raise ValueError("batch_size must be a positive integer.")
290
+
291
+ if (
292
+ not isinstance(params["num_conv_layers"], int)
293
+ or params["num_conv_layers"] <= 0
294
+ ):
295
+ raise ValueError("num_conv_layers must be a positive integer.")
296
+
297
+ if not isinstance(params["num_epoch"], int) or params["num_epoch"] <= 0:
298
+ raise ValueError("num_epoch must be a positive integer.")
299
+
300
+ if not isinstance(params["dropout"], float) or not (
301
+ 0.0 <= params["dropout"] <= 1.0
302
+ ):
303
+ raise ValueError("dropout must be a float between 0.0 and 1.0.")
304
+
305
+ if (
306
+ not isinstance(params["learning_rate"], float)
307
+ or params["learning_rate"] <= 0.0
308
+ ):
309
+ raise ValueError("learning_rate must be a positive float.")
310
+
311
+ if (
312
+ not isinstance(params["priority_rules_fraction"], float)
313
+ or params["priority_rules_fraction"] < 0.0
314
+ ):
315
+ raise ValueError(
316
+ "priority_rules_fraction must be a non-negative positive float."
317
+ )
318
+
319
+ if (
320
+ not isinstance(params["rule_prob_threshold"], float)
321
+ or params["rule_prob_threshold"] < 0.0
322
+ ):
323
+ raise ValueError("rule_prob_threshold must be a non-negative float.")
324
+
325
+ if not isinstance(params["top_rules"], int) or params["top_rules"] <= 0:
326
+ raise ValueError("top_rules must be a positive integer.")
327
+
328
+
329
+ @dataclass
330
+ class ValueNetworkConfig(ConfigABC):
331
+ """Configuration class for the value network.
332
+
333
+ :param vector_dim: Dimension of the input vectors.
334
+ :param batch_size: Number of samples per batch.
335
+ :param dropout: Dropout rate for regularization.
336
+ :param learning_rate: Learning rate for the optimizer.
337
+ :param num_conv_layers: Number of convolutional layers in the network.
338
+ :param num_epoch: Number of training epochs.
339
+ """
340
+
341
+ weights_path: str = None
342
+ vector_dim: int = 256
343
+ batch_size: int = 500
344
+ dropout: float = 0.4
345
+ learning_rate: float = 0.008
346
+ num_conv_layers: int = 5
347
+ num_epoch: int = 100
348
+
349
+ @staticmethod
350
+ def from_dict(config_dict: Dict[str, Any]) -> "ValueNetworkConfig":
351
+ return ValueNetworkConfig(**config_dict)
352
+
353
+ @staticmethod
354
+ def from_yaml(file_path: str) -> "ValueNetworkConfig":
355
+ with open(file_path, "r", encoding="utf-8") as file:
356
+ config_dict = yaml.safe_load(file)
357
+ return ValueNetworkConfig.from_dict(config_dict)
358
+
359
+ def to_yaml(self, file_path: str):
360
+ with open(file_path, "w", encoding="utf-8") as file:
361
+ yaml.dump(self.to_dict(), file)
362
+
363
+ def _validate_params(self, params: Dict[str, Any]):
364
+
365
+ if not isinstance(params["vector_dim"], int) or params["vector_dim"] <= 0:
366
+ raise ValueError("vector_dim must be a positive integer.")
367
+
368
+ if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
369
+ raise ValueError("batch_size must be a positive integer.")
370
+
371
+ if (
372
+ not isinstance(params["num_conv_layers"], int)
373
+ or params["num_conv_layers"] <= 0
374
+ ):
375
+ raise ValueError("num_conv_layers must be a positive integer.")
376
+
377
+ if not isinstance(params["num_epoch"], int) or params["num_epoch"] <= 0:
378
+ raise ValueError("num_epoch must be a positive integer.")
379
+
380
+ if not isinstance(params["dropout"], float) or not (
381
+ 0.0 <= params["dropout"] <= 1.0
382
+ ):
383
+ raise ValueError("dropout must be a float between 0.0 and 1.0.")
384
+
385
+ if (
386
+ not isinstance(params["learning_rate"], float)
387
+ or params["learning_rate"] <= 0.0
388
+ ):
389
+ raise ValueError("learning_rate must be a positive float.")
390
+
391
+
392
+ @dataclass
393
+ class TuningConfig(ConfigABC):
394
+ """Configuration class for the network training.
395
+
396
+ :param batch_size: The number of targets per batch in the planning simulation step.
397
+ :param num_simulations: The number of planning simulations.
398
+ """
399
+
400
+ batch_size: int = 100
401
+ num_simulations: int = 1
402
+
403
+ @staticmethod
404
+ def from_dict(config_dict: Dict[str, Any]) -> "TuningConfig":
405
+ return TuningConfig(**config_dict)
406
+
407
+ @staticmethod
408
+ def from_yaml(file_path: str) -> "TuningConfig":
409
+ with open(file_path, "r", encoding="utf-8") as file:
410
+ config_dict = yaml.safe_load(file)
411
+ return TuningConfig.from_dict(config_dict)
412
+
413
+ def _validate_params(self, params: Dict[str, Any]):
414
+
415
+ if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
416
+ raise ValueError("batch_size must be a positive integer.")
417
+
418
+
419
+ @dataclass
420
+ class TreeConfig(ConfigABC):
421
+ """Configuration class for the tree search algorithm.
422
+
423
+ :param max_iterations: The number of iterations to run the algorithm
424
+ for.
425
+ :param max_tree_size: The maximum number of nodes in the tree.
426
+ :param max_time: The time limit (in seconds) for the algorithm to
427
+ run.
428
+ :param max_depth: The maximum depth of the tree.
429
+ :param ucb_type: Type of UCB used in the search algorithm. Options
430
+ are "puct", "uct", "value", defaults to "uct".
431
+ :param c_ucb: The exploration-exploitation balance coefficient used
432
+ in Upper Confidence Bound (UCB).
433
+ :param backprop_type: Type of backpropagation algorithm. Options are
434
+ "muzero", "cumulative", defaults to "muzero".
435
+ :param search_strategy: The strategy used for tree search. Options
436
+ are "expansion_first", "evaluation_first".
437
+ :param exclude_small: Whether to exclude small molecules during the
438
+ search.
439
+ :param evaluation_agg: Method for aggregating evaluation scores.
440
+ Options are "max", "average", defaults to "max".
441
+ :param evaluation_type: The method used for evaluating nodes.
442
+ Options are "random", "rollout", "gcn".
443
+ :param init_node_value: Initial value for a new node.
444
+ :param epsilon: A parameter in the epsilon-greedy search strategy
445
+ representing the chance of random selection of reaction rules
446
+ during the selection stage in Monte Carlo Tree Search,
447
+ specifically during Upper Confidence Bound estimation. It
448
+ balances between exploration and exploitation.
449
+ :param min_mol_size: Defines the minimum size of a molecule that is
450
+ have to be synthesized. Molecules with 6 or fewer heavy atoms
451
+ are assumed to be building blocks by definition, thus setting
452
+ the threshold for considering larger molecules in the search,
453
+ defaults to 6.
454
+ :param silent: Whether to suppress progress output.
455
+ """
456
+
457
+ max_iterations: int = 100
458
+ max_tree_size: int = 1000000
459
+ max_time: float = 600
460
+ max_depth: int = 6
461
+ ucb_type: str = "uct"
462
+ c_ucb: float = 0.1
463
+ backprop_type: str = "muzero"
464
+ search_strategy: str = "expansion_first"
465
+ exclude_small: bool = True
466
+ evaluation_agg: str = "max"
467
+ evaluation_type: str = "gcn"
468
+ init_node_value: float = 0.0
469
+ epsilon: float = 0.0
470
+ min_mol_size: int = 6
471
+ silent: bool = False
472
+
473
+ @staticmethod
474
+ def from_dict(config_dict: Dict[str, Any]) -> "TreeConfig":
475
+ return TreeConfig(**config_dict)
476
+
477
+ @staticmethod
478
+ def from_yaml(file_path: str) -> "TreeConfig":
479
+ with open(file_path, "r", encoding="utf-8") as file:
480
+ config_dict = yaml.safe_load(file)
481
+ return TreeConfig.from_dict(config_dict)
482
+
483
+ def _validate_params(self, params):
484
+ if params["ucb_type"] not in ["puct", "uct", "value"]:
485
+ raise ValueError(
486
+ "Invalid ucb_type. Allowed values are 'puct', 'uct', 'value'."
487
+ )
488
+ if params["backprop_type"] not in ["muzero", "cumulative"]:
489
+ raise ValueError(
490
+ "Invalid backprop_type. Allowed values are 'muzero', 'cumulative'."
491
+ )
492
+ if params["evaluation_type"] not in ["random", "rollout", "gcn"]:
493
+ raise ValueError(
494
+ "Invalid evaluation_type. Allowed values are 'random', 'rollout', 'gcn'."
495
+ )
496
+ if params["evaluation_agg"] not in ["max", "average"]:
497
+ raise ValueError(
498
+ "Invalid evaluation_agg. Allowed values are 'max', 'average'."
499
+ )
500
+ if not isinstance(params["c_ucb"], float):
501
+ raise TypeError("c_ucb must be a float.")
502
+ if not isinstance(params["max_depth"], int) or params["max_depth"] < 1:
503
+ raise ValueError("max_depth must be a positive integer.")
504
+ if not isinstance(params["max_tree_size"], int) or params["max_tree_size"] < 1:
505
+ raise ValueError("max_tree_size must be a positive integer.")
506
+ if (
507
+ not isinstance(params["max_iterations"], int)
508
+ or params["max_iterations"] < 1
509
+ ):
510
+ raise ValueError("max_iterations must be a positive integer.")
511
+ if not isinstance(params["max_time"], int) or params["max_time"] < 1:
512
+ raise ValueError("max_time must be a positive integer.")
513
+ if not isinstance(params["exclude_small"], bool):
514
+ raise TypeError("exclude_small must be a boolean.")
515
+ if not isinstance(params["silent"], bool):
516
+ raise TypeError("silent must be a boolean.")
517
+ if not isinstance(params["init_node_value"], float):
518
+ raise TypeError("init_node_value must be a float if provided.")
519
+ if params["search_strategy"] not in ["expansion_first", "evaluation_first"]:
520
+ raise ValueError(
521
+ f"Invalid search_strategy: {params['search_strategy']}: "
522
+ f"Allowed values are 'expansion_first', 'evaluation_first'"
523
+ )
524
+ if not isinstance(params["epsilon"], float) or 0 >= params["epsilon"] >= 1:
525
+ raise ValueError("epsilon epsilon be a positive float between 0 and 1.")
526
+ if not isinstance(params["min_mol_size"], int) or params["min_mol_size"] < 0:
527
+ raise ValueError("min_mol_size must be a non-negative integer.")
528
+
529
+
530
+ def convert_config_to_dict(config_attr: ConfigABC, config_type) -> Dict | None:
531
+ """Converts a configuration attribute to a dictionary if it's either a dictionary or
532
+ an instance of a specified configuration type.
533
+
534
+ :param config_attr: The configuration attribute to be converted.
535
+ :param config_type: The type to check against for conversion.
536
+ :return: The configuration attribute as a dictionary, or None if it's not an
537
+ instance of the given type or dict.
538
+ """
539
+ if isinstance(config_attr, dict):
540
+ return config_attr
541
+ if isinstance(config_attr, config_type):
542
+ return config_attr.to_dict()
543
+ return None
synplan/utils/files.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes and functions needed for reactions/molecules data
2
+ reading/writing."""
3
+
4
+ from os.path import splitext
5
+ from pathlib import Path
6
+ from typing import Iterable, Union
7
+
8
+ from CGRtools import smiles
9
+ from CGRtools.containers import CGRContainer, MoleculeContainer, ReactionContainer
10
+ from CGRtools.files.RDFrw import RDFRead, RDFWrite
11
+ from CGRtools.files.SDFrw import SDFRead, SDFWrite
12
+
13
+
14
+ class FileHandler:
15
+ """General class to handle chemical files."""
16
+
17
+ def __init__(self, filename: Union[str, Path], **kwargs):
18
+ """General class to handle chemical files.
19
+
20
+ :param filename: The path and name of the file.
21
+ :return: None.
22
+ """
23
+ self._file = None
24
+ _, ext = splitext(filename)
25
+ file_types = {".smi": "SMI", ".smiles": "SMI", ".rdf": "RDF", ".sdf": "SDF"}
26
+ try:
27
+ self._file_type = file_types[ext]
28
+ except KeyError:
29
+ raise ValueError("I don't know the file extension,", ext)
30
+
31
+ def close(self):
32
+ self._file.close()
33
+
34
+ def __exit__(self, exc_type, exc_val, exc_tb):
35
+ self.close()
36
+
37
+
38
+ class Reader(FileHandler):
39
+ def __init__(self, filename: Union[str, Path], **kwargs):
40
+ """General class to read reactions/molecules data files.
41
+
42
+ :param filename: The path and name of the file.
43
+ :return: None.
44
+ """
45
+ super().__init__(filename, **kwargs)
46
+
47
+ def __enter__(self):
48
+ return self._file
49
+
50
+ def __iter__(self):
51
+ return iter(self._file)
52
+
53
+ def __next__(self):
54
+ return next(self._file)
55
+
56
+ def __len__(self):
57
+ return len(self._file)
58
+
59
+
60
+ class SMILESRead:
61
+ def __init__(self, filename: Union[str, Path], **kwargs):
62
+ """Simplified class to read files containing a SMILES (Molecules or Reaction)
63
+ string per line.
64
+
65
+ :param filename: The path and name of the SMILES file to parse.
66
+ :return: None.
67
+ """
68
+ filename = str(Path(filename).resolve(strict=True))
69
+ self._file = open(filename, "r", encoding="utf-8")
70
+ self._data = self.__data()
71
+
72
+ def __data(
73
+ self,
74
+ ) -> Iterable[Union[ReactionContainer, CGRContainer, MoleculeContainer]]:
75
+ for line in iter(self._file.readline, ""):
76
+ line = line.strip()
77
+ x = smiles(line)
78
+ if isinstance(x, (ReactionContainer, CGRContainer, MoleculeContainer)):
79
+ x.meta["init_smiles"] = line
80
+ yield x
81
+
82
+ def __enter__(self):
83
+ return self
84
+
85
+ def read(self):
86
+ """Parse the whole SMILES file.
87
+
88
+ :return: List of parsed molecules or reactions.
89
+ """
90
+ return list(iter(self))
91
+
92
+ def __iter__(self):
93
+ return (x for x in self._data)
94
+
95
+ def __next__(self):
96
+ return next(iter(self))
97
+
98
+ def close(self):
99
+ self._file.close()
100
+
101
+ def __exit__(self, exc_type, exc_val, exc_tb):
102
+ self.close()
103
+
104
+
105
+ class Writer(FileHandler):
106
+ def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
107
+ """General class to write chemical files.
108
+
109
+ :param filename: The path and name of the file.
110
+ :param mapping: Whenever to save mapping or not.
111
+ :return: None.
112
+ """
113
+ super().__init__(filename, **kwargs)
114
+ self._mapping = mapping
115
+
116
+ def __enter__(self):
117
+ return self
118
+
119
+
120
+ class ReactionReader(Reader):
121
+ def __init__(self, filename: Union[str, Path], **kwargs):
122
+ """Class to read reaction files.
123
+
124
+ :param filename: The path and name of the file.
125
+ :return: None.
126
+ """
127
+ super().__init__(filename, **kwargs)
128
+ if self._file_type == "SMI":
129
+ self._file = SMILESRead(filename, **kwargs)
130
+ elif self._file_type == "RDF":
131
+ self._file = RDFRead(filename, indexable=True, **kwargs)
132
+ else:
133
+ raise ValueError("File type incompatible -", filename)
134
+
135
+
136
+ class ReactionWriter(Writer):
137
+ def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
138
+ """Class to write reaction files.
139
+
140
+ :param filename: The path and name of the file.
141
+ :param mapping: Whenever to save mapping or not.
142
+ :return: None.
143
+ """
144
+ super().__init__(filename, mapping, **kwargs)
145
+ if self._file_type == "SMI":
146
+ self._file = open(filename, "w", encoding="utf-8", **kwargs)
147
+ elif self._file_type == "RDF":
148
+ self._file = RDFWrite(filename, append=False, **kwargs)
149
+ else:
150
+ raise ValueError("File type incompatible -", filename)
151
+
152
+ def write(self, reaction: ReactionContainer):
153
+ """Function to write a specific reaction to the file.
154
+
155
+ :param reaction: The path and name of the file.
156
+ :return: None.
157
+ """
158
+ if self._file_type == "SMI":
159
+ rea_str = to_reaction_smiles_record(reaction)
160
+ self._file.write(rea_str + "\n")
161
+ elif self._file_type == "RDF":
162
+ self._file.write(reaction)
163
+
164
+
165
+ class MoleculeReader(Reader):
166
+ def __init__(self, filename: Union[str, Path], **kwargs):
167
+ """Class to read molecule files.
168
+
169
+ :param filename: The path and name of the file.
170
+ :return: None.
171
+ """
172
+ super().__init__(filename, **kwargs)
173
+ if self._file_type == "SMI":
174
+ self._file = SMILESRead(filename, ignore=True, **kwargs)
175
+ elif self._file_type == "SDF":
176
+ self._file = SDFRead(filename, indexable=True, **kwargs)
177
+ else:
178
+ raise ValueError("File type incompatible -", filename)
179
+
180
+
181
+ class MoleculeWriter(Writer):
182
+ def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
183
+ """Class to write molecule files.
184
+
185
+ :param filename: The path and name of the file.
186
+ :param mapping: Whenever to save mapping or not.
187
+ :return: None.
188
+ """
189
+ super().__init__(filename, mapping, **kwargs)
190
+ if self._file_type == "SMI":
191
+ self._file = open(filename, "w", encoding="utf-8", **kwargs)
192
+ elif self._file_type == "SDF":
193
+ self._file = SDFWrite(filename, append=False, **kwargs)
194
+ else:
195
+ raise ValueError("File type incompatible -", filename)
196
+
197
+ def write(self, molecule: MoleculeContainer):
198
+ """Function to write a specific molecule to the file.
199
+
200
+ :param molecule: The path and name of the file.
201
+ :return: None.
202
+ """
203
+ if self._file_type == "SMI":
204
+ mol_str = str(molecule)
205
+ self._file.write(mol_str + "\n")
206
+ elif self._file_type == "SDF":
207
+ self._file.write(molecule)
208
+
209
+
210
+ def to_reaction_smiles_record(reaction: ReactionContainer) -> str:
211
+ """Converts the reaction to the SMILES record. Needed for reaction/molecule writers.
212
+
213
+ :param reaction: The reaction to be written.
214
+ :return: The SMILES record to be written.
215
+ """
216
+
217
+ if isinstance(reaction, str):
218
+ return reaction
219
+
220
+ reaction_record = [format(reaction, "m")]
221
+ sorted_meta = sorted(reaction.meta.items(), key=lambda x: x[0])
222
+ for _, meta_info in sorted_meta:
223
+ meta_info = ""
224
+ meta_info = ";".join(meta_info.split("\n"))
225
+ reaction_record.append(str(meta_info))
226
+ return "\t".join(reaction_record)
synplan/utils/loading.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for loading reaction rules, building blocks and
2
+ retrosynthetic models."""
3
+
4
+ import functools
5
+ import pickle
6
+ import zipfile
7
+ from pathlib import Path
8
+ from typing import List, Set, Union
9
+
10
+ from CGRtools.reactor.reactor import Reactor
11
+ from torch import device
12
+ from huggingface_hub import hf_hub_download, snapshot_download
13
+ from tqdm import tqdm
14
+
15
+ from synplan.ml.networks.policy import PolicyNetwork
16
+ from synplan.ml.networks.value import ValueNetwork
17
+ from synplan.utils.files import MoleculeReader
18
+
19
+
20
+ def download_unpack_data(filename, subfolder, save_to="."):
21
+ if isinstance(save_to, str):
22
+ save_to = Path(save_to).resolve()
23
+ save_to.mkdir(exist_ok=True)
24
+
25
+ # Download the zip file from the repository
26
+ file_path = hf_hub_download(
27
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
28
+ filename=filename,
29
+ subfolder=subfolder,
30
+ local_dir=save_to,
31
+ )
32
+ file_path = Path(file_path)
33
+
34
+ if file_path.suffix == ".zip":
35
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
36
+ # Extract the single file in the zip
37
+ zip_ref.extractall(save_to)
38
+ extracted_file = save_to / zip_ref.namelist()[0]
39
+
40
+ file_path.unlink()
41
+
42
+ return extracted_file
43
+ else:
44
+ return file_path
45
+
46
+
47
+ def download_all_data(save_to="."):
48
+ dir_path = snapshot_download(
49
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", local_dir=save_to
50
+ )
51
+ dir_path = Path(dir_path).resolve()
52
+ for zip_file in dir_path.rglob("*.zip"):
53
+ with zipfile.ZipFile(zip_file, "r") as zip_ref:
54
+ # Check each file in the zip
55
+ for file_name in zip_ref.namelist():
56
+ extracted_file_path = zip_file.parent / file_name
57
+
58
+ # Check if the extracted file already exists
59
+ if not extracted_file_path.exists():
60
+ # Extract the file if it does not exist
61
+ zip_ref.extract(file_name, zip_file.parent)
62
+ print(f"Extracted {file_name} to {zip_file.parent}")
63
+
64
+
65
+ @functools.lru_cache(maxsize=None)
66
+ def load_reaction_rules(file: str) -> List[Reactor]:
67
+ """Loads the reaction rules from a pickle file and converts them into a list of
68
+ Reactor objects if necessary.
69
+
70
+ :param file: The path to the pickle file that stores the reaction rules.
71
+ :return: A list of reaction rules as Reactor objects.
72
+ """
73
+
74
+ with open(file, "rb") as f:
75
+ reaction_rules = pickle.load(f)
76
+
77
+ if not isinstance(reaction_rules[0][0], Reactor):
78
+ reaction_rules = [Reactor(x) for x, _ in reaction_rules]
79
+
80
+ return reaction_rules
81
+
82
+
83
+ @functools.lru_cache(maxsize=None)
84
+ def load_building_blocks(
85
+ building_blocks_path: Union[str, Path], standardize: bool = True
86
+ ) -> Set[str]:
87
+ """Loads building blocks data from a file and returns a frozen set of building
88
+ blocks.
89
+
90
+ :param building_blocks_path: The path to the file containing the building blocks.
91
+ :param standardize: Flag if building blocks have to be standardized before loading. Default=True.
92
+ :return: The set of building blocks smiles.
93
+ """
94
+
95
+ building_blocks_path = Path(building_blocks_path).resolve()
96
+ assert (
97
+ building_blocks_path.suffix == ".smi"
98
+ or building_blocks_path.suffix == ".smiles"
99
+ )
100
+
101
+ building_blocks_smiles = set()
102
+ if standardize:
103
+ with MoleculeReader(building_blocks_path) as molecules:
104
+ for mol in tqdm(
105
+ molecules,
106
+ desc="Number of building blocks processed: ",
107
+ bar_format="{desc}{n} [{elapsed}]",
108
+ ):
109
+ try:
110
+ mol.canonicalize()
111
+ mol.clean_stereo()
112
+ building_blocks_smiles.add(str(mol))
113
+ except: # mol.canonicalize() / InvalidAromaticRing
114
+ pass
115
+ else:
116
+ with open(building_blocks_path, "r") as inp:
117
+ for line in inp:
118
+ smiles = line.strip().split()[0]
119
+ building_blocks_smiles.add(smiles)
120
+
121
+ return building_blocks_smiles
122
+
123
+
124
+ def load_value_net(
125
+ model_class: ValueNetwork, value_network_path: Union[str, Path]
126
+ ) -> ValueNetwork:
127
+ """Loads the value network.
128
+
129
+ :param value_network_path: The path to the file storing value network weights.
130
+ :param model_class: The model class to be loaded.
131
+ :return: The loaded value network.
132
+ """
133
+
134
+ map_location = device("cpu")
135
+ return model_class.load_from_checkpoint(value_network_path, map_location)
136
+
137
+
138
+ def load_policy_net(
139
+ model_class: PolicyNetwork, policy_network_path: Union[str, Path]
140
+ ) -> PolicyNetwork:
141
+ """Loads the policy network.
142
+
143
+ :param policy_network_path: The path to the file storing policy network weights.
144
+ :param model_class: The model class to be loaded.
145
+ :return: The loaded policy network.
146
+ """
147
+
148
+ map_location = device("cpu")
149
+ return model_class.load_from_checkpoint(
150
+ policy_network_path, map_location, batch_size=1
151
+ )
synplan/utils/logging.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generic logging helpers for scripts, notebooks and Ray clusters.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ import logging, sys, os, warnings
7
+ from pathlib import Path
8
+ from datetime import datetime
9
+ from typing import Iterable, Optional
10
+ from IPython import get_ipython
11
+
12
+
13
+ # --------------------------------------------------------------------------- #
14
+ # Helper classes #
15
+ # --------------------------------------------------------------------------- #
16
+
17
+
18
+ class DisableLogger:
19
+ """Context‑manager that suppresses *all* logging inside its scope."""
20
+
21
+ def __enter__(self):
22
+ logging.disable(logging.CRITICAL)
23
+
24
+ def __exit__(self, exc_type, exc_val, exc_tb):
25
+ logging.disable(logging.NOTSET)
26
+
27
+
28
+ class HiddenPrints:
29
+ """Context‑manager that suppresses *print* output inside its scope."""
30
+
31
+ def __enter__(self):
32
+ self._orig = sys.stdout
33
+ sys.stdout = open(os.devnull, "w")
34
+
35
+ def __exit__(self, exc_type, exc_val, exc_tb):
36
+ sys.stdout.close()
37
+ sys.stdout = self._orig
38
+
39
+
40
+ # --------------------------------------------------------------------------- #
41
+ # Notebook‑aware console handler #
42
+ # --------------------------------------------------------------------------- #
43
+
44
+
45
+ def _in_notebook() -> bool:
46
+ ip = get_ipython()
47
+ return bool(ip) and ip.__class__.__name__ == "ZMQInteractiveShell"
48
+
49
+
50
+ class TqdmHandler(logging.StreamHandler):
51
+ """Write via tqdm.write so log lines don't break progress bars."""
52
+
53
+ def emit(self, record):
54
+ try:
55
+ from tqdm import tqdm
56
+
57
+ tqdm.write(self.format(record), end=self.terminator)
58
+ except ModuleNotFoundError:
59
+ super().emit(record)
60
+
61
+
62
+ # --------------------------------------------------------------------------- #
63
+ # Public initialisation API #
64
+ # --------------------------------------------------------------------------- #
65
+
66
+
67
+ def init_logger(
68
+ *,
69
+ name: str = "app",
70
+ console_level: str | int = "ERROR",
71
+ file_level: str | int = "INFO",
72
+ log_dir: str | os.PathLike = ".",
73
+ redirect_tqdm: bool = True,
74
+ ) -> logging.Logger:
75
+ """
76
+ Initialise (or fetch) a namespaced logger that works in scripts &
77
+ notebooks. Idempotent ‑ safe to call multiple times.
78
+
79
+ Returns
80
+ -------
81
+ logging.Logger
82
+ Configured logger instance.
83
+ """
84
+ logger = logging.getLogger(name)
85
+ if logger.handlers: # already configured
86
+ return logger
87
+
88
+ logger.setLevel("DEBUG") # capture everything; handlers filter
89
+
90
+ # console / notebook handler
91
+ if _in_notebook() or (redirect_tqdm and "tqdm" in sys.modules):
92
+ ch: logging.Handler = TqdmHandler()
93
+ else:
94
+ ch = logging.StreamHandler(sys.stderr)
95
+ ch.setLevel(console_level)
96
+ ch.setFormatter(
97
+ logging.Formatter(
98
+ "%(asctime)s | %(levelname)-8s | %(message)s",
99
+ datefmt="%H:%M:%S",
100
+ )
101
+ )
102
+ logger.addHandler(ch)
103
+
104
+ # rotating file handler (one file per session)
105
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
106
+ stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
107
+ fh = logging.FileHandler(Path(log_dir) / f"{name}_{stamp}.log", encoding="utf-8")
108
+ fh.setLevel(file_level)
109
+ fh.setFormatter(
110
+ logging.Formatter(
111
+ "%(asctime)s | %(name)s | %(levelname)-8s | %(message)s",
112
+ datefmt="%Y-%m-%d %H:%M:%S",
113
+ )
114
+ )
115
+ logger.addHandler(fh)
116
+
117
+ # logger.propagate = False # Removed correctly
118
+ log_file_path = fh.baseFilename
119
+ logger.info("Logging initialised → %s", log_file_path)
120
+ return logger, log_file_path # <-- Return path too
121
+
122
+
123
+ # --------------------------------------------------------------------------- #
124
+ # Optional Ray‑specific configuration helpers #
125
+ # --------------------------------------------------------------------------- #
126
+
127
+
128
+ def init_ray_logging(
129
+ *,
130
+ python_level: str | int = "ERROR",
131
+ backend_level: str = "error",
132
+ log_to_driver: bool = False,
133
+ filter_userwarnings: bool = True,
134
+ ) -> "ray.LoggingConfig":
135
+ """
136
+ Prepare environment + Ray LoggingConfig **before** `ray.init()`.
137
+
138
+ Returns
139
+ -------
140
+ ray.LoggingConfig
141
+ Pass as `logging_config=` argument to `ray.init()`.
142
+ """
143
+ # 1) silence C++ backend (raylet / plasma) BEFORE importing ray
144
+ os.environ.setdefault("RAY_BACKEND_LOG_LEVEL", backend_level)
145
+
146
+ # 2) optional warnings filter
147
+ if filter_userwarnings:
148
+ warnings.filterwarnings("ignore", category=UserWarning)
149
+
150
+ import ray # local import to avoid hard dep
151
+
152
+ # 3) global Python logger levels for every worker
153
+ ray_logger_names: Iterable[str] = (
154
+ "ray",
155
+ "ray.worker",
156
+ "ray.runtime",
157
+ "ray.dashboard",
158
+ "ray.tune",
159
+ "ray.serve",
160
+ )
161
+ for n in ray_logger_names:
162
+ logging.getLogger(n).setLevel(python_level)
163
+
164
+ # 4) build LoggingConfig that propagates to workers
165
+ return ray.LoggingConfig(
166
+ log_to_driver=log_to_driver,
167
+ log_level=python_level,
168
+ )
169
+
170
+
171
+ def silence_logger(
172
+ logger_name: str,
173
+ level: int | str = logging.ERROR,
174
+ ):
175
+ """
176
+ Call at the *top* of every `@ray.remote` function or actor `__init__`
177
+ to raise the threshold of a chatty library **inside the worker**.
178
+ """
179
+ logging.getLogger(logger_name).setLevel(level)
synplan/utils/visualisation.py ADDED
@@ -0,0 +1,1365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for analysis and visualization of the built tree."""
2
+
3
+ import base64
4
+ from itertools import count, islice
5
+ from collections import deque
6
+ from typing import Any, Dict, List, Union
7
+
8
+ from CGRtools.containers.molecule import MoleculeContainer
9
+ from CGRtools import smiles as read_smiles
10
+
11
+ from synplan.chem.reaction_routes.visualisation import (
12
+ cgr_display,
13
+ depict_custom_reaction,
14
+ )
15
+ from synplan.chem.reaction_routes.io import make_dict
16
+ from synplan.mcts.tree import Tree
17
+
18
+ from IPython.display import display, HTML
19
+
20
+
21
+ def get_child_nodes(
22
+ tree: Tree,
23
+ molecule: MoleculeContainer,
24
+ graph: Dict[MoleculeContainer, List[MoleculeContainer]],
25
+ ) -> Dict[str, Any]:
26
+ """Extracts the child nodes of the given molecule.
27
+
28
+ :param tree: The built tree.
29
+ :param molecule: The molecule in the tree from which to extract child nodes.
30
+ :param graph: The relationship between the given molecule and child nodes.
31
+ :return: The dict with extracted child nodes.
32
+ """
33
+
34
+ nodes = []
35
+ try:
36
+ graph[molecule]
37
+ except KeyError:
38
+ return []
39
+ for precursor in graph[molecule]:
40
+ temp_obj = {
41
+ "smiles": str(precursor),
42
+ "type": "mol",
43
+ "in_stock": str(precursor) in tree.building_blocks,
44
+ }
45
+ node = get_child_nodes(tree, precursor, graph)
46
+ if node:
47
+ temp_obj["children"] = [node]
48
+ nodes.append(temp_obj)
49
+ return {"type": "reaction", "children": nodes}
50
+
51
+
52
+ def extract_routes(
53
+ tree: Tree, extended: bool = False, min_mol_size: int = 0
54
+ ) -> List[Dict[str, Any]]:
55
+ """Takes the target and the dictionary of successors and predecessors and returns a
56
+ list of dictionaries that contain the target and the list of successors.
57
+
58
+ :param tree: The built tree.
59
+ :param extended: If True, generates the extended route representation.
60
+ :param min_mol_size: If the size of the Precursor is equal or smaller than
61
+ min_mol_size it is automatically classified as building block.
62
+ :return: A list of dictionaries. Each dictionary contains a target, a list of
63
+ children, and a boolean indicating whether the target is in building_blocks.
64
+ """
65
+ target = tree.nodes[1].precursors_to_expand[0].molecule
66
+ target_in_stock = tree.nodes[1].curr_precursor.is_building_block(
67
+ tree.building_blocks, min_mol_size
68
+ )
69
+
70
+ # append encoded routes to list
71
+ routes_block = []
72
+ winning_nodes = []
73
+ if extended:
74
+ # collect routes
75
+ for i, node in tree.nodes.items():
76
+ if node.is_solved():
77
+ winning_nodes.append(i)
78
+ else:
79
+ winning_nodes = tree.winning_nodes
80
+ if winning_nodes:
81
+ for winning_node in winning_nodes:
82
+ # Create graph for route
83
+ nodes = tree.route_to_node(winning_node)
84
+ graph, pred = {}, {}
85
+ for before, after in zip(nodes, nodes[1:]):
86
+ before = before.curr_precursor.molecule
87
+ graph[before] = after = [x.molecule for x in after.new_precursors]
88
+ for x in after:
89
+ pred[x] = before
90
+
91
+ routes_block.append(
92
+ {
93
+ "type": "mol",
94
+ "smiles": str(target),
95
+ "in_stock": target_in_stock,
96
+ "children": [get_child_nodes(tree, target, graph)],
97
+ }
98
+ )
99
+ else:
100
+ routes_block = [
101
+ {
102
+ "type": "mol",
103
+ "smiles": str(target),
104
+ "in_stock": target_in_stock,
105
+ "children": [],
106
+ }
107
+ ]
108
+ return routes_block
109
+
110
+
111
+ def render_svg(pred, columns, box_colors):
112
+ """
113
+ Renders an SVG representation of a retrosynthetic route.
114
+
115
+ This function takes the predicted reaction steps, the molecules organized
116
+ into columns representing reaction stages, and a mapping of molecule status
117
+ to box colors, and generates an SVG string visualizing the route. It
118
+ calculates positions for molecules and arrows, and constructs the SVG
119
+ elements.
120
+
121
+ Args:
122
+ pred (tuple): A tuple of tuples representing the predicted reaction
123
+ steps. Each inner tuple is (source_molecule_index,
124
+ target_molecule_index). The indices correspond to the
125
+ flattened list of molecules across all columns.
126
+ columns (list): A list of lists, where each inner list contains
127
+ Molecule objects for a specific stage (column) in the
128
+ retrosynthetic route.
129
+ box_colors (dict): A dictionary mapping molecule status strings (e.g.,
130
+ 'target', 'mulecule', 'instock') to SVG color strings
131
+ for the boxes around the molecules.
132
+
133
+ Returns:
134
+ str: A string containing the complete SVG code for the retrosynthetic
135
+ route visualization.
136
+ """
137
+ x_shift = 0.0
138
+ c_max_x = 0.0
139
+ c_max_y = 0.0
140
+ render = []
141
+ cx = count()
142
+ cy = count()
143
+ arrow_points = {}
144
+ for ms in columns:
145
+ heights = []
146
+ for m in ms:
147
+ m.clean2d()
148
+ # X-shift for target
149
+ min_x = min(x for x, y in m._plane.values()) - x_shift
150
+ min_y = min(y for x, y in m._plane.values())
151
+ m._plane = {n: (x - min_x, y - min_y) for n, (x, y) in m._plane.items()}
152
+ max_x = max(x for x, y in m._plane.values())
153
+
154
+ c_max_x = max(c_max_x, max_x)
155
+
156
+ arrow_points[next(cx)] = [x_shift, max_x]
157
+ heights.append(max(y for x, y in m._plane.values()))
158
+
159
+ x_shift = c_max_x + 5.0 # between columns gap
160
+ # calculate Y-shift
161
+ y_shift = sum(heights) + 3.0 * (len(heights) - 1)
162
+
163
+ c_max_y = max(c_max_y, y_shift)
164
+
165
+ y_shift /= 2.0
166
+ for m, h in zip(ms, heights):
167
+ m._plane = {n: (x, y - y_shift) for n, (x, y) in m._plane.items()}
168
+
169
+ # calculate coordinates for boxes
170
+ max_x = max(x for x, y in m._plane.values()) + 0.9 # max x
171
+ min_x = min(x for x, y in m._plane.values()) - 0.6 # min x
172
+ max_y = -(max(y for x, y in m._plane.values()) + 0.45) # max y
173
+ min_y = -(min(y for x, y in m._plane.values()) - 0.45) # min y
174
+ x_delta = abs(max_x - min_x)
175
+ y_delta = abs(max_y - min_y)
176
+ box = (
177
+ f'<rect x="{min_x}" y="{max_y}" rx="{y_delta * 0.1}" ry="{y_delta * 0.1}" width="{x_delta}" height="{y_delta}"'
178
+ f' stroke="black" stroke-width=".0025" fill="{box_colors[m.meta["status"]]}" fill-opacity="0.30"/>'
179
+ )
180
+ arrow_points[next(cy)].append(y_shift - h / 2.0)
181
+ y_shift -= h + 3.0
182
+ depicted_molecule = list(m.depict(embedding=True))[:3]
183
+ depicted_molecule.append(box)
184
+ render.append(depicted_molecule)
185
+
186
+ # calculate mid-X coordinate to draw square arrows
187
+ graph = {}
188
+ for s, p in pred:
189
+ try:
190
+ graph[s].append(p)
191
+ except KeyError:
192
+ graph[s] = [p]
193
+ for s, ps in graph.items():
194
+ mid_x = float("-inf")
195
+ for p in ps:
196
+ s_min_x, s_max, s_y = arrow_points[s][:3] # s
197
+ p_min_x, p_max, p_y = arrow_points[p][:3] # p
198
+ p_max += 1
199
+ mid = p_max + (s_min_x - p_max) / 3
200
+ mid_x = max(mid_x, mid)
201
+ for p in ps:
202
+ arrow_points[p].append(mid_x)
203
+
204
+ config = MoleculeContainer._render_config
205
+ font_size = config["font_size"]
206
+ font125 = 1.25 * font_size
207
+ width = c_max_x + 4.0 * font_size # 3.0 by default
208
+ height = c_max_y + 3.5 * font_size # 2.5 by default
209
+ box_y = height / 2.0
210
+ svg = [
211
+ f'<svg width="{0.6 * width:.2f}cm" height="{0.6 * height:.2f}cm" '
212
+ f'viewBox="{-font125:.2f} {-box_y:.2f} {width:.2f} '
213
+ f'{height:.2f}" xmlns="http://www.w3.org/2000/svg" version="1.1">',
214
+ ' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" '
215
+ 'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3"/>\n </marker>\n </defs>',
216
+ ]
217
+
218
+ for s, p in pred:
219
+ s_min_x, s_max, s_y = arrow_points[s][:3]
220
+ p_min_x, p_max, p_y = arrow_points[p][:3]
221
+ p_max += 1
222
+ mid_x = arrow_points[p][-1] # p_max + (s_min_x - p_max) / 3
223
+ arrow = f""" <polyline points="{p_max:.2f} {p_y:.2f}, {mid_x:.2f} {p_y:.2f}, {mid_x:.2f} {s_y:.2f}, {s_min_x - 1.:.2f} {s_y:.2f}"
224
+ fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>"""
225
+ if p_y != s_y:
226
+ arrow += f' <circle cx="{mid_x}" cy="{p_y}" r="0.1"/>'
227
+ svg.append(arrow)
228
+ for atoms, bonds, masks, box in render:
229
+ molecule_svg = MoleculeContainer._graph_svg(
230
+ atoms, bonds, masks, -font125, -box_y, width, height
231
+ )
232
+ molecule_svg.insert(1, box)
233
+ svg.extend(molecule_svg)
234
+ svg.append("</svg>")
235
+ return "\n".join(svg)
236
+
237
+ def get_route_svg_mod(tree: Tree, node_id: int) -> str:
238
+ """
239
+ Visualizes the full retrosynthetic route from the target to a given node.
240
+
241
+ This function generates an SVG image for the synthetic path from the target
242
+ molecule to the specified node_id. It correctly handles paths that have not
243
+ been fully resolved to building blocks. The layout follows standard
244
+ retrosynthetic analysis, with the target on the right and precursors
245
+ arranged in columns to the left.
246
+
247
+ :param tree: The built MCTS tree.
248
+ :param node_id: The ID of the node to which the route should be visualized.
249
+ :return: A string containing the SVG visualization of the route.
250
+ """
251
+ # Box colors for molecule status
252
+ box_colors = {
253
+ "target": "#98EEFF", # Light Blue for the main target
254
+ "mulecule": "#F0AB90", # Peach for intermediates not in stock
255
+ "instock": "#9BFAB3", # Light Green for building blocks
256
+ }
257
+
258
+ # Obtain the sequence of reaction steps in retrosynthetic order
259
+ retro_reactions = list(reversed(tree.synthesis_route(node_id)))
260
+
261
+ # Handle the case of the root node with no preceding reactions
262
+ if not retro_reactions:
263
+ target_node = tree.nodes.get(node_id)
264
+ if not target_node:
265
+ return ""
266
+ molecule = target_node.curr_precursor.molecule
267
+ molecule.meta["status"] = "target"
268
+ return render_svg(tuple(), [[molecule]], box_colors)
269
+
270
+ # Map all unique molecule SMILES to their MoleculeContainer objects
271
+ mol_map = {str(m): m for r in retro_reactions for m in r.reactants + r.products}
272
+
273
+ # Set the status for each unique molecule
274
+ for smiles, molecule in mol_map.items():
275
+ molecule.meta["status"] = "instock" if smiles in tree.building_blocks else "mulecule"
276
+
277
+ # The final target is the product of the first retrosynthetic reaction
278
+ target_molecule = retro_reactions[0].products[0]
279
+ target_molecule.meta["status"] = "target"
280
+ mol_map[str(target_molecule)] = target_molecule
281
+
282
+ # --- Build columns from left to right based on reaction dependencies ---
283
+ columns = []
284
+ # Identify molecules that are products in any reaction step
285
+ products_smiles = {str(p) for r in retro_reactions for p in r.products}
286
+
287
+ # The leftmost column consists of reactants that are not products of any other step in the path
288
+ leftmost_smiles = {str(m) for r in retro_reactions for m in r.reactants} - products_smiles
289
+
290
+ if not leftmost_smiles: # Fallback for simple A->B routes
291
+ leftmost_smiles = {str(m) for m in retro_reactions[-1].reactants}
292
+
293
+ columns.append([mol_map[s] for s in leftmost_smiles])
294
+ placed_smiles = set(leftmost_smiles)
295
+
296
+ # Iteratively build the next columns
297
+ while len(placed_smiles) < len(mol_map):
298
+ next_products = set()
299
+ for r in retro_reactions:
300
+ # If all reactants for a reaction have been placed in previous columns...
301
+ if all(str(reactant) in placed_smiles for reactant in r.reactants):
302
+ # ...then its products belong in the next column.
303
+ for product in r.products:
304
+ if str(product) not in placed_smiles:
305
+ next_products.add(str(product))
306
+
307
+ if not next_products:
308
+ break # Safety break if no new column can be formed
309
+
310
+ columns.append([mol_map[s] for s in next_products])
311
+ placed_smiles.update(next_products)
312
+
313
+ # --- Prepare data for rendering ---
314
+ # Flatten the columns to get a single list of molecules for indexing
315
+ flat_mols = [mol for col in columns for mol in col]
316
+ mol_to_idx = {str(mol): i for i, mol in enumerate(flat_mols)}
317
+
318
+ # Define the connections (precursor -> product) for the SVG rendering
319
+ # The arrow in render_svg points from 'p' to 's'
320
+ pred = []
321
+ for reaction in retro_reactions:
322
+ for product in reaction.products:
323
+ if str(product) in mol_to_idx:
324
+ s_idx = mol_to_idx[str(product)] # 's' is the product (on the right)
325
+ for reactant in reaction.reactants:
326
+ if str(reactant) in mol_to_idx:
327
+ p_idx = mol_to_idx[str(reactant)] # 'p' is the reactant (on the left)
328
+ pred.append((s_idx, p_idx))
329
+
330
+ return render_svg(tuple(pred), columns, box_colors)
331
+
332
+
333
+ def get_route_svg(tree: Tree, node_id: int) -> str:
334
+ """Visualizes the retrosynthetic route.
335
+
336
+ :param tree: The built tree.
337
+ :param node_id: The id of the node from which to visualize the route.
338
+ :return: The SVG string.
339
+ """
340
+ nodes = tree.route_to_node(node_id)
341
+ # Set up node_id types for different box colors
342
+ for n in nodes:
343
+ for precursor in n.new_precursors:
344
+ precursor.molecule.meta["status"] = (
345
+ "instock"
346
+ if precursor.is_building_block(tree.building_blocks)
347
+ else "mulecule"
348
+ )
349
+ nodes[0].curr_precursor.molecule.meta["status"] = "target"
350
+ # Box colors
351
+ box_colors = {
352
+ "target": "#98EEFF", # 152, 238, 255
353
+ "mulecule": "#F0AB90", # 240, 171, 144
354
+ "instock": "#9BFAB3", # 155, 250, 179
355
+ }
356
+
357
+ # first column is target
358
+ # second column are first new precursor_to_expand
359
+ columns = [
360
+ [nodes[0].curr_precursor.molecule],
361
+ [x.molecule for x in nodes[1].new_precursors],
362
+ ]
363
+ pred = {x: 0 for x in range(1, len(columns[1]) + 1)}
364
+ cx = [
365
+ n
366
+ for n, x in enumerate(nodes[1].new_precursors, 1)
367
+ if not x.is_building_block(tree.building_blocks)
368
+ ]
369
+ size = len(cx)
370
+ nodes = iter(nodes[2:])
371
+ cy = count(len(columns[1]) + 1)
372
+ while size:
373
+ layer = []
374
+ for s in islice(nodes, size):
375
+ n = cx.pop(0)
376
+ for x in s.new_precursors:
377
+ layer.append(x)
378
+ m = next(cy)
379
+ if not x.is_building_block(tree.building_blocks):
380
+ cx.append(m)
381
+ pred[m] = n
382
+ size = len(cx)
383
+ columns.append([x.molecule for x in layer])
384
+
385
+ columns = [
386
+ columns[::-1] for columns in columns[::-1]
387
+ ] # Reverse array to make retrosynthetic graph
388
+ pred = tuple( # Change dict to tuple to make multiple precursor_to_expand available
389
+ (abs(source - len(pred)), abs(target - len(pred)))
390
+ for target, source in pred.items()
391
+ )
392
+ svg = render_svg(pred, columns, box_colors)
393
+ return svg
394
+
395
+
396
+ def get_route_svg_from_json(routes_json: dict, route_id: int) -> str:
397
+ """
398
+ Visualizes the retrosynthetic route described in routes_json[route_id].
399
+
400
+ :param routes_json: A dict mapping route IDs to nested JSON trees of molecules/reactions.
401
+ :param route_id: The id of the route from which to visualize the route.
402
+ :return: The SVG string .
403
+ """
404
+ # 1) Parse JSON into per-depth lists of mol-dicts, remembering parent links
405
+ if route_id not in routes_json.keys():
406
+ try:
407
+ root = routes_json[str(route_id)]
408
+ except KeyError:
409
+ raise ValueError(f"Route ID {route_id} not found in routes_json.")
410
+ else:
411
+ root = routes_json[route_id]
412
+ levels = [] # levels[d] = list of mol-dicts at depth d
413
+ parent_of = {} # mol_id -> parent_mol_dict
414
+ Q = deque([(root, 0, None)])
415
+ while Q:
416
+ node, depth, parent = Q.popleft()
417
+ if node.get("type") != "mol":
418
+ continue
419
+ if len(levels) <= depth:
420
+ levels.append([])
421
+ levels[depth].append(node)
422
+ parent_of[id(node)] = parent
423
+ for child in node.get("children", []):
424
+ if child.get("type") == "reaction":
425
+ for mol_child in child.get("children", []):
426
+ if mol_child.get("type") == "mol":
427
+ Q.append((mol_child, depth + 1, node))
428
+
429
+ # 2) Build MoleculeContainer objects & set meta["status"]
430
+ mol_container = {}
431
+ for depth, mols in enumerate(levels):
432
+ for mol in mols:
433
+ m = read_smiles(mol["smiles"])
434
+ # target at depth=0, else instock vs mulecule
435
+ if depth == 0:
436
+ m.meta["status"] = "target"
437
+ else:
438
+ m.meta["status"] = (
439
+ "instock" if mol.get("in_stock", False) else "mulecule"
440
+ )
441
+ mol_container[id(mol)] = m
442
+
443
+ # 3) Mirror columns left↔right at the JSON level
444
+ json_columns = levels[::-1]
445
+
446
+ # 4) Flatten JSON node IDs in that mirrored order (so flat_index keys = id(mol_dict))
447
+ flat_node_ids = [id(m) for lvl in json_columns for m in lvl]
448
+ flat_index = {nid: idx for idx, nid in enumerate(flat_node_ids)}
449
+
450
+ # 5) Build pred from those JSON‐node IDs
451
+ pred = tuple(
452
+ (flat_index[id(parent)], flat_index[child_id])
453
+ for child_id, parent in parent_of.items()
454
+ if parent is not None
455
+ )
456
+
457
+ # 6) Now map JSON columns → MoleculeContainer columns for layout
458
+ columns = [[mol_container[id(m)] for m in lvl] for lvl in json_columns]
459
+
460
+ # 6) The rest is identical to your original rendering logic:
461
+ box_colors = {
462
+ "target": "#98EEFF",
463
+ "mulecule": "#F0AB90",
464
+ "instock": "#9BFAB3",
465
+ }
466
+
467
+ svg = render_svg(pred, columns, box_colors)
468
+ return svg
469
+
470
+
471
+ def generate_results_html(
472
+ tree: Tree, html_path: str, aam: bool = False, extended: bool = False
473
+ ) -> None:
474
+ """Writes an HTML page with the synthesis routes in SVG format and corresponding
475
+ reactions in SMILES format.
476
+
477
+ :param tree: The built tree.
478
+ :param extended: If True, generates the extended route representation.
479
+ :param html_path: The path to the file where to store resulting HTML.
480
+ :param aam: If True, depict atom-to-atom mapping.
481
+ :return: None.
482
+ """
483
+ if aam:
484
+ MoleculeContainer.depict_settings(aam=True)
485
+ else:
486
+ MoleculeContainer.depict_settings(aam=False)
487
+
488
+ routes = []
489
+ if extended:
490
+ # Gather paths
491
+ for idx, node in tree.nodes.items():
492
+ if node.is_solved():
493
+ routes.append(idx)
494
+ else:
495
+ routes = tree.winning_nodes
496
+ # HTML Tags
497
+ th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
498
+ td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
499
+ font_red = "<font color='red' style='font-weight: bold'>"
500
+ font_green = "<font color='light-green' style='font-weight: bold'>"
501
+ font_head = "<font style='font-weight: bold; font-size: 18px'>"
502
+ font_normal = "<font style='font-weight: normal; font-size: 18px'>"
503
+ font_close = "</font>"
504
+
505
+ template_begin = """
506
+ <!doctype html>
507
+ <html lang="en">
508
+ <head>
509
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
510
+ rel="stylesheet"
511
+ integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
512
+ crossorigin="anonymous">
513
+ <script
514
+ src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
515
+ integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
516
+ crossorigin="anonymous">
517
+ </script>
518
+ <meta charset="utf-8">
519
+ <meta name="viewport" content="width=device-width, initial-scale=1">
520
+ <title>Predicted Paths Report</title>
521
+ <meta name="description" content="A simple HTML5 Template for new projects.">
522
+ <meta name="author" content="SitePoint">
523
+ </head>
524
+ <body>
525
+ """
526
+ template_end = """
527
+ </body>
528
+ </html>
529
+ """
530
+ # SVG Template
531
+ box_mark = """
532
+ <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg">
533
+ <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
534
+ </svg>
535
+ """
536
+ # table = f"<table><thead><{th}>Retrosynthetic Routes</th></thead><tbody>"
537
+ table = """
538
+ <table class="table table-striped table-hover caption-top">
539
+ <caption><h3>Retrosynthetic Routes Report</h3></caption>
540
+ <tbody>"""
541
+
542
+ # Gather path data
543
+ table += f"<tr>{td}{font_normal}Target Molecule: {str(tree.nodes[1].curr_precursor)}{font_close}</td></tr>"
544
+ table += f"<tr>{td}{font_normal}Tree Size: {len(tree)}{font_close} nodes</td></tr>"
545
+ table += f"<tr>{td}{font_normal}Number of visited nodes: {len(tree.visited_nodes)}{font_close}</td></tr>"
546
+ table += f"<tr>{td}{font_normal}Found paths: {len(routes)}{font_close}</td></tr>"
547
+ table += f"<tr>{td}{font_normal}Time: {round(tree.curr_time, 4)}{font_close} seconds</td></tr>"
548
+ table += f"""
549
+ <tr>{td}
550
+ <div>
551
+ {box_mark.replace("rgb()", "rgb(152, 238, 255)")}
552
+ Target Molecule
553
+ {box_mark.replace("rgb()", "rgb(240, 171, 144)")}
554
+ Molecule Not In Stock
555
+ {box_mark.replace("rgb()", "rgb(155, 250, 179)")}
556
+ Molecule In Stock
557
+ </div>
558
+ </td></tr>
559
+ """
560
+
561
+ for route in routes:
562
+ svg = get_route_svg(tree, route) # get SVG
563
+ full_route = tree.synthesis_route(route) # get route
564
+ # write SMILES of all reactions in synthesis path
565
+ step = 1
566
+ reactions = ""
567
+ for synth_step in full_route:
568
+ reactions += f"<b>Step {step}:</b> {str(synth_step)}<br>"
569
+ step += 1
570
+ # Concatenate all content of path
571
+ route_score = round(tree.route_score(route), 3)
572
+ table += (
573
+ f'<tr style="line-height: 250%">{td}{font_head}Route {route}; '
574
+ f"Steps: {len(full_route)}; "
575
+ f"Cumulated nodes' value: {route_score}{font_close}</td></tr>"
576
+ )
577
+ # f"Cumulated nodes' value: {node._probabilities[path]}{font_close}</td></tr>"
578
+ table += f"<tr>{td}{svg}</td></tr>"
579
+ table += f"<tr>{td}{reactions}</td></tr>"
580
+ table += "</tbody>"
581
+ if html_path is None:
582
+ return table
583
+ with open(html_path, "w", encoding="utf-8") as html_file:
584
+ html_file.write(template_begin)
585
+ html_file.write(table)
586
+ html_file.write(template_end)
587
+
588
+
589
+ def html_top_routes_cluster(clusters: dict, tree: Tree, target_smiles: str) -> str:
590
+ """9. Clustering Results Download: Providing functionality to download the clustering results with styled HTML report."""
591
+
592
+ # Compute summary
593
+ total_routes = sum(len(data.get("node_ids", [])) for data in clusters.values())
594
+ total_clusters = len(clusters)
595
+
596
+ # Build styled HTML report using Bootstrap
597
+ html = []
598
+
599
+ html.append("<!doctype html><html lang='en'><head>")
600
+ html.append(
601
+ "<meta charset='utf-8'><meta name='viewport' content='width=device-width, initial-scale=1'>"
602
+ )
603
+ html.append(
604
+ "<link href='https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css' rel='stylesheet'>"
605
+ )
606
+ html.append("<title>Clustering Results Report</title>")
607
+ html.append(
608
+ "<style> svg{max-width:100%;height:auto;} .report-table th,.report-table td{vertical-align:top;border:1px solid #dee2e6;} </style>"
609
+ )
610
+ html.append("</head><body><div class='container my-4'>")
611
+ # Report header
612
+ html.append(f"<h1 class='mb-3'>Best route from each cluster</h1>")
613
+ html.append(f"<p><strong>Target molecule (SMILES):</strong> {target_smiles}</p>")
614
+ html.append(f"<p><strong>Total number of routes:</strong> {total_routes}</p>")
615
+ html.append(f"<p><strong>Total number of clusters:</strong> {total_clusters}</p>")
616
+ # Table header
617
+ # html.append("<table class='table report-table'><thead><tr>")
618
+ html.append(
619
+ "<table class='table report-table'><colgroup><col style='width:5%'><colgroup><col style='width:5%'><col style='width:15%'><col style='width:75%'></colgroup><thead><tr>"
620
+ )
621
+ html.append(
622
+ "<th>Cluster index</th><th>Size</th><th>ReducedRouteCGR</th><th>Best Route</th>"
623
+ )
624
+ html.append("</tr></thead><tbody>")
625
+
626
+ # Rows per cluster
627
+ for cluster_num, group_data in clusters.items():
628
+ node_ids = group_data.get("node_ids", [])
629
+ if not node_ids:
630
+ continue
631
+ node_id = node_ids[0]
632
+ # Get SVGs
633
+ svg = get_route_svg(tree, node_id)
634
+ r_cgr = group_data.get("sb_cgr")
635
+ r_cgr_svg = None
636
+ if r_cgr:
637
+ r_cgr.clean2d()
638
+ r_cgr_svg = cgr_display(r_cgr)
639
+ # Start row
640
+ html.append(f"<tr><td>{cluster_num}</td>")
641
+ html.append(f"<td>{len(node_ids)}</td>")
642
+ # ReducedRouteCGR cell
643
+ html.append("<td>")
644
+ if r_cgr_svg:
645
+ b64_r = base64.b64encode(r_cgr_svg.encode("utf-8")).decode()
646
+ html.append(
647
+ f"<img src='data:image/svg+xml;base64,{b64_r}' alt='ReducedRouteCGR' class='img-fluid'/>"
648
+ )
649
+ html.append("</td>")
650
+ # Best Route cell
651
+ html.append("<td>")
652
+ if svg:
653
+ b64_svg = base64.b64encode(svg.encode("utf-8")).decode()
654
+ html.append(
655
+ f"<img src='data:image/svg+xml;base64,{b64_svg}' alt='Route {node_id}' class='img-fluid'/>"
656
+ )
657
+ html.append("</td></tr>")
658
+
659
+ # Close table and HTML
660
+ html.append("</tbody></table>")
661
+ html.append("</div></body></html>")
662
+
663
+ report_html = "".join(html)
664
+ return report_html
665
+
666
+
667
+ def routes_clustering_report(
668
+ source: Union[Tree, dict],
669
+ clusters: dict,
670
+ group_index: str,
671
+ sb_cgrs_dict: dict,
672
+ aam: bool = False,
673
+ html_path: str = None,
674
+ ) -> str:
675
+ """
676
+ Generates an HTML report visualizing a cluster of retrosynthetic routes.
677
+
678
+ This function takes a source of retrosynthetic routes (either a Tree object
679
+ or a dictionary representing routes in JSON format), cluster information,
680
+ and a dictionary of ReducedRouteCGRs, and produces a comprehensive HTML report.
681
+ The report includes details about the cluster, a representative ReducedRouteCGR,
682
+ and SVG visualizations of each route within the specified cluster.
683
+
684
+ Args:
685
+ source (Union[Tree, dict]): The source of retrosynthetic routes.
686
+ Can be a Tree object containing the full
687
+ search tree, or a dictionary loaded from
688
+ a routes JSON file.
689
+ clusters (dict): A dictionary containing clustering results. It should
690
+ contain information about different clusters, typically
691
+ including a list of 'node_ids' for each cluster.
692
+ group_index (str): The key identifying the specific cluster within the
693
+ `clusters` dictionary for which the report should be
694
+ generated.
695
+ sb_cgrs_dict (dict): A dictionary mapping route IDs (integers) to
696
+ ReducedRouteCGR (Retrosynthetic Graph-based Chemical
697
+ Reaction) objects. Used to display a representative
698
+ ReducedRouteCGR for the cluster.
699
+ aam (bool, optional): Whether to enable atom-atom mapping visualization
700
+ in molecule depictions. Defaults to False.
701
+ html_path (str, optional): The file path where the generated HTML
702
+ report should be saved. If provided, the
703
+ function saves the report to this file and
704
+ returns a confirmation message. If None,
705
+ the function returns the HTML string
706
+ directly. Defaults to None.
707
+
708
+ Returns:
709
+ str: The generated HTML report as a string, or a string confirming
710
+ the file path where the report was saved if `html_path` is
711
+ provided. Returns an error message string if the input `source`
712
+ or `clusters` are invalid, or if the specified `group_index` is
713
+ not found.
714
+ """
715
+ # --- Depict Settings ---
716
+ try:
717
+ MoleculeContainer.depict_settings(aam=bool(aam))
718
+ except Exception:
719
+ pass
720
+
721
+ # --- Figure out what `source` is ---
722
+ using_tree = False
723
+ if hasattr(source, "nodes") and hasattr(source, "route_to_node"):
724
+ tree = source
725
+ using_tree = True
726
+ elif isinstance(source, dict):
727
+ routes_json = source
728
+ tree = None
729
+ else:
730
+ return "<html><body>Error: first argument must be a Tree or a routes_json dict.</body></html>"
731
+
732
+ # --- Validate clusters ---
733
+ if not isinstance(clusters, dict):
734
+ return "<html><body>Error: clusters must be a dict.</body></html>"
735
+
736
+ group = clusters.get(group_index)
737
+ if group is None:
738
+ return f"<html><body>Error: no group with index {group_index!r}.</body></html>"
739
+
740
+ cluster_node_ids = group.get("node_ids", [])
741
+ # Filter valid routes
742
+ valid_routes = []
743
+
744
+ if using_tree:
745
+ for nid in cluster_node_ids:
746
+ if nid in tree.nodes and tree.nodes[nid].is_solved():
747
+ valid_routes.append(nid)
748
+ else:
749
+ # JSON mode: check if the node ID exists in the routes_dict
750
+ routes_dict = make_dict(routes_json)
751
+ for nid in cluster_node_ids:
752
+ if nid in routes_dict.keys():
753
+ valid_routes.append(nid)
754
+ if not valid_routes:
755
+ return f"""
756
+ <!doctype html><html><body>
757
+ <h3>Cluster {group_index} Report</h3>
758
+ <p>No valid routes found in this cluster.</p>
759
+ </body></html>
760
+ """
761
+
762
+ # --- Boilerplate HTML head/tail omitted for brevity ---
763
+ template_begin = (
764
+ """<!doctype html><html><head>…</head><body><div class="container">"""
765
+ )
766
+ template_end = """</div></body></html>"""
767
+
768
+ table = f"""
769
+ <table class="table">
770
+ <caption><h3>Cluster {group_index} Routes</h3></caption>
771
+ <tbody>
772
+ """
773
+
774
+ # show target
775
+ if using_tree:
776
+ try:
777
+ target_smiles = str(tree.nodes[1].curr_precursor)
778
+ except Exception:
779
+ target_smiles = "N/A"
780
+ else:
781
+ # JSON mode: take the root smiles of the first route
782
+ target_smiles = routes_json[str(valid_routes[0])]["smiles"]
783
+
784
+ # legend row omitted…
785
+
786
+ # --- HTML Templates & Tags ---
787
+ th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
788
+ td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
789
+ font_head = "<font style='font-weight: bold; font-size: 18px'>"
790
+ font_normal = "<font style='font-weight: normal; font-size: 18px'>"
791
+ font_close = "</font>"
792
+
793
+ template_begin = f"""
794
+ <!doctype html>
795
+ <html lang="en">
796
+ <head>
797
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
798
+ rel="stylesheet"
799
+ integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
800
+ crossorigin="anonymous">
801
+ <meta charset="utf-8">
802
+ <meta name="viewport" content="width=device-width, initial-scale=1">
803
+ <title>Cluster {group_index} Routes Report</title>
804
+ <style>
805
+ /* Optional: Add some basic styling */
806
+ .table {{ border-collapse: collapse; width: 100%; }}
807
+ th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
808
+ tr:nth-child(even) {{ background-color: #ffffff; }}
809
+ caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }}
810
+ svg {{ max-width: 100%; height: auto; }}
811
+ </style>
812
+ </head>
813
+ <body>
814
+ <div class="container"> """
815
+
816
+ template_end = """
817
+ </div> <script
818
+ src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
819
+ integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
820
+ crossorigin="anonymous">
821
+ </script>
822
+ </body>
823
+ </html>
824
+ """
825
+
826
+ box_mark = """
827
+ <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;">
828
+ <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
829
+ </svg>
830
+ """
831
+
832
+ # --- Build HTML Table ---
833
+ table = f"""
834
+ <table class="table table-hover caption-top">
835
+ <caption><h3>Retrosynthetic Routes Report - Cluster {group_index}</h3></caption>
836
+ <tbody>"""
837
+
838
+ table += (
839
+ f"<tr>{td}{font_normal}Target Molecule: {target_smiles}{font_close}</td></tr>"
840
+ )
841
+ table += f"<tr>{td}{font_normal}Group index: {group_index}{font_close}</td></tr>"
842
+ table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes)} routes{font_close} </td></tr>"
843
+
844
+ # --- Add ReducedRouteCGR Image ---
845
+ first_route_id = valid_routes[0] if valid_routes else None
846
+
847
+ if first_route_id and sb_cgrs_dict:
848
+ try:
849
+ sb_cgr = sb_cgrs_dict[first_route_id]
850
+ sb_cgr.clean2d()
851
+ sb_cgr_svg = cgr_display(sb_cgr)
852
+
853
+ if sb_cgr_svg.strip().startswith("<svg"):
854
+ table += f"<tr>{td}{font_normal}Identified Strategic Bonds{font_close}<br>{sb_cgr_svg}</td></tr>"
855
+ else:
856
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>"
857
+ print(
858
+ f"Warning: Expected SVG for ReducedRouteCGR of node {first_route_id}, but got: {sb_cgr_svg[:100]}..."
859
+ )
860
+ except Exception as e:
861
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying ReducedRouteCGR: {e}</i></td></tr>"
862
+ else:
863
+ if first_route_id:
864
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided ReducedRouteCGR dictionary.</i></td></tr>"
865
+ else:
866
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>"
867
+
868
+ table += f"""
869
+ <tr>{td}
870
+ <div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;">
871
+ <span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span>
872
+ <span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span>
873
+ <span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span>
874
+ </div>
875
+ </td></tr>
876
+ """
877
+ for route_id in valid_routes:
878
+ if using_tree:
879
+ # 1) SVG from Tree
880
+ svg = get_route_svg(tree, route_id)
881
+ # 2) Reaction steps & score
882
+ steps = tree.synthesis_route(route_id)
883
+ score = round(tree.route_score(route_id), 3)
884
+ # build reaction list
885
+ reac_html = "".join(
886
+ f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in enumerate(steps)
887
+ )
888
+ header = f"Route {route_id} — {len(steps)} steps, score={score}"
889
+ table += f"<tr><td><b>{header}</b></td></tr>"
890
+ table += f"<tr><td>{svg}</td></tr>"
891
+ table += f"<tr><td>{reac_html}</td></tr>"
892
+
893
+ else:
894
+ # 1) SVG from JSON
895
+ svg = get_route_svg_from_json(routes_json, route_id)
896
+ steps = routes_dict[route_id]
897
+ reac_html = "".join(
898
+ f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in steps.items()
899
+ )
900
+
901
+ header = f"Route {route_id} — {len(steps)} steps"
902
+ table += f"<tr><td><b>{header}</b></td></tr>"
903
+ table += f"<tr><td>{svg}</td></tr>"
904
+ table += f"<tr><td>{reac_html}</td></tr>"
905
+
906
+ table += "</tbody></table>"
907
+
908
+ html = template_begin + table + template_end
909
+
910
+ if html_path:
911
+ with open(html_path, "w", encoding="utf-8") as f:
912
+ f.write(html)
913
+ return f"Written to {html_path}"
914
+ return html
915
+
916
+
917
+ def lg_table_2_html(subcluster, nodes_to_display=[], if_display=True):
918
+ """
919
+ Generates an HTML table visualizing leaving groups (X) 'marks' for routes within a subcluster.
920
+
921
+ This function creates an HTML table where each row represents a routes
922
+ from the specified subcluster (or a subset of nodes), and columns
923
+ represent unique 'marks' found across the nodes. The cells contain
924
+ the SVG depiction of the corresponding mark for that node.
925
+
926
+ Args:
927
+ subcluster (dict): A dictionary containing subcluster data, expected
928
+ to have a 'nodes_data' key mapping node IDs to
929
+ dictionaries of marks and their associated data
930
+ (where the first element is a depictable object).
931
+ nodes_to_display (list, optional): A list of specific node IDs to
932
+ include in the table. If empty,
933
+ all nodes in `subcluster["nodes_data"]`
934
+ are included. Defaults to [].
935
+ if_display (bool, optional): If True, the generated HTML is
936
+ displayed directly using `display(HTML())`.
937
+ Defaults to True.
938
+
939
+ Returns:
940
+ str: The generated HTML string for the table.
941
+ """
942
+ # Create HTML table header
943
+ html = "<table style='border-collapse: collapse;'><tr><th style='border: 1px solid black; padding: 4px;'>Route ID</th>"
944
+
945
+ # Extract all unique marks across all nodes to form consistent columns
946
+ all_marks = set()
947
+ for node_data in subcluster["nodes_data"].values():
948
+ all_marks.update(node_data.keys())
949
+ all_marks = sorted(all_marks) # sort for consistent ordering
950
+
951
+ # Add marks as headers
952
+ for mark in all_marks:
953
+ html += f"<th style='border: 1px solid black; padding: 4px;'>{mark}</th>"
954
+ html += "</tr>"
955
+
956
+ # Fill in the rows
957
+ if len(nodes_to_display) == 0:
958
+ for node_id, node_data in subcluster["nodes_data"].items():
959
+ html += (
960
+ f"<tr><td style='border: 1px solid black; padding: 4px;'>{node_id}</td>"
961
+ )
962
+ for mark in all_marks:
963
+ html += "<td style='border: 1px solid black; padding: 4px;'>"
964
+ if mark in node_data:
965
+ svg = node_data[mark][0].depict() # Get SVG data as string
966
+ html += svg
967
+ html += "</td>"
968
+ html += "</tr>"
969
+ else:
970
+ for node_id in nodes_to_display:
971
+ # Check if the node_id exists in the subcluster data
972
+ if node_id in subcluster["nodes_data"]:
973
+ node_data = subcluster["nodes_data"][node_id]
974
+ html += f"<tr><td style='border: 1px solid black; padding: 4px;'>{node_id}</td>"
975
+ for mark in all_marks:
976
+ html += "<td style='border: 1px solid black; padding: 4px;'>"
977
+ if mark in node_data:
978
+ svg = node_data[mark][0].depict() # Get SVG data as string
979
+ html += svg
980
+ html += "</td>"
981
+ html += "</tr>"
982
+ else:
983
+ # Optionally, you can note that the node_id was not found
984
+ html += f"<tr><td colspan='{len(all_marks)+1}' style='border: 1px solid black; padding: 4px; color:red;'>Route ID {node_id} not found.</td></tr>"
985
+
986
+ html += "</table>"
987
+
988
+ if if_display:
989
+ display(HTML(html))
990
+
991
+ return html
992
+
993
+
994
+ def group_lg_table_2_html_fixed(
995
+ grouped: dict,
996
+ groups_to_display=None,
997
+ if_display=False,
998
+ max_group_col_width: int = 200,
999
+ ) -> str:
1000
+ """
1001
+ Generates an HTML table visualizing leaving groups X 'marks' for representative routes in grouped data.
1002
+
1003
+ This function takes a dictionary of grouped data, where each key represents
1004
+ a group (e.g., a collection of node IDs of routes) and the value is a representative
1005
+ dictionary of 'marks' for that group. It generates an HTML table with a
1006
+ fixed layout, where each row corresponds to a group, and columns show the
1007
+ SVG depiction or string representation of the 'marks' for the group's
1008
+ representative.
1009
+
1010
+ Args:
1011
+ grouped (dict): A dictionary where keys are group identifiers (e.g.,
1012
+ tuples of node IDs of routes) and values are dictionaries
1013
+ representing the 'marks' for the representative of
1014
+ that group. The 'marks' dictionary should map mark
1015
+ names (str) to objects that have a `.depict()` method
1016
+ or are convertible to a string.
1017
+ groups_to_display (list, optional): A list of specific group
1018
+ identifiers to include in the table.
1019
+ If None, all groups in the `grouped`
1020
+ dictionary are included. Defaults to None.
1021
+ if_display (bool, optional): If True, the generated HTML is
1022
+ displayed directly using `display(HTML())`.
1023
+ Defaults to False.
1024
+ max_group_col_width (int, optional): The maximum width (in pixels)
1025
+ for the column displaying the
1026
+ group identifiers. Defaults to 200.
1027
+
1028
+ Returns:
1029
+ str: The generated HTML string for the table.
1030
+ """
1031
+ # 1) pick which groups to show
1032
+ if groups_to_display is None:
1033
+ groups = list(grouped.keys())
1034
+ else:
1035
+ groups = [g for g in groups_to_display if g in grouped]
1036
+
1037
+ # 2) collect all marks for the header
1038
+ all_marks = sorted({m for rep in grouped.values() for m in rep.keys()})
1039
+
1040
+ # 3) build table start with auto layout
1041
+ html = [
1042
+ "<table style='width:100%; table-layout:auto; border-collapse: collapse;'>",
1043
+ "<thead><tr>",
1044
+ "<th style='border:1px solid #ccc; padding:4px;'>Route IDs</th>",
1045
+ ]
1046
+ # numeric headers
1047
+ html += [
1048
+ f"<th style='border:1px solid #ccc; padding:4px; text-align:center;'>{mark}</th>"
1049
+ for mark in all_marks
1050
+ ]
1051
+ html.append("</tr></thead><tbody>")
1052
+
1053
+ # 4) each row
1054
+ group_td_style = (
1055
+ f"border:1px solid #ccc; padding:4px; "
1056
+ "white-space: normal; overflow-wrap: break-word; "
1057
+ f"max-width:{max_group_col_width}px;"
1058
+ )
1059
+ img_td_style = (
1060
+ "border:1px solid #ccc; padding:4px; text-align:center; vertical-align:middle;"
1061
+ )
1062
+
1063
+ for group in groups:
1064
+ rep = grouped[group]
1065
+ label = ",".join(str(n) for n in group)
1066
+ # start row
1067
+ row = [f"<td style='{group_td_style}'>{label}</td>"]
1068
+ # fill in each mark column
1069
+ for mark in all_marks:
1070
+ cell = ["<td style='" + img_td_style + "'>"]
1071
+ if mark in rep:
1072
+ val = rep[mark]
1073
+ cell.append(val.depict() if hasattr(val, "depict") else str(val))
1074
+ cell.append("</td>")
1075
+ row.append("".join(cell))
1076
+ html.append("<tr>" + "".join(row) + "</tr>")
1077
+
1078
+ html.append("</tbody></table>")
1079
+ out = "".join(html)
1080
+ if if_display:
1081
+ display(HTML(out))
1082
+
1083
+ return out
1084
+
1085
+
1086
+ def routes_subclustering_report(
1087
+ source: Union[Tree, dict],
1088
+ subcluster: dict,
1089
+ group_index: str,
1090
+ cluster_num: int,
1091
+ sb_cgrs_dict: dict,
1092
+ if_lg_group: bool = False,
1093
+ aam: bool = False,
1094
+ html_path: str = None,
1095
+ ) -> str:
1096
+ """
1097
+ Generates an HTML report visualizing a specific subcluster of retrosynthetic routes.
1098
+
1099
+ This function takes a source of retrosynthetic routes (either a Tree object
1100
+ or a dictionary representing routes in JSON format), data for a specific
1101
+ subcluster, and a dictionary of ReducedRouteCGRs. It produces a detailed HTML report
1102
+ for the subcluster, including general cluster information, a representative
1103
+ ReducedRouteCGR, a synthon pseudo reaction, a table of leaving groups (either per
1104
+ node or grouped), and SVG visualizations of each valid route within the
1105
+ subcluster.
1106
+
1107
+ Args:
1108
+ source (Union[Tree, dict]): The source of retrosynthetic routes.
1109
+ Can be a Tree object containing the full
1110
+ search tree, or a dictionary loaded from
1111
+ a routes JSON file.
1112
+ subcluster (dict): A dictionary containing data for the specific
1113
+ subcluster. Expected keys include 'nodes_data'
1114
+ (mapping node IDs to mark data), 'synthon_reaction',
1115
+ and optionally 'group_lgs' if `if_lg_group` is True.
1116
+ group_index (str): The index of the main cluster to which this
1117
+ subcluster belongs. Used for report titling.
1118
+ cluster_num (int): The number or identifier of the subcluster within
1119
+ its main group. Used for report titling.
1120
+ sb_cgrs_dict (dict): A dictionary mapping route IDs (integers) to
1121
+ ReducedRouteCGR objects. Used to display a representative
1122
+ ReducedRouteCGR for the cluster.
1123
+ if_lg_group (bool, optional): If True, the leaving groups table will
1124
+ display grouped leaving groups from
1125
+ `subcluster['group_lgs']`. If False, it
1126
+ will display leaving groups per individual
1127
+ node from `subcluster['nodes_data']`.
1128
+ Defaults to False.
1129
+ aam (bool, optional): Whether to enable atom-atom mapping visualization
1130
+ in molecule depictions. Defaults to False.
1131
+ html_path (str, optional): The file path where the generated HTML
1132
+ report should be saved. If provided, the
1133
+ function saves the report to this file and
1134
+ returns a confirmation message. If None,
1135
+ the function returns the HTML string
1136
+ directly. Defaults to None.
1137
+
1138
+ Returns:
1139
+ str: The generated HTML report as a string, or a string confirming
1140
+ the file path where the report was saved if `html_path` is
1141
+ provided. Returns a minimal HTML page indicating no valid routes
1142
+ if the subcluster contains no valid/solved routes. Returns an
1143
+ error message string if the input `source` or `subcluster` are
1144
+ invalid.
1145
+ """
1146
+ # --- Depict Settings ---
1147
+ try:
1148
+ MoleculeContainer.depict_settings(aam=bool(aam))
1149
+ except Exception:
1150
+ pass
1151
+
1152
+ # --- Figure out what `source` is ---
1153
+ using_tree = False
1154
+ if hasattr(source, "nodes") and hasattr(source, "route_to_node"):
1155
+ tree = source
1156
+ using_tree = True
1157
+ elif isinstance(source, dict):
1158
+ routes_json = source
1159
+ tree = None
1160
+ else:
1161
+ return "<html><body>Error: first argument must be a Tree or a routes_json dict.</body></html>"
1162
+
1163
+ # --- Validate groups ---
1164
+ if not isinstance(subcluster, dict):
1165
+ return "<html><body>Error: groups must be a dict.</body></html>"
1166
+
1167
+ subcluster_node_ids = list(subcluster["nodes_data"].keys())
1168
+ # Filter valid routes
1169
+ valid_routes = []
1170
+
1171
+ if using_tree:
1172
+ for nid in subcluster_node_ids:
1173
+ if nid in tree.nodes and tree.nodes[nid].is_solved():
1174
+ valid_routes.append(nid)
1175
+ else:
1176
+ # JSON mode: just keep those IDs present in the JSON
1177
+ for nid in subcluster_node_ids:
1178
+ if nid in routes_json:
1179
+ valid_routes.append(nid)
1180
+ routes_dict = make_dict(routes_json)
1181
+
1182
+ if not valid_routes:
1183
+ # Return a minimal HTML page indicating no valid routes
1184
+ return f"""
1185
+ <!doctype html><html lang="en"><head><meta charset="utf-8">
1186
+ <title>Cluster {group_index}.{cluster_num} Report</title></head><body>
1187
+ <h3>Cluster {group_index}.{cluster_num} Report</h3>
1188
+ <p>No valid/solved routes found for this cluster.</p>
1189
+ </body></html>"""
1190
+
1191
+ # --- Boilerplate HTML head/tail omitted for brevity ---
1192
+ template_begin = (
1193
+ """<!doctype html><html><head>…</head><body><div class="container">"""
1194
+ )
1195
+ template_end = """</div></body></html>"""
1196
+
1197
+ table = f"""
1198
+ <table class="table">
1199
+ <caption><h3>Cluster {group_index} Routes</h3></caption>
1200
+ <tbody>
1201
+ """
1202
+
1203
+ # show target
1204
+ if using_tree:
1205
+ try:
1206
+ target_smiles = str(tree.nodes[1].curr_precursor)
1207
+ except Exception:
1208
+ target_smiles = "N/A"
1209
+ else:
1210
+ # JSON mode: take the root smiles of the first route
1211
+ target_smiles = routes_json[valid_routes[0]]["smiles"]
1212
+
1213
+ # legend row omitted…
1214
+
1215
+ # --- HTML Templates & Tags ---
1216
+ th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
1217
+ td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
1218
+ font_head = "<font style='font-weight: bold; font-size: 18px'>"
1219
+ font_normal = "<font style='font-weight: normal; font-size: 18px'>"
1220
+ font_close = "</font>"
1221
+
1222
+ template_begin = f"""
1223
+ <!doctype html>
1224
+ <html lang="en">
1225
+ <head>
1226
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
1227
+ rel="stylesheet"
1228
+ integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
1229
+ crossorigin="anonymous">
1230
+ <meta charset="utf-8">
1231
+ <meta name="viewport" content="width=device-width, initial-scale=1">
1232
+ <title>SubCluster {group_index}.{cluster_num} Routes Report</title>
1233
+ <style>
1234
+ /* Optional: Add some basic styling */
1235
+ .table {{ border-collapse: collapse; width: 100%; }}
1236
+ th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
1237
+ tr:nth-child(even) {{ background-color: #ffffff; }}
1238
+ caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }}
1239
+ svg {{ max-width: 100%; height: auto; }}
1240
+ </style>
1241
+ </head>
1242
+ <body>
1243
+ <div class="container"> """
1244
+
1245
+ template_end = """
1246
+ </div> <script
1247
+ src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
1248
+ integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
1249
+ crossorigin="anonymous">
1250
+ </script>
1251
+ </body>
1252
+ </html>
1253
+ """
1254
+
1255
+ box_mark = """
1256
+ <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;">
1257
+ <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
1258
+ </svg>
1259
+ """
1260
+
1261
+ # --- Build HTML Table ---
1262
+ table = f"""
1263
+ <table class="table table-hover caption-top">
1264
+ <caption><h3>Retrosynthetic Routes Report - Cluster {group_index}.{cluster_num}</h3></caption>
1265
+ <tbody>"""
1266
+
1267
+ table += (
1268
+ f"<tr>{td}{font_normal}Target Molecule: {target_smiles}{font_close}</td></tr>"
1269
+ )
1270
+ table += f"<tr>{td}{font_normal}Group index: {group_index}{font_close}</td></tr>"
1271
+ table += f"<tr>{td}{font_normal}Cluster Number: {cluster_num}{font_close}</td></tr>"
1272
+ table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes)} routes{font_close} </td></tr>"
1273
+
1274
+ # --- Add ReducedRouteCGR Image ---
1275
+ first_route_id = valid_routes[0] if valid_routes else None
1276
+
1277
+ if first_route_id and sb_cgrs_dict:
1278
+ try:
1279
+ sb_cgr = sb_cgrs_dict[first_route_id]
1280
+ sb_cgr.clean2d()
1281
+ sb_cgr_svg = cgr_display(sb_cgr)
1282
+
1283
+ if sb_cgr_svg.strip().startswith("<svg"):
1284
+ table += f"<tr>{td}{font_normal}Identified Strategic Bonds{font_close}<br>{sb_cgr_svg}</td></tr>"
1285
+ else:
1286
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>"
1287
+ print(
1288
+ f"Warning: Expected SVG for ReducedRouteCGR of node {first_route_id}, but got: {sb_cgr_svg[:100]}..."
1289
+ )
1290
+ except Exception as e:
1291
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying ReducedRouteCGR: {e}</i></td></tr>"
1292
+ else:
1293
+ if first_route_id:
1294
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided ReducedRouteCGR dictionary.</i></td></tr>"
1295
+ else:
1296
+ table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>"
1297
+
1298
+ try:
1299
+ synthon_reaction = subcluster["synthon_reaction"]
1300
+ synthon_reaction.clean2d()
1301
+ synthon_svg = depict_custom_reaction(synthon_reaction)
1302
+
1303
+ extra_synthon = f"<tr>{td}{font_normal}Synthon pseudo reaction:{font_close}<br>{synthon_svg}</td></tr>"
1304
+ table += extra_synthon
1305
+ except Exception as e:
1306
+ table += f"<tr><td colspan='1' style='color: red;'>Error displaying synthon reaction: {e}</td></tr>"
1307
+
1308
+ try:
1309
+ if if_lg_group:
1310
+ grouped_lgs = subcluster["group_lgs"]
1311
+ lg_table_html = group_lg_table_2_html_fixed(grouped_lgs, if_display=False)
1312
+ else:
1313
+ lg_table_html = lg_table_2_html(subcluster, if_display=False)
1314
+ extra_lg = f"<tr>{td}{font_normal}Leaving Groups table:{font_close}<br>{lg_table_html}</td></tr>"
1315
+ table += extra_lg
1316
+ except Exception as e:
1317
+ table += f"<tr><td colspan='1' style='color: red;'>Error displaying leaving groups: {e}</td></tr>"
1318
+
1319
+ table += f"""
1320
+ <tr>{td}
1321
+ <div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;">
1322
+ <span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span>
1323
+ <span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span>
1324
+ <span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span>
1325
+ </div>
1326
+ </td></tr>
1327
+ """
1328
+ for route_id in valid_routes:
1329
+ if using_tree:
1330
+ # 1) SVG from Tree
1331
+ svg = get_route_svg(tree, route_id)
1332
+ # 2) Reaction steps & score
1333
+ steps = tree.synthesis_route(route_id)
1334
+ score = round(tree.route_score(route_id), 3)
1335
+ # build reaction list
1336
+ reac_html = "".join(
1337
+ f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in enumerate(steps)
1338
+ )
1339
+ header = f"Route {route_id} — {len(steps)} steps, score={score}"
1340
+ table += f"<tr><td><b>{header}</b></td></tr>"
1341
+ table += f"<tr><td>{svg}</td></tr>"
1342
+ table += f"<tr><td>{reac_html}</td></tr>"
1343
+
1344
+ else:
1345
+ # 1) SVG from JSON
1346
+ svg = get_route_svg_from_json(routes_json, route_id)
1347
+ steps = routes_dict[route_id]
1348
+ reac_html = "".join(
1349
+ f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in steps.items()
1350
+ )
1351
+
1352
+ header = f"Route {route_id} — {len(steps)} steps"
1353
+ table += f"<tr><td><b>{header}</b></td></tr>"
1354
+ table += f"<tr><td>{svg}</td></tr>"
1355
+ table += f"<tr><td>{reac_html}</td></tr>"
1356
+
1357
+ table += "</tbody></table>"
1358
+
1359
+ html = template_begin + table + template_end
1360
+
1361
+ if html_path:
1362
+ with open(html_path, "w", encoding="utf-8") as f:
1363
+ f.write(html)
1364
+ return f"Written to {html_path}"
1365
+ return html