Spaces:
Running
Running
| """Module containing functions for analysis and visualization of the built tree.""" | |
| import base64 | |
| from itertools import count, islice | |
| from collections import deque | |
| from typing import Any, Dict, List, Union | |
| from datetime import datetime | |
| import random | |
| from CGRtools.containers.molecule import MoleculeContainer | |
| from CGRtools import smiles as read_smiles | |
| from synplan.chem.reaction_routes.visualisation import ( | |
| cgr_display, | |
| depict_custom_reaction, | |
| ) | |
| from synplan.chem.reaction_routes.io import make_dict | |
| from synplan.mcts.tree import Tree | |
| from IPython.display import display, HTML | |
| # Data classes for structured data passing | |
| class _GeometryData: | |
| def __init__(self): | |
| self.arrow_points = {} | |
| self.mol_status = {} | |
| self.mol_labels = {} | |
| class _LayoutData: | |
| def __init__(self, render_components, arrow_points, mol_status, | |
| mol_labels, width, height): | |
| self.render_components = render_components | |
| self.arrow_points = arrow_points | |
| self.mol_status = mol_status | |
| self.mol_labels = mol_labels | |
| self.width = width | |
| self.height = height | |
| class _ColumnLayoutResult: | |
| def __init__(self, render_components, column_max_x, column_max_y, next_x_shift): | |
| self.render_components = render_components | |
| self.column_max_x = column_max_x | |
| self.column_max_y = column_max_y | |
| self.next_x_shift = next_x_shift | |
| class _MoleculeData: | |
| def __init__(self, molecule, min_x, max_x, min_y, max_y, index): | |
| self.molecule = molecule | |
| self.min_x = min_x | |
| self.max_x = max_x | |
| self.min_y = min_y | |
| self.max_y = max_y | |
| self.index = index | |
| def get_child_nodes( | |
| tree: Tree, | |
| molecule: MoleculeContainer, | |
| graph: Dict[MoleculeContainer, List[MoleculeContainer]], | |
| ) -> Dict[str, Any]: | |
| """Extracts the child nodes of the given molecule. | |
| :param tree: The built tree. | |
| :param molecule: The molecule in the tree from which to extract child nodes. | |
| :param graph: The relationship between the given molecule and child nodes. | |
| :return: The dict with extracted child nodes. | |
| """ | |
| nodes = [] | |
| try: | |
| graph[molecule] | |
| except KeyError: | |
| return [] | |
| for precursor in graph[molecule]: | |
| temp_obj = { | |
| "smiles": str(precursor), | |
| "type": "mol", | |
| "in_stock": str(precursor) in tree.building_blocks, | |
| } | |
| node = get_child_nodes(tree, precursor, graph) | |
| if node: | |
| temp_obj["children"] = [node] | |
| nodes.append(temp_obj) | |
| return {"type": "reaction", "children": nodes} | |
| def extract_routes( | |
| tree: Tree, extended: bool = False, min_mol_size: int = 0 | |
| ) -> List[Dict[str, Any]]: | |
| """Takes the target and the dictionary of successors and predecessors and returns a | |
| list of dictionaries that contain the target and the list of successors. | |
| :param tree: The built tree. | |
| :param extended: If True, generates the extended route representation. | |
| :param min_mol_size: If the size of the Precursor is equal or smaller than | |
| min_mol_size it is automatically classified as building block. | |
| :return: A list of dictionaries. Each dictionary contains a target, a list of | |
| children, and a boolean indicating whether the target is in building_blocks. | |
| """ | |
| target = tree.nodes[1].precursors_to_expand[0].molecule | |
| target_in_stock = tree.nodes[1].curr_precursor.is_building_block( | |
| tree.building_blocks, min_mol_size | |
| ) | |
| # append encoded routes to list | |
| routes_block = [] | |
| winning_nodes = [] | |
| if extended: | |
| # collect routes | |
| for i, node in tree.nodes.items(): | |
| if node.is_solved(): | |
| winning_nodes.append(i) | |
| else: | |
| winning_nodes = tree.winning_nodes | |
| if winning_nodes: | |
| for winning_node in winning_nodes: | |
| # Create graph for route | |
| nodes = tree.route_to_node(winning_node) | |
| graph, pred = {}, {} | |
| for before, after in zip(nodes, nodes[1:]): | |
| before = before.curr_precursor.molecule | |
| graph[before] = after = [x.molecule for x in after.new_precursors] | |
| for x in after: | |
| pred[x] = before | |
| routes_block.append( | |
| { | |
| "type": "mol", | |
| "smiles": str(target), | |
| "in_stock": target_in_stock, | |
| "children": [get_child_nodes(tree, target, graph)], | |
| } | |
| ) | |
| else: | |
| routes_block = [ | |
| { | |
| "type": "mol", | |
| "smiles": str(target), | |
| "in_stock": target_in_stock, | |
| "children": [], | |
| } | |
| ] | |
| return routes_block | |
| def render_svg(pred, columns, box_colors, labeled=False): | |
| """ | |
| Renders an SVG representation of a retrosynthetic route. | |
| """ | |
| # Initialize layout and collect geometry data | |
| layout_data = _compute_molecule_layout(columns, box_colors) | |
| # Build reaction graph and compute arrow geometry | |
| graph = _build_reaction_graph(pred) | |
| arrow_data = _compute_arrow_geometry(graph, layout_data.arrow_points) | |
| # Generate SVG components | |
| svg_components = _build_svg_components( | |
| layout_data, arrow_data, graph, box_colors, labeled | |
| ) | |
| return _assemble_svg(svg_components, layout_data.width, layout_data.height) | |
| def _order_columns_by_barycenter(columns, reactions, sweeps=2): | |
| """ | |
| Reorder molecules within each column to reduce edge crossings. | |
| Uses a simple barycenter heuristic with alternating sweeps. | |
| """ | |
| if not columns or len(columns) < 2: | |
| return [list(col) for col in columns] | |
| left_neighbors = {} | |
| right_neighbors = {} | |
| for reaction in reactions: | |
| for reactant in reaction.reactants: | |
| r_smiles = str(reactant) | |
| right_neighbors.setdefault(r_smiles, set()) | |
| for product in reaction.products: | |
| p_smiles = str(product) | |
| right_neighbors[r_smiles].add(p_smiles) | |
| left_neighbors.setdefault(p_smiles, set()).add(r_smiles) | |
| ordered = [sorted(list(col), key=lambda m: str(m)) for col in columns] | |
| sweeps = max(1, sweeps) | |
| for _ in range(sweeps): | |
| # Left -> right | |
| for i in range(1, len(ordered)): | |
| prev_pos = {str(m): idx for idx, m in enumerate(ordered[i - 1])} | |
| def l2r_key(mol): | |
| neighbors = left_neighbors.get(str(mol), ()) | |
| positions = [prev_pos[n] for n in neighbors if n in prev_pos] | |
| if not positions: | |
| return (1, str(mol)) | |
| return (0, sum(positions) / len(positions), str(mol)) | |
| ordered[i].sort(key=l2r_key) | |
| # Right -> left | |
| for i in range(len(ordered) - 2, -1, -1): | |
| next_pos = {str(m): idx for idx, m in enumerate(ordered[i + 1])} | |
| def r2l_key(mol): | |
| neighbors = right_neighbors.get(str(mol), ()) | |
| positions = [next_pos[n] for n in neighbors if n in next_pos] | |
| if not positions: | |
| return (1, str(mol)) | |
| return (0, sum(positions) / len(positions), str(mol)) | |
| ordered[i].sort(key=r2l_key) | |
| return ordered | |
| def _compute_molecule_layout(columns, box_colors): | |
| """ | |
| Compute the 2D layout for molecules in columns and collect geometry data. | |
| """ | |
| x_shift = 0.0 | |
| max_x = 0.0 | |
| max_y = 0.0 | |
| render_components = [] | |
| molecule_index = count() | |
| geometry_data = _GeometryData() | |
| for column_molecules in columns: | |
| column_result = _layout_column( | |
| column_molecules, x_shift, molecule_index, geometry_data, box_colors | |
| ) | |
| render_components.extend(column_result.render_components) | |
| max_x = max(max_x, column_result.column_max_x) | |
| max_y = max(max_y, column_result.column_max_y) | |
| x_shift = column_result.next_x_shift | |
| # Calculate final dimensions | |
| config = MoleculeContainer._render_config | |
| font_size = config["font_size"] | |
| width = max_x + 4.0 * font_size | |
| height = max_y + 3.5 * font_size | |
| return _LayoutData( | |
| render_components=render_components, | |
| arrow_points=geometry_data.arrow_points, | |
| mol_status=geometry_data.mol_status, | |
| mol_labels=geometry_data.mol_labels, | |
| width=width, | |
| height=height | |
| ) | |
| def _layout_column(molecules, x_shift, molecule_index, geometry_data, box_colors): | |
| """ | |
| Layout a single column of molecules. | |
| """ | |
| molecule_data_list = [] | |
| column_max_x = 0.0 | |
| # First pass: compute molecule bounds and store data | |
| for molecule in molecules: | |
| idx = next(molecule_index) | |
| mol_data = _prepare_molecule_data(molecule, x_shift, idx, geometry_data) | |
| molecule_data_list.append(mol_data) | |
| column_max_x = max(column_max_x, mol_data.max_x) | |
| # Update x_shift for next column | |
| next_x_shift = column_max_x + 5.0 | |
| # Second pass: compute vertical centering and y-positions | |
| heights = [data.max_y for data in molecule_data_list] | |
| total_height = sum(heights) + 3.0 * (len(heights) - 1) | |
| y_shift = total_height / 2.0 | |
| column_max_y = total_height | |
| render_components = [] | |
| for mol_data, height in zip(molecule_data_list, heights): | |
| render_component = _position_molecule_vertically( | |
| mol_data, y_shift, height, geometry_data, box_colors | |
| ) | |
| render_components.append(render_component) | |
| y_shift -= height + 3.0 | |
| return _ColumnLayoutResult( | |
| render_components=render_components, | |
| column_max_x=column_max_x, | |
| column_max_y=column_max_y, | |
| next_x_shift=next_x_shift | |
| ) | |
| def _prepare_molecule_data(molecule, x_shift, index, geometry_data): | |
| """ | |
| Clean and normalize molecule coordinates, return molecule data. | |
| """ | |
| molecule.clean2d() | |
| # Normalize coordinates | |
| min_x = min(x for x, y in molecule._plane.values()) - x_shift | |
| min_y = min(y for x, y in molecule._plane.values()) | |
| molecule._plane = {n: (x - min_x, y - min_y) for n, (x, y) in molecule._plane.items()} | |
| max_x = max(x for x, y in molecule._plane.values()) | |
| max_y = max(y for x, y in molecule._plane.values()) | |
| # Store initial geometry data | |
| geometry_data.arrow_points[index] = [x_shift, max_x] | |
| geometry_data.mol_status[index] = molecule.meta.get("status", "instock") | |
| geometry_data.mol_labels[index] = molecule.meta.get("label", "uspto") | |
| return _MoleculeData(molecule, min_x, max_x, min_y, max_y, index) | |
| def _position_molecule_vertically(mol_data, y_shift, height, geometry_data, box_colors): | |
| """ | |
| Position molecule vertically and create render components. | |
| """ | |
| molecule = mol_data.molecule | |
| # Adjust y coordinates for vertical centering | |
| molecule._plane = {n: (x, y - y_shift) for n, (x, y) in molecule._plane.items()} | |
| # Compute bounding box for background | |
| max_x = max(x for x, y in molecule._plane.values()) + 0.9 | |
| min_x = min(x for x, y in molecule._plane.values()) - 0.6 | |
| max_y_svg = -(max(y for x, y in molecule._plane.values()) + 0.45) | |
| min_y_svg = -(min(y for x, y in molecule._plane.values()) - 0.45) | |
| # Create background box | |
| status = geometry_data.mol_status[mol_data.index] | |
| box_svg = _create_background_box(min_x, max_x, max_y_svg, min_y_svg, status, box_colors) | |
| # Store y-position in geometry data | |
| center_y = y_shift - height / 2.0 | |
| geometry_data.arrow_points[mol_data.index].append(center_y) | |
| # Prepare molecule depiction | |
| depicted_molecule = list(molecule.depict(embedding=True))[:3] | |
| depicted_molecule.append(box_svg) | |
| return depicted_molecule | |
| def _create_background_box(min_x, max_x, max_y, min_y, status, box_colors): | |
| """ | |
| Create SVG rectangle for molecule background. | |
| """ | |
| width = abs(max_x - min_x) | |
| height = abs(max_y - min_y) | |
| corner_radius = height * 0.1 | |
| fill_color = box_colors.get(status, "#FFFFFF") | |
| return ( | |
| f'<rect x="{min_x}" y="{max_y}" rx="{corner_radius}" ry="{corner_radius}" ' | |
| f'width="{width}" height="{height}" stroke="black" stroke-width=".0025" ' | |
| f'fill="{fill_color}" fill-opacity="0.30"/>' | |
| ) | |
| def _build_reaction_graph(pred): | |
| """ | |
| Build graph representation of retrosynthetic reactions. | |
| """ | |
| graph = {} | |
| for source_idx, target_idx in pred: | |
| graph.setdefault(source_idx, []).append(target_idx) | |
| return graph | |
| def _compute_arrow_geometry(graph, arrow_points): | |
| """ | |
| Compute midpoints for arrow routing between molecules. | |
| """ | |
| # Store mid_x for each precursor in arrow_points | |
| for source_idx, precursors in graph.items(): | |
| source_min_x, source_max, source_y = arrow_points[source_idx][:3] | |
| mid_x = float("-inf") | |
| for precursor_idx in precursors: | |
| precursor_min_x, precursor_max, precursor_y = arrow_points[precursor_idx][:3] | |
| precursor_max += 1 | |
| mid = precursor_max + (source_min_x - precursor_max) / 3 | |
| mid_x = max(mid_x, mid) | |
| for precursor_idx in precursors: | |
| if len(arrow_points[precursor_idx]) < 4: | |
| arrow_points[precursor_idx].append(mid_x) | |
| return arrow_points | |
| def _build_svg_components(layout_data, arrow_data, graph, box_colors, labeled): | |
| """ | |
| Build all SVG components: arrows, molecules, and optional labels. | |
| """ | |
| svg_components = [] | |
| # Add SVG header and definitions | |
| svg_components.extend(_create_svg_header(layout_data.width, layout_data.height)) | |
| # Add arrows | |
| svg_components.extend(_create_arrows(graph, arrow_data)) | |
| # Add molecules | |
| svg_components.extend(_render_molecules(layout_data.render_components, layout_data.width, layout_data.height)) | |
| # Add labels if requested | |
| if labeled: | |
| svg_components.extend( | |
| _create_labels(graph, arrow_data, layout_data.mol_status, layout_data.mol_labels) | |
| ) | |
| return svg_components | |
| def _create_svg_header(width, height): | |
| """ | |
| Create SVG header with arrow marker definitions. | |
| """ | |
| config = MoleculeContainer._render_config | |
| font_size = config["font_size"] | |
| font125 = 1.25 * font_size | |
| box_y = height / 2.0 | |
| return [ | |
| f'<svg width="{0.6 * width:.2f}cm" height="{0.6 * height:.2f}cm" ' | |
| f'viewBox="{-font125:.2f} {-box_y:.2f} {width:.2f} {height:.2f}" ' | |
| f'xmlns="http://www.w3.org/2000/svg" version="1.1">', | |
| ' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" ' | |
| 'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3 z" fill="black"/>\n </marker>\n </defs>' | |
| ] | |
| def _create_arrows(graph, arrow_points): | |
| """ | |
| Create SVG arrows for retrosynthetic steps. | |
| """ | |
| arrows = [] | |
| for source_idx, precursors in graph.items(): | |
| for precursor_idx in precursors: | |
| arrow_svg = _create_single_arrow(source_idx, precursor_idx, arrow_points) | |
| arrows.append(arrow_svg) | |
| return arrows | |
| def _create_single_arrow(source_idx, precursor_idx, arrow_points): | |
| """ | |
| Create SVG for a single arrow between molecules. | |
| """ | |
| source_data = arrow_points[source_idx] | |
| precursor_data = arrow_points[precursor_idx] | |
| if len(source_data) < 3 or len(precursor_data) < 4: | |
| return "" | |
| s_min_x, s_max, s_y = source_data[:3] | |
| p_min_x, p_max, p_y = precursor_data[:3] | |
| mid_x = precursor_data[3] if len(precursor_data) > 3 else p_max + (s_min_x - p_max) / 3 | |
| p_max += 1 | |
| arrow_svg = ( | |
| f' <polyline points="{p_max:.2f} {p_y:.2f}, {mid_x:.2f} {p_y:.2f}, ' | |
| f'{mid_x:.2f} {s_y:.2f}, {s_min_x - 1.:.2f} {s_y:.2f}" ' | |
| f'fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>' | |
| ) | |
| # Add connection dot for non-straight arrows | |
| if p_y != s_y: | |
| arrow_svg += f' <circle cx="{mid_x}" cy="{p_y}" r="0.1" fill="black"/>' | |
| return arrow_svg | |
| def _render_molecules(render_components, width, height): | |
| """ | |
| Render all molecules with their background boxes. | |
| """ | |
| molecules_svg = [] | |
| config = MoleculeContainer._render_config | |
| font_size = config["font_size"] | |
| font125 = 1.25 * font_size | |
| box_y = height / 2.0 | |
| for atoms, bonds, masks, box in render_components: | |
| molecule_svg = MoleculeContainer._graph_svg( | |
| atoms, bonds, masks, -font125, -box_y, width, height | |
| ) | |
| molecule_svg.insert(1, box) | |
| molecules_svg.extend(molecule_svg) | |
| return molecules_svg | |
| def _create_labels(graph, arrow_points, mol_status, mol_labels): | |
| """ | |
| Create labels for specific molecule types (mulecule/target). | |
| """ | |
| # Label constants | |
| LABEL_GAP = 0.40 | |
| LABEL_FONT_SCALE = 0.80 | |
| LABEL_TEXT_PAD = 0.20 | |
| LABEL_STROKE = 0.04 | |
| labels = [] | |
| config = MoleculeContainer._render_config | |
| font_size = config["font_size"] | |
| for source_idx, precursors in graph.items(): | |
| status = mol_status.get(source_idx, "instock") | |
| if status not in {"mulecule", "target"}: | |
| continue | |
| source_data = arrow_points[source_idx] | |
| if len(source_data) < 3: | |
| continue | |
| s_min_x, s_max, s_y = source_data[:3] | |
| arrow_tip_x = s_min_x - 1.0 | |
| arrow_tip_y = s_y | |
| label_text = mol_labels.get(source_idx, "") | |
| if not label_text: | |
| continue | |
| # Calculate label dimensions | |
| font_px = font_size * LABEL_FONT_SCALE | |
| approx_char_width = 0.60 * font_px | |
| rect_width = len(label_text) * approx_char_width + 2 * LABEL_TEXT_PAD | |
| rect_height = 1.20 * font_px | |
| # Position label to left of arrow tip | |
| rect_x = arrow_tip_x - LABEL_GAP - rect_width | |
| rect_y = arrow_tip_y - rect_height / 2.0 | |
| # Create label elements | |
| rect_element = ( | |
| f'<rect x="{rect_x:.2f}" y="{rect_y:.2f}" width="{rect_width:.2f}" ' | |
| f'height="{rect_height:.2f}" fill="white" stroke="black" ' | |
| f'stroke-width="{LABEL_STROKE:.2f}"/>' | |
| ) | |
| text_center_x = rect_x + rect_width / 2.0 | |
| text_element = ( | |
| f'<text x="{text_center_x:.2f}" y="{arrow_tip_y:.2f}" ' | |
| f'font-size="{font_px:.2f}" text-anchor="middle" dominant-baseline="middle" ' | |
| f'fill="black">{label_text}</text>' | |
| ) | |
| labels.append(f' <g>{rect_element}{text_element}</g>') | |
| return labels | |
| def _assemble_svg(svg_components, width, height): | |
| """ | |
| Assemble all SVG components into final SVG string. | |
| """ | |
| svg_parts = [] | |
| svg_parts.extend(svg_components) | |
| svg_parts.append("</svg>") | |
| return "\n".join(svg_parts) | |
| def get_route_svg_mod(tree: Tree, node_id: int, sweeps: int = 2) -> str: | |
| """ | |
| Visualizes the full retrosynthetic route from the target to a given node. | |
| Uses a simple ordering heuristic to reduce arrow overlaps by reordering | |
| molecules within columns. | |
| :param tree: The built MCTS tree. | |
| :param node_id: The ID of the node to which the route should be visualized. | |
| :param sweeps: Number of ordering sweeps to reduce edge crossings. | |
| :return: A string containing the SVG visualization of the route. | |
| """ | |
| # Box colors for molecule status | |
| box_colors = { | |
| "target": "#98EEFF", # Light Blue for the main target | |
| "mulecule": "#F0AB90", # Peach for intermediates not in stock | |
| "instock": "#9BFAB3", # Light Green for building blocks | |
| } | |
| # Obtain the sequence of reaction steps in retrosynthetic order | |
| retro_reactions = list(reversed(tree.synthesis_route(node_id))) | |
| # Handle the case of the root node with no preceding reactions | |
| if not retro_reactions: | |
| target_node = tree.nodes.get(node_id) | |
| if not target_node: | |
| return "" | |
| molecule = target_node.curr_precursor.molecule | |
| molecule.meta["status"] = "target" | |
| return render_svg(tuple(), [[molecule]], box_colors) | |
| # Map all unique molecule SMILES to their MoleculeContainer objects | |
| mol_map = {str(m): m for r in retro_reactions for m in r.reactants + r.products} | |
| # Set the status for each unique molecule | |
| for smiles, molecule in mol_map.items(): | |
| molecule.meta["status"] = "instock" if smiles in tree.building_blocks else "mulecule" | |
| # The final target is the product of the first retrosynthetic reaction | |
| target_molecule = retro_reactions[0].products[0] | |
| target_molecule.meta["status"] = "target" | |
| mol_map[str(target_molecule)] = target_molecule | |
| # --- Build columns from left to right based on reaction dependencies --- | |
| columns = [] | |
| products_smiles = {str(p) for r in retro_reactions for p in r.products} | |
| leftmost_smiles = ( | |
| {str(m) for r in retro_reactions for m in r.reactants} - products_smiles | |
| ) | |
| if not leftmost_smiles: # Fallback for simple A->B routes | |
| leftmost_smiles = {str(m) for m in retro_reactions[-1].reactants} | |
| leftmost_order = sorted(leftmost_smiles) | |
| columns.append([mol_map[s] for s in leftmost_order]) | |
| placed_smiles = set(leftmost_order) | |
| # Iteratively build the next columns | |
| while len(placed_smiles) < len(mol_map): | |
| next_products = [] | |
| for reaction in retro_reactions: | |
| if all(str(reactant) in placed_smiles for reactant in reaction.reactants): | |
| for product in reaction.products: | |
| product_smiles = str(product) | |
| if product_smiles not in placed_smiles and product_smiles not in next_products: | |
| next_products.append(product_smiles) | |
| if not next_products: | |
| break # Safety break if no new column can be formed | |
| next_order = sorted(next_products) | |
| columns.append([mol_map[s] for s in next_order]) | |
| placed_smiles.update(next_order) | |
| # Reorder within columns to reduce edge crossings/overlaps | |
| columns = _order_columns_by_barycenter(columns, retro_reactions, sweeps=sweeps) | |
| # --- Prepare data for rendering --- | |
| flat_mols = [mol for col in columns for mol in col] | |
| mol_to_idx = {str(mol): i for i, mol in enumerate(flat_mols)} | |
| # Define the connections (precursor -> product) for the SVG rendering | |
| pred = [] | |
| for reaction in retro_reactions: | |
| for product in reaction.products: | |
| product_smiles = str(product) | |
| if product_smiles in mol_to_idx: | |
| s_idx = mol_to_idx[product_smiles] # 's' is the product (on the right) | |
| for reactant in reaction.reactants: | |
| reactant_smiles = str(reactant) | |
| if reactant_smiles in mol_to_idx: | |
| p_idx = mol_to_idx[reactant_smiles] # 'p' is the reactant (on the left) | |
| pred.append((s_idx, p_idx)) | |
| return render_svg(tuple(pred), columns, box_colors) | |
| def get_route_svg(tree: Tree, node_id: int, labeled: bool = False) -> str: | |
| """Visualizes the retrosynthetic route. | |
| :param tree: The built tree. | |
| :param node_id: The id of the node from which to visualize the route. | |
| :return: The SVG string. | |
| """ | |
| if node_id not in tree.winning_nodes: | |
| return None | |
| # --- 1. Reconstruct route as node IDs and Node objects (root -> leaf) --- | |
| path_ids = [] | |
| nid = node_id | |
| while nid: | |
| path_ids.append(nid) | |
| nid = tree.parents[nid] | |
| path_ids = list(reversed(path_ids)) | |
| nodes = [tree.nodes[i] for i in path_ids] | |
| # --- 2. Clear any old "label" metadata on molecules in this route --- | |
| for n in nodes: | |
| cp = getattr(n, "curr_precursor", None) | |
| # curr_precursor can be a tuple for some nodes -> guard it | |
| if cp is not None and hasattr(cp, "molecule"): | |
| cp.molecule.meta.pop("label", None) | |
| for prec in getattr(n, "new_precursors", ()): | |
| if hasattr(prec, "molecule"): | |
| prec.molecule.meta.pop("label", None) | |
| # --- 3. Assign labels from rule used to generate each parent product --- | |
| # For each edge parent_idx -> child_id, the rule of `child_id` labels | |
| # the parent product (parent_node.curr_precursor.molecule). | |
| for parent_idx in range(len(path_ids) - 1): | |
| child_id = path_ids[parent_idx + 1] | |
| if labeled: | |
| rule_label = tree.nodes_rule_label.get(child_id) # "priority" or "uspto" | |
| if not rule_label: | |
| continue | |
| else: | |
| rule_label = None | |
| parent_node = nodes[parent_idx] | |
| cp = getattr(parent_node, "curr_precursor", None) | |
| if labeled and rule_label and cp is not None and hasattr(cp, "molecule"): | |
| cp.molecule.meta["label"] = rule_label | |
| # --- 4. Original status-coloring logic (unchanged) --- | |
| for n in nodes: | |
| for precursor in n.new_precursors: | |
| precursor.molecule.meta["status"] = ( | |
| "instock" | |
| if precursor.is_building_block(tree.building_blocks) | |
| else "mulecule" | |
| ) | |
| nodes[0].curr_precursor.molecule.meta["status"] = "target" | |
| # Box colors | |
| box_colors = { | |
| "target": "#98EEFF", # blue | |
| "mulecule": "#F0AB90", # red/orange | |
| "instock": "#9BFAB3", # green | |
| } | |
| # --- 5. Original column / pred construction (unchanged) --- | |
| columns = [ | |
| [nodes[0].curr_precursor.molecule], | |
| [x.molecule for x in nodes[1].new_precursors], | |
| ] | |
| pred = {x: 0 for x in range(1, len(columns[1]) + 1)} | |
| cx = [ | |
| n | |
| for n, x in enumerate(nodes[1].new_precursors, 1) | |
| if not x.is_building_block(tree.building_blocks) | |
| ] | |
| size = len(cx) | |
| nodes_iter = iter(nodes[2:]) | |
| cy = count(len(columns[1]) + 1) | |
| while size: | |
| layer = [] | |
| for s in islice(nodes_iter, size): | |
| n = cx.pop(0) | |
| for x in s.new_precursors: | |
| layer.append(x) | |
| m = next(cy) | |
| if not x.is_building_block(tree.building_blocks): | |
| cx.append(m) | |
| pred[m] = n | |
| size = len(cx) | |
| columns.append([x.molecule for x in layer]) | |
| columns = [ | |
| columns[::-1] for columns in columns[::-1] | |
| ] # Reverse array to make retrosynthetic graph | |
| pred = tuple( # Change dict to tuple to make multiple precursor_to_expand available | |
| (abs(source - len(pred)), abs(target - len(pred))) | |
| for target, source in pred.items() | |
| ) | |
| svg = render_svg(pred, columns, box_colors, labeled=labeled) | |
| return svg | |
| def _get_root(routes_json: dict, route_id: int) -> dict: | |
| """ | |
| Retrieve the root tree for the given route_id, supporting int or str keys. | |
| Raises ValueError if not found. | |
| """ | |
| if route_id in routes_json: | |
| return routes_json[route_id] | |
| if str(route_id) in routes_json: | |
| return routes_json[str(route_id)] | |
| raise ValueError(f"Route ID {route_id} not found in routes_json.") | |
| def _extract_levels_and_parents(root: dict): | |
| """ | |
| BFS traversal of the tree to collect molecules by depth | |
| and record parent links for each mol-node. | |
| Returns (levels, parent_of) where: | |
| - levels[d] is a list of mol dicts at depth d | |
| - parent_of[node_id] = parent_mol_dict or None for root | |
| """ | |
| levels = [] | |
| parent_of = {} | |
| queue = deque([(root, 0, None)]) | |
| while queue: | |
| node, depth, parent = queue.popleft() | |
| if not isinstance(node, dict) or node.get("type") != "mol": | |
| continue | |
| # ensure depth list exists | |
| if depth >= len(levels): | |
| levels.extend([] for _ in range(depth - len(levels) + 1)) | |
| levels[depth].append(node) | |
| parent_of[id(node)] = parent | |
| # enqueue next-layer molecule children | |
| for reaction in node.get("children") or []: | |
| if not isinstance(reaction, dict) or reaction.get("type") != "reaction": | |
| continue | |
| for mol_child in reaction.get("children") or []: | |
| if isinstance(mol_child, dict) and mol_child.get("type") == "mol": | |
| queue.append((mol_child, depth + 1, node)) | |
| return levels, parent_of | |
| def get_route_svg_from_json(routes_json: dict, route_id: int) -> str: | |
| """ | |
| Visualize the retrosynthetic route for routes_json[route_id] as an SVG. | |
| """ | |
| # 1) Locate the root tree for this route | |
| root = _get_root(routes_json, route_id) | |
| # 2) Build per-depth molecule lists & parent mapping | |
| levels, parent_of = _extract_levels_and_parents(root) | |
| # 3) Create MoleculeContainer instances and set statuses | |
| mol_container = {} | |
| for depth, mols in enumerate(levels): | |
| for mol in mols: | |
| container = read_smiles(mol["smiles"]) | |
| if depth == 0: | |
| container.meta["status"] = "target" | |
| else: | |
| container.meta["status"] = ( | |
| "instock" if mol.get("in_stock") else "mulecule" | |
| ) | |
| mol_container[id(mol)] = container | |
| # 4) Mirror the columns (reverse depth order) | |
| json_columns = list(reversed(levels)) | |
| # 5) Flatten and index node IDs for layout ordering | |
| flat_ids = [id(m) for lvl in json_columns for m in lvl] | |
| index_map = {nid: idx for idx, nid in enumerate(flat_ids)} | |
| # 6) Build predecessor edges (parent -> child) in flattened indices | |
| pred = [] | |
| for node_id, parent in parent_of.items(): | |
| if parent is not None: | |
| pred.append((index_map[id(parent)], index_map[node_id])) | |
| pred = tuple(pred) | |
| # 7) Map JSON columns to MoleculeContainer columns | |
| columns = [[mol_container[id(m)] for m in lvl] for lvl in json_columns] | |
| # 8) Render SVG with status color coding | |
| box_colors = { | |
| "target": "#98EEFF", | |
| "mulecule": "#F0AB90", | |
| "instock": "#9BFAB3", | |
| } | |
| return render_svg(pred, columns, box_colors) | |
| def generate_results_html( | |
| tree: Tree, html_path: str, aam: bool = False, extended: bool = False | |
| ) -> None: | |
| """Writes an HTML page with the synthesis routes in SVG format and corresponding | |
| reactions in SMILES format. | |
| :param tree: The built tree. | |
| :param extended: If True, generates the extended route representation. | |
| :param html_path: The path to the file where to store resulting HTML. | |
| :param aam: If True, depict atom-to-atom mapping. | |
| :return: None. | |
| """ | |
| if aam: | |
| MoleculeContainer.depict_settings(aam=True) | |
| else: | |
| MoleculeContainer.depict_settings(aam=False) | |
| routes = [] | |
| if extended: | |
| # Gather paths | |
| for idx, node in tree.nodes.items(): | |
| if node.is_solved(): | |
| routes.append(idx) | |
| else: | |
| routes = tree.winning_nodes | |
| # HTML Tags | |
| th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">' | |
| td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">' | |
| font_red = "<font color='red' style='font-weight: bold'>" | |
| font_green = "<font color='light-green' style='font-weight: bold'>" | |
| font_head = "<font style='font-weight: bold; font-size: 18px'>" | |
| font_normal = "<font style='font-weight: normal; font-size: 18px'>" | |
| font_close = "</font>" | |
| template_begin = """ | |
| <!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" | |
| rel="stylesheet" | |
| integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3" | |
| crossorigin="anonymous"> | |
| <script | |
| src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js" | |
| integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p" | |
| crossorigin="anonymous"> | |
| </script> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>Predicted Paths Report</title> | |
| <meta name="description" content="A simple HTML5 Template for new projects."> | |
| <meta name="author" content="SitePoint"> | |
| </head> | |
| <body> | |
| """ | |
| template_end = """ | |
| </body> | |
| </html> | |
| """ | |
| # SVG Template | |
| box_mark = """ | |
| <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg"> | |
| <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" /> | |
| </svg> | |
| """ | |
| # table = f"<table><thead><{th}>Retrosynthetic Routes</th></thead><tbody>" | |
| table = """ | |
| <table class="table table-striped table-hover caption-top"> | |
| <caption><h3>Retrosynthetic Routes Report</h3></caption> | |
| <tbody>""" | |
| # Gather path data | |
| table += f"<tr>{td}{font_normal}Target Molecule: {str(tree.nodes[1].curr_precursor)}{font_close}</td></tr>" | |
| table += f"<tr>{td}{font_normal}Tree Size: {len(tree)}{font_close} nodes</td></tr>" | |
| table += f"<tr>{td}{font_normal}Number of visited nodes: {len(tree.visited_nodes)}{font_close}</td></tr>" | |
| table += f"<tr>{td}{font_normal}Found paths: {len(routes)}{font_close}</td></tr>" | |
| table += f"<tr>{td}{font_normal}Time: {round(tree.curr_time, 4)}{font_close} seconds</td></tr>" | |
| table += f""" | |
| <tr>{td} | |
| <div> | |
| {box_mark.replace("rgb()", "rgb(152, 238, 255)")} | |
| Target Molecule | |
| {box_mark.replace("rgb()", "rgb(240, 171, 144)")} | |
| Molecule Not In Stock | |
| {box_mark.replace("rgb()", "rgb(155, 250, 179)")} | |
| Molecule In Stock | |
| </div> | |
| </td></tr> | |
| """ | |
| for route in routes: | |
| svg = get_route_svg(tree, route) # get SVG | |
| full_route = tree.synthesis_route(route) # get route | |
| # write SMILES of all reactions in synthesis path | |
| step = 1 | |
| reactions = "" | |
| for synth_step in full_route: | |
| reactions += f"<b>Step {step}:</b> {str(synth_step)}<br>" | |
| step += 1 | |
| # Concatenate all content of path | |
| route_score = round(tree.route_score(route), 3) | |
| table += ( | |
| f'<tr style="line-height: 250%">{td}{font_head}Route {route}; ' | |
| f"Steps: {len(full_route)}; " | |
| f"Cumulated nodes' value: {route_score}{font_close}</td></tr>" | |
| ) | |
| # f"Cumulated nodes' value: {node._probabilities[path]}{font_close}</td></tr>" | |
| table += f"<tr>{td}{svg}</td></tr>" | |
| table += f"<tr>{td}{reactions}</td></tr>" | |
| table += "</tbody>" | |
| if html_path is None: | |
| return table | |
| with open(html_path, "w", encoding="utf-8") as html_file: | |
| html_file.write(template_begin) | |
| html_file.write(table) | |
| html_file.write(template_end) | |
| def html_top_routes_cluster( | |
| clusters: dict, tree: Tree, target_smiles: str, html_path: str = None | |
| ) -> str: | |
| """Clustering Results Download: Providing functionality to download the clustering results with styled HTML report.""" | |
| # Compute summary | |
| total_routes = sum(len(data.get("route_ids", [])) for data in clusters.values()) | |
| total_clusters = len(clusters) | |
| # Build styled HTML report using Bootstrap | |
| html = [] | |
| html.append("<!doctype html><html lang='en'><head>") | |
| html.append( | |
| "<meta charset='utf-8'><meta name='viewport' content='width=device-width, initial-scale=1'>" | |
| ) | |
| html.append( | |
| "<link href='https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css' rel='stylesheet'>" | |
| ) | |
| now = datetime.now() | |
| created_time = now.strftime("%Y-%m-%d %H:%M:%S") | |
| html.append("<title>Clustering Results Report</title>") | |
| html.append( | |
| "<style> svg{max-width:100%;height:auto;} .report-table th,.report-table td{vertical-align:top;border:1px solid #dee2e6;} </style>" | |
| ) | |
| html.append("</head><body><div class='container my-4'>") | |
| # Report header | |
| html.append( | |
| f""" | |
| <div class="d-flex justify-content-between align-items-center mb-3"> | |
| <h1 class="mb-0">Best route from each cluster</h1> | |
| <div class="text-end" style="min-width:180px;"> | |
| <p class="mb-1" style="font-size: 1rem;">Report created time:</p> | |
| <p class="mb-0" style="font-size: 1rem;">{created_time}</p> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| html.append(f"<p><strong>Target molecule (SMILES):</strong> {target_smiles}</p>") | |
| html.append(f"<p><strong>Total number of routes:</strong> {total_routes}</p>") | |
| html.append(f"<p><strong>Total number of clusters:</strong> {total_clusters}</p>") | |
| # Table header | |
| html.append( | |
| "<table class='table report-table'><colgroup><col style='width:5%'><colgroup><col style='width:5%'><col style='width:15%'><col style='width:75%'></colgroup><thead><tr>" | |
| ) | |
| html.append("<th>Cluster index</th><th>Size</th><th>SB-CGR</th><th>Best Route</th>") | |
| html.append("</tr></thead><tbody>") | |
| # Rows per cluster | |
| for cluster_num, group_data in clusters.items(): | |
| route_ids = group_data.get("route_ids", []) | |
| if not route_ids: | |
| continue | |
| route_id = route_ids[0] | |
| # Get SVGs | |
| svg = get_route_svg(tree, route_id) | |
| r_cgr = group_data.get("sb_cgr") | |
| r_cgr_svg = None | |
| if r_cgr: | |
| r_cgr.clean2d() | |
| r_cgr_svg = cgr_display(r_cgr) | |
| # Start row | |
| html.append(f"<tr><td>{cluster_num}</td>") | |
| html.append(f"<td>{len(route_ids)}</td>") | |
| html.append("<td>") | |
| if r_cgr_svg: | |
| b64_r = base64.b64encode(r_cgr_svg.encode("utf-8")).decode() | |
| html.append( | |
| f"<img src='data:image/svg+xml;base64,{b64_r}' alt='SB-CGR' class='img-fluid'/>" | |
| ) | |
| html.append("</td>") | |
| # Best Route cell | |
| html.append("<td>") | |
| if svg: | |
| b64_svg = base64.b64encode(svg.encode("utf-8")).decode() | |
| html.append( | |
| f"<img src='data:image/svg+xml;base64,{b64_svg}' alt='Route {route_id}' class='img-fluid'/>" | |
| ) | |
| html.append("</td></tr>") | |
| # Close table and HTML | |
| html.append("</tbody></table>") | |
| html.append("</div></body></html>") | |
| report_html = "".join(html) | |
| if html_path: | |
| with open(html_path, "w", encoding="utf-8") as f: | |
| f.write(report_html) | |
| return f"Written to {html_path}" | |
| return report_html | |
| def routes_clustering_report( | |
| source: Union[Tree, dict], | |
| clusters: dict, | |
| group_index: str, | |
| sb_cgrs_dict: dict, | |
| aam: bool = False, | |
| html_path: str = None, | |
| ) -> str: | |
| """ | |
| Generates an HTML report visualizing a cluster of retrosynthetic routes. | |
| This function takes a source of retrosynthetic routes (either a Tree object | |
| or a dictionary representing routes in JSON format), cluster information, | |
| and a dictionary of SB-CGRs, and produces a comprehensive HTML report. | |
| The report includes details about the cluster, a representative SB-CGR, | |
| and SVG visualizations of each route within the specified cluster. | |
| Args: | |
| source (Union[Tree, dict]): The source of retrosynthetic routes. | |
| Can be a Tree object containing the full | |
| search tree, or a dictionary loaded from | |
| a routes JSON file. | |
| clusters (dict): A dictionary containing clustering results. It should | |
| contain information about different clusters, typically | |
| including a list of 'route_ids' for each cluster. | |
| group_index (str): The key identifying the specific cluster within the | |
| `clusters` dictionary for which the report should be | |
| generated. | |
| sb_cgrs_dict (dict): A dictionary mapping route IDs (integers) to | |
| SB-CGR objects. Used to display a representative | |
| SB-CGR for the cluster. | |
| aam (bool, optional): Whether to enable atom-atom mapping visualization | |
| in molecule depictions. Defaults to False. | |
| html_path (str, optional): The file path where the generated HTML | |
| report should be saved. If provided, the | |
| function saves the report to this file and | |
| returns a confirmation message. If None, | |
| the function returns the HTML string | |
| directly. Defaults to None. | |
| Returns: | |
| str: The generated HTML report as a string, or a string confirming | |
| the file path where the report was saved if `html_path` is | |
| provided. Returns an error message string if the input `source` | |
| or `clusters` are invalid, or if the specified `group_index` is | |
| not found. | |
| """ | |
| # --- Depict Settings --- | |
| try: | |
| MoleculeContainer.depict_settings(aam=bool(aam)) | |
| except Exception: | |
| pass | |
| # --- Figure out what `source` is --- | |
| using_tree = False | |
| if hasattr(source, "nodes") and hasattr(source, "route_to_node"): | |
| tree = source | |
| using_tree = True | |
| elif isinstance(source, dict): | |
| routes_json = source | |
| tree = None | |
| else: | |
| return "<html><body>Error: first argument must be a Tree or a routes_json dict.</body></html>" | |
| # --- Validate clusters --- | |
| if not isinstance(clusters, dict): | |
| return "<html><body>Error: clusters must be a dict.</body></html>" | |
| group = clusters.get(group_index) | |
| if group is None: | |
| return f"<html><body>Error: no group with index {group_index!r}.</body></html>" | |
| cluster_route_ids = group.get("route_ids", []) | |
| # Filter valid routes | |
| valid_routes = [] | |
| if using_tree: | |
| for nid in cluster_route_ids: | |
| if nid in tree.nodes and tree.nodes[nid].is_solved(): | |
| valid_routes.append(nid) | |
| else: | |
| # JSON mode: check if the route ID exists in the routes_dict | |
| routes_dict = make_dict(routes_json) | |
| for nid in cluster_route_ids: | |
| if nid in routes_dict.keys(): | |
| valid_routes.append(nid) | |
| if not valid_routes: | |
| return f""" | |
| <!doctype html><html><body> | |
| <h3>Cluster {group_index} Report</h3> | |
| <p>No valid routes found in this cluster.</p> | |
| </body></html> | |
| """ | |
| # --- Boilerplate HTML head/tail omitted for brevity --- | |
| template_begin = ( | |
| """<!doctype html><html><head>…</head><body><div class="container">""" | |
| ) | |
| template_end = """</div></body></html>""" | |
| table = f""" | |
| <table class="table"> | |
| <caption><h3>Cluster {group_index} Routes</h3></caption> | |
| <tbody> | |
| """ | |
| # show target | |
| if using_tree: | |
| try: | |
| target_smiles = str(tree.nodes[1].curr_precursor) | |
| except Exception: | |
| target_smiles = "N/A" | |
| else: | |
| # JSON mode: take the root smiles of the first route | |
| try: | |
| target_smiles = routes_json[valid_routes[0]]["smiles"] | |
| except: | |
| target_smiles = routes_json[valid_routes[0]]["smiles"] | |
| # --- HTML Templates & Tags --- | |
| th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">' | |
| td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">' | |
| font_head = "<font style='font-weight: bold; font-size: 18px'>" | |
| font_normal = "<font style='font-weight: normal; font-size: 18px'>" | |
| font_close = "</font>" | |
| template_begin = f""" | |
| <!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" | |
| rel="stylesheet" | |
| integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3" | |
| crossorigin="anonymous"> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>Cluster {group_index} Routes Report</title> | |
| <style> | |
| /* Optional: Add some basic styling */ | |
| .table {{ border-collapse: collapse; width: 100%; }} | |
| th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }} | |
| tr:nth-child(even) {{ background-color: #ffffff; }} | |
| caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }} | |
| svg {{ max-width: 100%; height: auto; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> """ | |
| template_end = """ | |
| </div> <script | |
| src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js" | |
| integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p" | |
| crossorigin="anonymous"> | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| box_mark = """ | |
| <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;"> | |
| <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" /> | |
| </svg> | |
| """ | |
| # --- Build HTML Table --- | |
| table = f""" | |
| <table class="table table-hover caption-top"> | |
| <caption><h3>Retrosynthetic Routes Report - Cluster {group_index}</h3></caption> | |
| <tbody>""" | |
| table += ( | |
| f"<tr>{td}{font_normal}Target Molecule: {target_smiles}{font_close}</td></tr>" | |
| ) | |
| table += f"<tr>{td}{font_normal}Group index: {group_index}{font_close}</td></tr>" | |
| table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes)} routes{font_close} </td></tr>" | |
| # --- Add SB-CGR Image --- | |
| first_route_id = valid_routes[0] if valid_routes else None | |
| if first_route_id and sb_cgrs_dict: | |
| try: | |
| sb_cgr = sb_cgrs_dict[first_route_id] | |
| sb_cgr.clean2d() | |
| sb_cgr_svg = cgr_display(sb_cgr) | |
| if sb_cgr_svg.strip().startswith("<svg"): | |
| table += f"<tr>{td}{font_normal}Identified Strategic Bonds{font_close}<br>{sb_cgr_svg}</td></tr>" | |
| else: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>" | |
| print( | |
| f"Warning: Expected SVG for SB-CGR of route {first_route_id}, but got: {sb_cgr_svg[:100]}..." | |
| ) | |
| except Exception as e: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying SB-CGR: {e}</i></td></tr>" | |
| else: | |
| if first_route_id: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided SB-CGR dictionary.</i></td></tr>" | |
| else: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>" | |
| table += f""" | |
| <tr>{td} | |
| <div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;"> | |
| <span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span> | |
| <span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span> | |
| <span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span> | |
| </div> | |
| </td></tr> | |
| """ | |
| for route_id in valid_routes: | |
| if using_tree: | |
| # 1) SVG from Tree | |
| svg = get_route_svg(tree, route_id) | |
| # 2) Reaction steps & score | |
| steps = tree.synthesis_route(route_id) | |
| score = round(tree.route_score(route_id), 3) | |
| # build reaction list | |
| reac_html = "".join( | |
| f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in enumerate(steps) | |
| ) | |
| header = f"Route {route_id} — {len(steps)} steps, score={score}" | |
| table += f"<tr><td><b>{header}</b></td></tr>" | |
| table += f"<tr><td>{svg}</td></tr>" | |
| table += f"<tr><td>{reac_html}</td></tr>" | |
| else: | |
| # 1) SVG from JSON | |
| svg = get_route_svg_from_json(routes_json, route_id) | |
| steps = routes_dict[route_id] | |
| reac_html = "".join( | |
| f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in steps.items() | |
| ) | |
| header = f"Route {route_id} — {len(steps)} steps" | |
| table += f"<tr><td><b>{header}</b></td></tr>" | |
| table += f"<tr><td>{svg}</td></tr>" | |
| table += f"<tr><td>{reac_html}</td></tr>" | |
| table += "</tbody></table>" | |
| html = template_begin + table + template_end | |
| if html_path: | |
| with open(html_path, "w", encoding="utf-8") as f: | |
| f.write(html) | |
| return f"Written to {html_path}" | |
| return html | |
| def lg_table_2_html(subcluster, routes_to_display=[], if_display=True): | |
| """ | |
| Generates an HTML table visualizing leaving groups (X) 'marks' for routes within a subcluster. | |
| This function creates an HTML table where each row represents a routes | |
| from the specified subcluster (or a subset of routes), and columns | |
| represent unique 'marks' found across the routes. The cells contain | |
| the SVG depiction of the corresponding mark for that route. | |
| Args: | |
| subcluster (dict): A dictionary containing subcluster data, expected | |
| to have a 'routes_data' key mapping route IDs to | |
| dictionaries of marks and their associated data | |
| (where the first element is a depictable object). | |
| routes_to_display (list, optional): A list of specific route IDs to | |
| include in the table. If empty, | |
| all routes in `subcluster["routes_data"]` | |
| are included. Defaults to []. | |
| if_display (bool, optional): If True, the generated HTML is | |
| displayed directly using `display(HTML())`. | |
| Defaults to True. | |
| Returns: | |
| str: The generated HTML string for the table. | |
| """ | |
| # Create HTML table header | |
| html = "<table style='border-collapse: collapse;'><tr><th style='border: 1px solid black; padding: 4px;'>Route ID</th>" | |
| # Extract all unique marks across all routes to form consistent columns | |
| all_marks = set() | |
| for route_data in subcluster["routes_data"].values(): | |
| all_marks.update(route_data.keys()) | |
| all_marks = sorted(all_marks) # sort for consistent ordering | |
| # Add marks as headers | |
| for mark in all_marks: | |
| html += f"<th style='border: 1px solid black; padding: 4px;'>{mark}</th>" | |
| html += "</tr>" | |
| # Fill in the rows | |
| if len(routes_to_display) == 0: | |
| for route_id, route_data in subcluster["routes_data"].items(): | |
| html += f"<tr><td style='border: 1px solid black; padding: 4px;'>{route_id}</td>" | |
| for mark in all_marks: | |
| html += "<td style='border: 1px solid black; padding: 4px;'>" | |
| if mark in route_data: | |
| svg = route_data[mark][0].depict() # Get SVG data as string | |
| html += svg | |
| html += "</td>" | |
| html += "</tr>" | |
| else: | |
| for route_id in routes_to_display: | |
| # Check if the route_id exists in the subcluster data | |
| if route_id in subcluster["routes_data"]: | |
| route_data = subcluster["routes_data"][route_id] | |
| html += f"<tr><td style='border: 1px solid black; padding: 4px;'>{route_id}</td>" | |
| for mark in all_marks: | |
| html += "<td style='border: 1px solid black; padding: 4px;'>" | |
| if mark in route_data: | |
| svg = route_data[mark][0].depict() # Get SVG data as string | |
| html += svg | |
| html += "</td>" | |
| html += "</tr>" | |
| else: | |
| # Optionally, you can note that the route_id was not found | |
| html += f"<tr><td colspan='{len(all_marks)+1}' style='border: 1px solid black; padding: 4px; color:red;'>Route ID {route_id} not found.</td></tr>" | |
| html += "</table>" | |
| if if_display: | |
| display(HTML(html)) | |
| return html | |
| def group_lg_table_2_html_fixed( | |
| grouped: dict, | |
| groups_to_display=None, | |
| if_display=False, | |
| max_group_col_width: int = 200, | |
| ) -> str: | |
| """ | |
| Generates an HTML table visualizing leaving groups X 'marks' for representative routes in grouped data. | |
| This function takes a dictionary of grouped data, where each key represents | |
| a group (e.g., a collection of route IDs of routes) and the value is a representative | |
| dictionary of 'marks' for that group. It generates an HTML table with a | |
| fixed layout, where each row corresponds to a group, and columns show the | |
| SVG depiction or string representation of the 'marks' for the group's | |
| representative. | |
| Args: | |
| grouped (dict): A dictionary where keys are group identifiers (e.g., | |
| tuples of route IDs of routes) and values are dictionaries | |
| representing the 'marks' for the representative of | |
| that group. The 'marks' dictionary should map mark | |
| names (str) to objects that have a `.depict()` method | |
| or are convertible to a string. | |
| groups_to_display (list, optional): A list of specific group | |
| identifiers to include in the table. | |
| If None, all groups in the `grouped` | |
| dictionary are included. Defaults to None. | |
| if_display (bool, optional): If True, the generated HTML is | |
| displayed directly using `display(HTML())`. | |
| Defaults to False. | |
| max_group_col_width (int, optional): The maximum width (in pixels) | |
| for the column displaying the | |
| group identifiers. Defaults to 200. | |
| Returns: | |
| str: The generated HTML string for the table. | |
| """ | |
| # 1) pick which groups to show | |
| if groups_to_display is None: | |
| groups = list(grouped.keys()) | |
| else: | |
| groups = [g for g in groups_to_display if g in grouped] | |
| # 2) collect all marks for the header | |
| all_marks = sorted({m for rep in grouped.values() for m in rep.keys()}) | |
| # 3) build table start with auto layout | |
| html = [ | |
| "<table style='width:100%; table-layout:auto; border-collapse: collapse;'>", | |
| "<thead><tr>", | |
| "<th style='border:1px solid #ccc; padding:4px;'>Route IDs</th>", | |
| ] | |
| # numeric headers | |
| html += [ | |
| f"<th style='border:1px solid #ccc; padding:4px; text-align:center;'>X<small>{mark}</small></th>" | |
| for mark in all_marks | |
| ] | |
| html.append("</tr></thead><tbody>") | |
| # 4) each row | |
| group_td_style = ( | |
| f"border:1px solid #ccc; padding:4px; " | |
| "white-space: normal; overflow-wrap: break-word; " | |
| f"max-width:{max_group_col_width}px;" | |
| ) | |
| img_td_style = ( | |
| "border:1px solid #ccc; padding:4px; text-align:center; vertical-align:middle;" | |
| ) | |
| for group in groups: | |
| rep = grouped[group] | |
| label = ",".join(str(n) for n in group) | |
| # start row | |
| row = [f"<td style='{group_td_style}'>{label}</td>"] | |
| # fill in each mark column | |
| for mark in all_marks: | |
| cell = ["<td style='" + img_td_style + "'>"] | |
| if mark in rep: | |
| val = rep[mark] | |
| cell.append(val.depict() if hasattr(val, "depict") else str(val)) | |
| cell.append("</td>") | |
| row.append("".join(cell)) | |
| html.append("<tr>" + "".join(row) + "</tr>") | |
| html.append("</tbody></table>") | |
| out = "".join(html) | |
| if if_display: | |
| display(HTML(out)) | |
| return out | |
| def routes_subclustering_report( | |
| source: Union[Tree, dict], | |
| subcluster: dict, | |
| group_index: str, | |
| cluster_num: int, | |
| sb_cgrs_dict: dict, | |
| if_lg_group: bool = False, | |
| aam: bool = False, | |
| html_path: str = None, | |
| ) -> str: | |
| """ | |
| Generates an HTML report visualizing a specific subcluster of retrosynthetic routes. | |
| This function takes a source of retrosynthetic routes (either a Tree object | |
| or a dictionary representing routes in JSON format), data for a specific | |
| subcluster, and a dictionary of SB-CGRs. It produces a detailed HTML report | |
| for the subcluster, including general cluster information, a representative | |
| SB-CGR, a synthon pseudo reaction, a table of leaving groups (either per | |
| route or grouped), and SVG visualizations of each valid route within the | |
| subcluster. | |
| Args: | |
| source (Union[Tree, dict]): The source of retrosynthetic routes. | |
| Can be a Tree object containing the full | |
| search tree, or a dictionary loaded from | |
| a routes JSON file. | |
| subcluster (dict): A dictionary containing data for the specific | |
| subcluster. Expected keys include 'routes_data' | |
| (mapping route IDs to mark data), 'synthon_reaction', | |
| and optionally 'group_lgs' if `if_lg_group` is True. | |
| group_index (str): The index of the main cluster to which this | |
| subcluster belongs. Used for report titling. | |
| cluster_num (int): The number or identifier of the subcluster within | |
| its main group. Used for report titling. | |
| sb_cgrs_dict (dict): A dictionary mapping route IDs (integers) to | |
| SB-CGR objects. Used to display a representative | |
| SB-CGR for the cluster. | |
| if_lg_group (bool, optional): If True, the leaving groups table will | |
| display grouped leaving groups from | |
| `subcluster['group_lgs']`. If False, it | |
| will display leaving groups per individual | |
| route from `subcluster['routes_data']`. | |
| Defaults to False. | |
| aam (bool, optional): Whether to enable atom-atom mapping visualization | |
| in molecule depictions. Defaults to False. | |
| html_path (str, optional): The file path where the generated HTML | |
| report should be saved. If provided, the | |
| function saves the report to this file and | |
| returns a confirmation message. If None, | |
| the function returns the HTML string | |
| directly. Defaults to None. | |
| Returns: | |
| str: The generated HTML report as a string, or a string confirming | |
| the file path where the report was saved if `html_path` is | |
| provided. Returns a minimal HTML page indicating no valid routes | |
| if the subcluster contains no valid/solved routes. Returns an | |
| error message string if the input `source` or `subcluster` are | |
| invalid. | |
| """ | |
| # --- Depict Settings --- | |
| try: | |
| MoleculeContainer.depict_settings(aam=bool(aam)) | |
| except Exception: | |
| pass | |
| # --- Figure out what `source` is --- | |
| using_tree = False | |
| if hasattr(source, "nodes") and hasattr(source, "route_to_node"): | |
| tree = source | |
| using_tree = True | |
| elif isinstance(source, dict): | |
| routes_json = source | |
| tree = None | |
| else: | |
| return "<html><body>Error: first argument must be a Tree or a routes_json dict.</body></html>" | |
| # --- Validate groups --- | |
| if not isinstance(subcluster, dict): | |
| return "<html><body>Error: groups must be a dict.</body></html>" | |
| subcluster_route_ids = list(subcluster["routes_data"].keys()) | |
| # Filter valid routes | |
| valid_routes = [] | |
| if using_tree: | |
| for nid in subcluster_route_ids: | |
| if nid in tree.nodes and tree.nodes[nid].is_solved(): | |
| valid_routes.append(nid) | |
| else: | |
| # JSON mode: just keep those IDs present in the JSON | |
| for nid in subcluster_route_ids: | |
| if nid in routes_json: | |
| valid_routes.append(nid) | |
| routes_dict = make_dict(routes_json) | |
| if not valid_routes: | |
| # Return a minimal HTML page indicating no valid routes | |
| return f""" | |
| <!doctype html><html lang="en"><head><meta charset="utf-8"> | |
| <title>Cluster {group_index}.{cluster_num} Report</title></head><body> | |
| <h3>Cluster {group_index}.{cluster_num} Report</h3> | |
| <p>No valid/solved routes found for this cluster.</p> | |
| </body></html>""" | |
| # --- Boilerplate HTML head/tail omitted for brevity --- | |
| template_begin = ( | |
| """<!doctype html><html><head>…</head><body><div class="container">""" | |
| ) | |
| template_end = """</div></body></html>""" | |
| table = f""" | |
| <table class="table"> | |
| <caption><h3>Cluster {group_index} Routes</h3></caption> | |
| <tbody> | |
| """ | |
| # show target | |
| if using_tree: | |
| try: | |
| target_smiles = str(tree.nodes[1].curr_precursor) | |
| except Exception: | |
| target_smiles = "N/A" | |
| else: | |
| # JSON mode: take the root smiles of the first route | |
| target_smiles = routes_json[valid_routes[0]]["smiles"] | |
| # legend row omitted… | |
| # --- HTML Templates & Tags --- | |
| th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">' | |
| td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">' | |
| font_head = "<font style='font-weight: bold; font-size: 18px'>" | |
| font_normal = "<font style='font-weight: normal; font-size: 18px'>" | |
| font_close = "</font>" | |
| template_begin = f""" | |
| <!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" | |
| rel="stylesheet" | |
| integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3" | |
| crossorigin="anonymous"> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>SubCluster {group_index}.{cluster_num} Routes Report</title> | |
| <style> | |
| /* Optional: Add some basic styling */ | |
| .table {{ border-collapse: collapse; width: 100%; }} | |
| th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }} | |
| tr:nth-child(even) {{ background-color: #ffffff; }} | |
| caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }} | |
| svg {{ max-width: 100%; height: auto; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> """ | |
| template_end = """ | |
| </div> <script | |
| src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js" | |
| integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p" | |
| crossorigin="anonymous"> | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| box_mark = """ | |
| <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;"> | |
| <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" /> | |
| </svg> | |
| """ | |
| # --- Build HTML Table --- | |
| table = f""" | |
| <table class="table table-hover caption-top"> | |
| <caption><h3>Retrosynthetic Routes Report - Cluster {group_index}.{cluster_num}</h3></caption> | |
| <tbody>""" | |
| table += ( | |
| f"<tr>{td}{font_normal}Target Molecule: {target_smiles}{font_close}</td></tr>" | |
| ) | |
| table += f"<tr>{td}{font_normal}Group index: {group_index}{font_close}</td></tr>" | |
| table += f"<tr>{td}{font_normal}Cluster Number: {cluster_num}{font_close}</td></tr>" | |
| table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes)} routes{font_close} </td></tr>" | |
| # --- Add SB-CGR Image --- | |
| first_route_id = valid_routes[0] if valid_routes else None | |
| if first_route_id and sb_cgrs_dict: | |
| try: | |
| sb_cgr = sb_cgrs_dict[first_route_id] | |
| sb_cgr.clean2d() | |
| sb_cgr_svg = cgr_display(sb_cgr) | |
| if sb_cgr_svg.strip().startswith("<svg"): | |
| table += f"<tr>{td}{font_normal}Identified Strategic Bonds{font_close}<br>{sb_cgr_svg}</td></tr>" | |
| else: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>" | |
| print( | |
| f"Warning: Expected SVG for SB-CGR of route {first_route_id}, but got: {sb_cgr_svg[:100]}..." | |
| ) | |
| except Exception as e: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying SB-CGR: {e}</i></td></tr>" | |
| else: | |
| if first_route_id: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided SB-CGR dictionary.</i></td></tr>" | |
| else: | |
| table += f"<tr>{td}{font_normal}Cluster Representative SB-CGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>" | |
| try: | |
| synthon_reaction = subcluster["synthon_reaction"] | |
| synthon_reaction.clean2d() | |
| synthon_svg = depict_custom_reaction(synthon_reaction) | |
| extra_synthon = f"<tr>{td}{font_normal}Synthon pseudo reaction:{font_close}<br>{synthon_svg}</td></tr>" | |
| table += extra_synthon | |
| except Exception as e: | |
| table += f"<tr><td colspan='1' style='color: red;'>Error displaying synthon reaction: {e}</td></tr>" | |
| try: | |
| if if_lg_group: | |
| grouped_lgs = subcluster["group_lgs"] | |
| lg_table_html = group_lg_table_2_html_fixed(grouped_lgs, if_display=False) | |
| else: | |
| lg_table_html = lg_table_2_html(subcluster, if_display=False) | |
| extra_lg = f"<tr>{td}{font_normal}Leaving Groups table:{font_close}<br>{lg_table_html}</td></tr>" | |
| table += extra_lg | |
| except Exception as e: | |
| table += f"<tr><td colspan='1' style='color: red;'>Error displaying leaving groups: {e}</td></tr>" | |
| table += f""" | |
| <tr>{td} | |
| <div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;"> | |
| <span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span> | |
| <span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span> | |
| <span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span> | |
| </div> | |
| </td></tr> | |
| """ | |
| for route_id in valid_routes: | |
| if using_tree: | |
| # 1) SVG from Tree | |
| svg = get_route_svg(tree, route_id) | |
| # 2) Reaction steps & score | |
| steps = tree.synthesis_route(route_id) | |
| score = round(tree.route_score(route_id), 3) | |
| # build reaction list | |
| reac_html = "".join( | |
| f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in enumerate(steps) | |
| ) | |
| header = f"Route {route_id} — {len(steps)} steps, score={score}" | |
| table += f"<tr><td><b>{header}</b></td></tr>" | |
| table += f"<tr><td>{svg}</td></tr>" | |
| table += f"<tr><td>{reac_html}</td></tr>" | |
| else: | |
| # 1) SVG from JSON | |
| svg = get_route_svg_from_json(routes_json, route_id) | |
| steps = routes_dict[route_id] | |
| reac_html = "".join( | |
| f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in steps.items() | |
| ) | |
| header = f"Route {route_id} — {len(steps)} steps" | |
| table += f"<tr><td><b>{header}</b></td></tr>" | |
| table += f"<tr><td>{svg}</td></tr>" | |
| table += f"<tr><td>{reac_html}</td></tr>" | |
| table += "</tbody></table>" | |
| html = template_begin + table + template_end | |
| if html_path: | |
| with open(html_path, "w", encoding="utf-8") as f: | |
| f.write(html) | |
| return f"Written to {html_path}" | |
| return html | |