ipd commited on
Commit
197c331
1 Parent(s): 34a46ac
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .idea/.gitignore +3 -0
  3. .idea/inspectionProfiles/Project_Default.xml +29 -0
  4. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  5. .idea/materials.mhg-ged.iml +12 -0
  6. .idea/modules.xml +8 -0
  7. .idea/vcs.xml +6 -0
  8. README.md +78 -3
  9. __init__.py +5 -0
  10. __pycache__/__init__.cpython-310.pyc +0 -0
  11. __pycache__/load.cpython-310.pyc +0 -0
  12. graph_grammar/.DS_Store +0 -0
  13. graph_grammar/__init__.py +19 -0
  14. graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
  15. graph_grammar/__pycache__/hypergraph.cpython-310.pyc +0 -0
  16. graph_grammar/algo/__init__.py +20 -0
  17. graph_grammar/algo/__pycache__/__init__.cpython-310.pyc +0 -0
  18. graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc +0 -0
  19. graph_grammar/algo/tree_decomposition.py +821 -0
  20. graph_grammar/graph_grammar/__init__.py +20 -0
  21. graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
  22. graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc +0 -0
  23. graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc +0 -0
  24. graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc +0 -0
  25. graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc +0 -0
  26. graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc +0 -0
  27. graph_grammar/graph_grammar/base.py +30 -0
  28. graph_grammar/graph_grammar/corpus.py +152 -0
  29. graph_grammar/graph_grammar/hrg.py +1065 -0
  30. graph_grammar/graph_grammar/symbols.py +180 -0
  31. graph_grammar/graph_grammar/utils.py +130 -0
  32. graph_grammar/hypergraph.py +544 -0
  33. graph_grammar/io/__init__.py +20 -0
  34. graph_grammar/io/__pycache__/__init__.cpython-310.pyc +0 -0
  35. graph_grammar/io/__pycache__/smi.cpython-310.pyc +0 -0
  36. graph_grammar/io/smi.py +559 -0
  37. graph_grammar/nn/__init__.py +11 -0
  38. graph_grammar/nn/__pycache__/__init__.cpython-310.pyc +0 -0
  39. graph_grammar/nn/__pycache__/decoder.cpython-310.pyc +0 -0
  40. graph_grammar/nn/__pycache__/encoder.cpython-310.pyc +0 -0
  41. graph_grammar/nn/dataset.py +121 -0
  42. graph_grammar/nn/decoder.py +158 -0
  43. graph_grammar/nn/encoder.py +199 -0
  44. graph_grammar/nn/graph.py +313 -0
  45. images/mhg_example.png +0 -0
  46. images/mhg_example1.png +0 -0
  47. images/mhg_example2.png +0 -0
  48. load.py +83 -0
  49. mhg_gnn.egg-info/PKG-INFO +102 -0
  50. mhg_gnn.egg-info/SOURCES.txt +46 -0
.DS_Store ADDED
Binary file (10.2 kB). View file
 
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="16">
8
+ <item index="0" class="java.lang.String" itemvalue="accelerate" />
9
+ <item index="1" class="java.lang.String" itemvalue="matplotlib" />
10
+ <item index="2" class="java.lang.String" itemvalue="torch-geometric" />
11
+ <item index="3" class="java.lang.String" itemvalue="torchinfo" />
12
+ <item index="4" class="java.lang.String" itemvalue="caikit" />
13
+ <item index="5" class="java.lang.String" itemvalue="pytorch-fast-transformers" />
14
+ <item index="6" class="java.lang.String" itemvalue="e3nn" />
15
+ <item index="7" class="java.lang.String" itemvalue="rdkit" />
16
+ <item index="8" class="java.lang.String" itemvalue="PyImpetus" />
17
+ <item index="9" class="java.lang.String" itemvalue="torch-scatter" />
18
+ <item index="10" class="java.lang.String" itemvalue="torch-nl" />
19
+ <item index="11" class="java.lang.String" itemvalue="torch-sparse" />
20
+ <item index="12" class="java.lang.String" itemvalue="mordred" />
21
+ <item index="13" class="java.lang.String" itemvalue="xgboost" />
22
+ <item index="14" class="java.lang.String" itemvalue="mamba-ssm" />
23
+ <item index="15" class="java.lang.String" itemvalue="evaluate" />
24
+ </list>
25
+ </value>
26
+ </option>
27
+ </inspection_tool>
28
+ </profile>
29
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/materials.mhg-ged.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="NUMPY" />
10
+ <option name="myDocStringFormat" value="NumPy" />
11
+ </component>
12
+ </module>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/materials.mhg-ged.iml" filepath="$PROJECT_DIR$/.idea/materials.mhg-ged.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
README.md CHANGED
@@ -1,3 +1,78 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # mhg-gnn
5
+
6
+ This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
7
+
8
+ **Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
9
+
10
+ ![mhg-gnn](images/mhg_example1.png)
11
+
12
+ ## Introduction
13
+
14
+ We present MHG-GNN, an autoencoder architecture
15
+ that has an encoder based on GNN and a decoder based on a sequential model with MHG.
16
+ Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
17
+ demonstrate high predictive performance on molecular graph data.
18
+ In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
19
+
20
+ ## Table of Contents
21
+
22
+ 1. [Getting Started](#getting-started)
23
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
24
+ 2. [Installation](#installation)
25
+ 2. [Feature Extraction](#feature-extraction)
26
+
27
+ ## Getting Started
28
+
29
+ **This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
30
+
31
+ ### Pretrained Models and Training Logs
32
+
33
+ We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
34
+
35
+ Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
36
+
37
+ ### Installation
38
+
39
+ We recommend to create a virtual environment. For example:
40
+
41
+ ```
42
+ python3 -m venv .venv
43
+ . .venv/bin/activate
44
+ ```
45
+
46
+ Type the following command once the virtual environment is activated:
47
+
48
+ ```
49
+ git clone git@github.ibm.com:CMD-TRL/mhg-gnn.git
50
+ cd ./mhg-gnn
51
+ pip install .
52
+ ```
53
+
54
+ ## Feature Extraction
55
+
56
+ The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
57
+
58
+ To load mhg-gnn, you can simply use:
59
+
60
+ ```python
61
+ import torch
62
+ import load
63
+
64
+ model = load.load()
65
+ ```
66
+
67
+ To encode SMILES into embeddings, you can use:
68
+
69
+ ```python
70
+ with torch.no_grad():
71
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
72
+ ```
73
+
74
+ For decoder, you can use the function, so you can return from embeddings to SMILES strings:
75
+
76
+ ```python
77
+ orig = model.decode(repr)
78
+ ```
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (214 Bytes). View file
 
__pycache__/load.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
graph_grammar/.DS_Store ADDED
Binary file (8.2 kB). View file
 
graph_grammar/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """
12
+
13
+ """ Title """
14
+
15
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
16
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
17
+ __version__ = "0.1"
18
+ __date__ = "Jan 1 2018"
19
+
graph_grammar/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (666 Bytes). View file
 
graph_grammar/__pycache__/hypergraph.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
graph_grammar/algo/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
graph_grammar/algo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (659 Bytes). View file
 
graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc ADDED
Binary file (19.5 kB). View file
 
graph_grammar/algo/tree_decomposition.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from copy import deepcopy
22
+ from itertools import combinations
23
+ from ..hypergraph import Hypergraph
24
+ import networkx as nx
25
+ import numpy as np
26
+
27
+
28
+ class CliqueTree(nx.Graph):
29
+ ''' clique tree object
30
+
31
+ Attributes
32
+ ----------
33
+ hg : Hypergraph
34
+ This hypergraph will be decomposed.
35
+ root_hg : Hypergraph
36
+ Hypergraph on the root node.
37
+ ident_node_dict : dict
38
+ ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
39
+ '''
40
+ def __init__(self, hg=None, **kwargs):
41
+ self.hg = deepcopy(hg)
42
+ if self.hg is not None:
43
+ self.ident_node_dict = self.hg.get_identical_node_dict()
44
+ else:
45
+ self.ident_node_dict = {}
46
+ super().__init__(**kwargs)
47
+
48
+ @property
49
+ def root_hg(self):
50
+ ''' return the hypergraph on the root node
51
+ '''
52
+ return self.nodes[0]['subhg']
53
+
54
+ @root_hg.setter
55
+ def root_hg(self, hypergraph):
56
+ ''' set the hypergraph on the root node
57
+ '''
58
+ self.nodes[0]['subhg'] = hypergraph
59
+
60
+ def insert_subhg(self, subhypergraph: Hypergraph) -> None:
61
+ ''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
62
+
63
+ Parameters
64
+ ----------
65
+ subhg : Hypergraph
66
+ '''
67
+ num_nodes = self.number_of_nodes()
68
+ self.add_node(num_nodes, subhg=subhypergraph)
69
+ self.add_edge(num_nodes, 0)
70
+ adj_nodes = deepcopy(list(self.adj[0].keys()))
71
+ for each_node in adj_nodes:
72
+ if len(self.nodes[each_node]["subhg"].nodes.intersection(
73
+ self.nodes[num_nodes]["subhg"].nodes)\
74
+ - self.root_hg.nodes) != 0 and each_node != num_nodes:
75
+ self.remove_edge(0, each_node)
76
+ self.add_edge(each_node, num_nodes)
77
+
78
+ def to_irredundant(self) -> None:
79
+ ''' convert the clique tree to be irredundant
80
+ '''
81
+ for each_node in self.hg.nodes:
82
+ subtree = self.subgraph([
83
+ each_tree_node for each_tree_node in self.nodes()\
84
+ if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
85
+ leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
86
+ redundant_leaf_node_list = []
87
+ for each_leaf_node in leaf_node_list:
88
+ if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
89
+ redundant_leaf_node_list.append(each_leaf_node)
90
+ for each_red_leaf_node in redundant_leaf_node_list:
91
+ current_node = each_red_leaf_node
92
+ while subtree.degree(current_node) == 1 \
93
+ and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
94
+ self.nodes[current_node]["subhg"].remove_node(each_node)
95
+ remove_node = current_node
96
+ current_node = list(dict(subtree[remove_node]).keys())[0]
97
+ subtree.remove_node(remove_node)
98
+
99
+ fixed_node_set = deepcopy(self.nodes)
100
+ for each_node in fixed_node_set:
101
+ if self.nodes[each_node]["subhg"].num_edges == 0:
102
+ if len(self[each_node]) == 1:
103
+ self.remove_node(each_node)
104
+ elif len(self[each_node]) == 2:
105
+ self.add_edge(*self[each_node])
106
+ self.remove_node(each_node)
107
+ else:
108
+ pass
109
+ else:
110
+ pass
111
+
112
+ redundant = True
113
+ while redundant:
114
+ redundant = False
115
+ fixed_edge_set = deepcopy(self.edges)
116
+ remove_node_set = set()
117
+ for node_1, node_2 in fixed_edge_set:
118
+ if node_1 in remove_node_set or node_2 in remove_node_set:
119
+ pass
120
+ else:
121
+ if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
122
+ redundant = True
123
+ adj_node_list = set(self.adj[node_1]) - {node_2}
124
+ self.remove_node(node_1)
125
+ remove_node_set.add(node_1)
126
+ for each_node in adj_node_list:
127
+ self.add_edge(node_2, each_node)
128
+
129
+ elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
130
+ redundant = True
131
+ adj_node_list = set(self.adj[node_2]) - {node_1}
132
+ self.remove_node(node_2)
133
+ remove_node_set.add(node_2)
134
+ for each_node in adj_node_list:
135
+ self.add_edge(node_1, each_node)
136
+
137
+ def node_update(self, key_node: str, subhg) -> None:
138
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
139
+
140
+ Parameters
141
+ ----------
142
+ key_node : str
143
+ key node that must be removed.
144
+ subhg : Hypegraph
145
+ """
146
+ for each_edge in subhg.edges:
147
+ self.root_hg.remove_edge(each_edge)
148
+ self.root_hg.remove_nodes(self.ident_node_dict[key_node])
149
+
150
+ adj_node_list = list(subhg.nodes)
151
+ for each_node in subhg.nodes:
152
+ if each_node not in self.ident_node_dict[key_node]:
153
+ if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
154
+ self.root_hg.remove_node(each_node)
155
+ adj_node_list.remove(each_node)
156
+ else:
157
+ adj_node_list.remove(each_node)
158
+
159
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
160
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
161
+ self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
162
+
163
+ subhg.remove_edges_with_attr({'tmp' : True})
164
+ self.insert_subhg(subhg)
165
+
166
+ def update(self, subhg, remove_nodes=False):
167
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
168
+
169
+ Parameters
170
+ ----------
171
+ subhg : Hypegraph
172
+ """
173
+ for each_edge in subhg.edges:
174
+ self.root_hg.remove_edge(each_edge)
175
+ if remove_nodes:
176
+ remove_edge_list = []
177
+ for each_edge in self.root_hg.edges:
178
+ if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
179
+ and self.root_hg.edge_attr(each_edge).get('tmp', False):
180
+ remove_edge_list.append(each_edge)
181
+ self.root_hg.remove_edges(remove_edge_list)
182
+
183
+ adj_node_list = list(subhg.nodes)
184
+ for each_node in subhg.nodes:
185
+ if self.root_hg.degree(each_node) == 0:
186
+ self.root_hg.remove_node(each_node)
187
+ adj_node_list.remove(each_node)
188
+
189
+ if len(adj_node_list) != 1 and not remove_nodes:
190
+ self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
191
+ '''
192
+ else:
193
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
194
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
195
+ self.root_hg.add_edge(
196
+ [each_node_1, each_node_2], attr_dict=dict(tmp=True))
197
+ '''
198
+ subhg.remove_edges_with_attr({'tmp':True})
199
+ self.insert_subhg(subhg)
200
+
201
+
202
+ def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
203
+ if mode == 'standard':
204
+ degree_dict = hg.degrees()
205
+ min_deg_node = min(degree_dict, key=degree_dict.get)
206
+ min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
207
+ return min_deg_node, min_deg_subhg
208
+ elif mode == 'mol':
209
+ degree_dict = hg.degrees()
210
+ min_deg = min(degree_dict.values())
211
+ min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
212
+ min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
213
+ for each_min_deg_node in min_deg_node_list]
214
+ best_score = np.inf
215
+ best_idx = -1
216
+ for each_idx in range(len(min_deg_subhg_list)):
217
+ if min_deg_subhg_list[each_idx].num_nodes < best_score:
218
+ best_idx = each_idx
219
+ return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
220
+ else:
221
+ raise ValueError
222
+
223
+
224
+ def tree_decomposition(hg, irredundant=True):
225
+ """ compute a tree decomposition of the input hypergraph
226
+
227
+ Parameters
228
+ ----------
229
+ hg : Hypergraph
230
+ hypergraph to be decomposed
231
+ irredundant : bool
232
+ if True, irredundant tree decomposition will be computed.
233
+
234
+ Returns
235
+ -------
236
+ clique_tree : nx.Graph
237
+ each node contains a subhypergraph of `hg`
238
+ """
239
+ org_hg = hg.copy()
240
+ ident_node_dict = hg.get_identical_node_dict()
241
+ clique_tree = CliqueTree(org_hg)
242
+ clique_tree.add_node(0, subhg=org_hg)
243
+ while True:
244
+ degree_dict = org_hg.degrees()
245
+ min_deg_node = min(degree_dict, key=degree_dict.get)
246
+ min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
247
+ if org_hg.nodes == min_deg_subhg.nodes:
248
+ break
249
+
250
+ # org_hg and min_deg_subhg are divided
251
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
252
+
253
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
254
+
255
+ if irredundant:
256
+ clique_tree.to_irredundant()
257
+ return clique_tree
258
+
259
+
260
+ def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
261
+ ''' compute a tree decomposition given a hyperedge replacement grammar.
262
+ the resultant clique tree should induce a less compact HRG.
263
+
264
+ Parameters
265
+ ----------
266
+ hg : Hypergraph
267
+ hypergraph to be decomposed
268
+ hrg : HyperedgeReplacementGrammar
269
+ current HRG
270
+ irredundant : bool
271
+ if True, irredundant tree decomposition will be computed.
272
+
273
+ Returns
274
+ -------
275
+ clique_tree : nx.Graph
276
+ each node contains a subhypergraph of `hg`
277
+ '''
278
+ org_hg = hg.copy()
279
+ ident_node_dict = hg.get_identical_node_dict()
280
+ clique_tree = CliqueTree(org_hg)
281
+ clique_tree.add_node(0, subhg=org_hg)
282
+ root_node = 0
283
+
284
+ # construct a clique tree using HRG
285
+ success_any = True
286
+ while success_any:
287
+ success_any = False
288
+ for each_prod_rule in hrg.prod_rule_list:
289
+ org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
290
+ if success:
291
+ if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
292
+ success_any = True
293
+ subhg.remove_edges_with_attr({'terminal' : False})
294
+ clique_tree.root_hg = org_hg
295
+ clique_tree.insert_subhg(subhg)
296
+
297
+ clique_tree.root_hg = org_hg
298
+
299
+ for each_edge in deepcopy(org_hg.edges):
300
+ if not org_hg.edge_attr(each_edge)['terminal']:
301
+ node_list = org_hg.nodes_in_edge(each_edge)
302
+ org_hg.remove_edge(each_edge)
303
+
304
+ for each_node_1, each_node_2 in combinations(node_list, 2):
305
+ if not org_hg.is_adj(each_node_1, each_node_2):
306
+ org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
307
+
308
+ # construct a clique tree using the existing algorithm
309
+ degree_dict = org_hg.degrees()
310
+ if degree_dict:
311
+ while True:
312
+ min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
313
+ if org_hg.nodes == min_deg_subhg.nodes: break
314
+
315
+ # org_hg and min_deg_subhg are divided
316
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
317
+
318
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
319
+ if irredundant:
320
+ clique_tree.to_irredundant()
321
+
322
+ if return_root:
323
+ if root_node == 0 and 0 not in clique_tree.nodes:
324
+ root_node = clique_tree.number_of_nodes()
325
+ while root_node not in clique_tree.nodes:
326
+ root_node -= 1
327
+ elif root_node not in clique_tree.nodes:
328
+ while root_node not in clique_tree.nodes:
329
+ root_node -= 1
330
+ else:
331
+ pass
332
+ return clique_tree, root_node
333
+ else:
334
+ return clique_tree
335
+
336
+
337
+ def tree_decomposition_from_leaf(hg, irredundant=True):
338
+ """ compute a tree decomposition of the input hypergraph
339
+
340
+ Parameters
341
+ ----------
342
+ hg : Hypergraph
343
+ hypergraph to be decomposed
344
+ irredundant : bool
345
+ if True, irredundant tree decomposition will be computed.
346
+
347
+ Returns
348
+ -------
349
+ clique_tree : nx.Graph
350
+ each node contains a subhypergraph of `hg`
351
+ """
352
+ def apply_normal_decomposition(clique_tree):
353
+ degree_dict = clique_tree.root_hg.degrees()
354
+ min_deg_node = min(degree_dict, key=degree_dict.get)
355
+ min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
356
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
357
+ return clique_tree, False
358
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
359
+ return clique_tree, True
360
+
361
+ def apply_min_edge_deg_decomposition(clique_tree):
362
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
363
+ non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
364
+ if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
365
+ if not non_tmp_edge_list:
366
+ return clique_tree, False
367
+ min_deg_edge = None
368
+ min_deg = np.inf
369
+ for each_edge in non_tmp_edge_list:
370
+ if min_deg > edge_degree_dict[each_edge]:
371
+ min_deg_edge = each_edge
372
+ min_deg = edge_degree_dict[each_edge]
373
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
374
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
375
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
376
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
377
+ return clique_tree, False
378
+ clique_tree.update(min_deg_subhg)
379
+ return clique_tree, True
380
+
381
+ org_hg = hg.copy()
382
+ clique_tree = CliqueTree(org_hg)
383
+ clique_tree.add_node(0, subhg=org_hg)
384
+
385
+ success = True
386
+ while success:
387
+ clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
388
+ if not success:
389
+ clique_tree, success = apply_normal_decomposition(clique_tree)
390
+
391
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
392
+ if irredundant:
393
+ clique_tree.to_irredundant()
394
+ return clique_tree
395
+
396
+ def topological_tree_decomposition(
397
+ hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
398
+ ''' compute a tree decomposition of the input hypergraph
399
+
400
+ Parameters
401
+ ----------
402
+ hg : Hypergraph
403
+ hypergraph to be decomposed
404
+ irredundant : bool
405
+ if True, irredundant tree decomposition will be computed.
406
+
407
+ Returns
408
+ -------
409
+ clique_tree : CliqueTree
410
+ each node contains a subhypergraph of `hg`
411
+ '''
412
+ def _contract_tree(clique_tree):
413
+ ''' contract a single leaf
414
+
415
+ Parameters
416
+ ----------
417
+ clique_tree : CliqueTree
418
+
419
+ Returns
420
+ -------
421
+ CliqueTree, bool
422
+ bool represents whether this operation succeeds or not.
423
+ '''
424
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
425
+ leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
426
+ if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
427
+ and edge_degree_dict[each_edge] == 1]
428
+ if not leaf_edge_list:
429
+ return clique_tree, False
430
+ min_deg_edge = leaf_edge_list[0]
431
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
432
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
433
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
434
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
435
+ return clique_tree, False
436
+ clique_tree.update(min_deg_subhg)
437
+ return clique_tree, True
438
+
439
+ def _rip_labels_from_cycles(clique_tree, org_hg):
440
+ ''' rip hyperedge-labels off
441
+
442
+ Parameters
443
+ ----------
444
+ clique_tree : CliqueTree
445
+ org_hg : Hypergraph
446
+
447
+ Returns
448
+ -------
449
+ CliqueTree, bool
450
+ bool represents whether this operation succeeds or not.
451
+ '''
452
+ ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
453
+ for each_edge in clique_tree.root_hg.edges:
454
+ if each_edge in org_hg.edges:
455
+ if org_hg.in_cycle(each_edge):
456
+ node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
457
+ subhg = clique_tree.root_hg.get_subhg(
458
+ node_list, [each_edge], ident_node_dict)
459
+ if clique_tree.root_hg.nodes == subhg.nodes:
460
+ return clique_tree, False
461
+ clique_tree.update(subhg)
462
+ '''
463
+ in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
464
+ if not all(in_cycle_dict.values()):
465
+ node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
466
+ node_list = [node_not_in_cycle]
467
+ node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
468
+ edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
469
+ import pdb; pdb.set_trace()
470
+ subhg = clique_tree.root_hg.get_subhg(
471
+ node_list, edge_list, ident_node_dict)
472
+
473
+ clique_tree.update(subhg)
474
+ '''
475
+ return clique_tree, True
476
+ return clique_tree, False
477
+
478
+ def _shrink_cycle(clique_tree):
479
+ ''' shrink a cycle
480
+
481
+ Parameters
482
+ ----------
483
+ clique_tree : CliqueTree
484
+
485
+ Returns
486
+ -------
487
+ CliqueTree, bool
488
+ bool represents whether this operation succeeds or not.
489
+ '''
490
+ def filter_subhg(subhg, hg, key_node):
491
+ num_nodes_cycle = 0
492
+ nodes_in_cycle_list = []
493
+ for each_node in subhg.nodes:
494
+ if hg.in_cycle(each_node):
495
+ num_nodes_cycle += 1
496
+ if each_node != key_node:
497
+ nodes_in_cycle_list.append(each_node)
498
+ if num_nodes_cycle > 3:
499
+ break
500
+ if num_nodes_cycle != 3:
501
+ return False
502
+ else:
503
+ for each_edge in hg.edges:
504
+ if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
505
+ return False
506
+ return True
507
+
508
+ #ident_node_dict = hg.get_identical_node_dict()
509
+ ident_node_dict = clique_tree.ident_node_dict
510
+ for each_node in clique_tree.root_hg.nodes:
511
+ if clique_tree.root_hg.in_cycle(each_node)\
512
+ and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
513
+ clique_tree.root_hg,
514
+ each_node):
515
+ target_node = each_node
516
+ target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
517
+ if clique_tree.root_hg.nodes == target_subhg.nodes:
518
+ return clique_tree, False
519
+ clique_tree.update(target_subhg)
520
+ return clique_tree, True
521
+ return clique_tree, False
522
+
523
+ def _contract_cycles(clique_tree):
524
+ '''
525
+ remove a subhypergraph that looks like a cycle on a leaf.
526
+
527
+ Parameters
528
+ ----------
529
+ clique_tree : CliqueTree
530
+
531
+ Returns
532
+ -------
533
+ CliqueTree, bool
534
+ bool represents whether this operation succeeds or not.
535
+ '''
536
+ def _divide_hg(hg):
537
+ ''' divide a hypergraph into subhypergraphs such that
538
+ each subhypergraph is connected to each other in a tree-like way.
539
+
540
+ Parameters
541
+ ----------
542
+ hg : Hypergraph
543
+
544
+ Returns
545
+ -------
546
+ list of Hypergraphs
547
+ each element corresponds to a subhypergraph of `hg`
548
+ '''
549
+ for each_node in hg.nodes:
550
+ if hg.is_dividable(each_node):
551
+ adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
552
+ '''
553
+ if any(adj_edges_dict.values()):
554
+ import pdb; pdb.set_trace()
555
+ edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
556
+ subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
557
+ return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
558
+ else:
559
+ '''
560
+ subhg1, subhg2 = hg.divide(each_node)
561
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
562
+ return [hg]
563
+
564
+ def _is_leaf(hg, divided_subhg) -> bool:
565
+ ''' judge whether subhg is a leaf-like in the original hypergraph
566
+
567
+ Parameters
568
+ ----------
569
+ hg : Hypergraph
570
+ divided_subhg : Hypergraph
571
+ `divided_subhg` is a subhypergraph of `hg`
572
+
573
+ Returns
574
+ -------
575
+ bool
576
+ '''
577
+ '''
578
+ adj_edges_set = set([])
579
+ for each_node in divided_subhg.nodes:
580
+ adj_edges_set.update(set(hg.adj_edges(each_node)))
581
+
582
+
583
+ _hg = deepcopy(hg)
584
+ _hg.remove_subhg(divided_subhg)
585
+ if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
586
+ import pdb; pdb.set_trace()
587
+ return len(adj_edges_set - divided_subhg.edges) == 1
588
+ '''
589
+ _hg = deepcopy(hg)
590
+ _hg.remove_subhg(divided_subhg)
591
+ return nx.is_connected(_hg.hg)
592
+
593
+ subhg_list = _divide_hg(clique_tree.root_hg)
594
+ if len(subhg_list) == 1:
595
+ return clique_tree, False
596
+ else:
597
+ while len(subhg_list) > 1:
598
+ max_leaf_subhg = None
599
+ for each_subhg in subhg_list:
600
+ if _is_leaf(clique_tree.root_hg, each_subhg):
601
+ if max_leaf_subhg is None:
602
+ max_leaf_subhg = each_subhg
603
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
604
+ max_leaf_subhg = each_subhg
605
+ clique_tree.update(max_leaf_subhg)
606
+ subhg_list.remove(max_leaf_subhg)
607
+ return clique_tree, True
608
+
609
+ org_hg = hg.copy()
610
+ clique_tree = CliqueTree(org_hg)
611
+ clique_tree.add_node(0, subhg=org_hg)
612
+
613
+ success = True
614
+ while success:
615
+ '''
616
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
617
+ if not success:
618
+ clique_tree, success = _contract_cycles(clique_tree)
619
+ '''
620
+ clique_tree, success = _contract_tree(clique_tree)
621
+ if not success:
622
+ if rip_labels:
623
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
624
+ if not success:
625
+ if shrink_cycle:
626
+ clique_tree, success = _shrink_cycle(clique_tree)
627
+ if not success:
628
+ if contract_cycles:
629
+ clique_tree, success = _contract_cycles(clique_tree)
630
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
631
+ if irredundant:
632
+ clique_tree.to_irredundant()
633
+ return clique_tree
634
+
635
+ def molecular_tree_decomposition(hg, irredundant=True):
636
+ """ compute a tree decomposition of the input molecular hypergraph
637
+
638
+ Parameters
639
+ ----------
640
+ hg : Hypergraph
641
+ molecular hypergraph to be decomposed
642
+ irredundant : bool
643
+ if True, irredundant tree decomposition will be computed.
644
+
645
+ Returns
646
+ -------
647
+ clique_tree : CliqueTree
648
+ each node contains a subhypergraph of `hg`
649
+ """
650
+ def _divide_hg(hg):
651
+ ''' divide a hypergraph into subhypergraphs such that
652
+ each subhypergraph is connected to each other in a tree-like way.
653
+
654
+ Parameters
655
+ ----------
656
+ hg : Hypergraph
657
+
658
+ Returns
659
+ -------
660
+ list of Hypergraphs
661
+ each element corresponds to a subhypergraph of `hg`
662
+ '''
663
+ is_ring = False
664
+ for each_node in hg.nodes:
665
+ if hg.node_attr(each_node)['is_in_ring']:
666
+ is_ring = True
667
+ if not hg.node_attr(each_node)['is_in_ring'] \
668
+ and hg.degree(each_node) == 2:
669
+ subhg1, subhg2 = hg.divide(each_node)
670
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
671
+
672
+ if is_ring:
673
+ subhg_list = []
674
+ remove_edge_list = []
675
+ remove_node_list = []
676
+ for each_edge in hg.edges:
677
+ node_list = hg.nodes_in_edge(each_edge)
678
+ subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
679
+ subhg_list.append(subhg)
680
+ remove_edge_list.append(each_edge)
681
+ for each_node in node_list:
682
+ if not hg.node_attr(each_node)['is_in_ring']:
683
+ remove_node_list.append(each_node)
684
+ hg.remove_edges(remove_edge_list)
685
+ hg.remove_nodes(remove_node_list, False)
686
+ return subhg_list + [hg]
687
+ else:
688
+ return [hg]
689
+
690
+ org_hg = hg.copy()
691
+ clique_tree = CliqueTree(org_hg)
692
+ clique_tree.add_node(0, subhg=org_hg)
693
+
694
+ subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
695
+ #_subhg_list = deepcopy(subhg_list)
696
+ if len(subhg_list) == 1:
697
+ pass
698
+ else:
699
+ while len(subhg_list) > 1:
700
+ max_leaf_subhg = None
701
+ for each_subhg in subhg_list:
702
+ if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
703
+ if max_leaf_subhg is None:
704
+ max_leaf_subhg = each_subhg
705
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
706
+ max_leaf_subhg = each_subhg
707
+
708
+ if max_leaf_subhg is None:
709
+ for each_subhg in subhg_list:
710
+ if _is_ring_label(clique_tree.root_hg, each_subhg):
711
+ if max_leaf_subhg is None:
712
+ max_leaf_subhg = each_subhg
713
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
714
+ max_leaf_subhg = each_subhg
715
+ if max_leaf_subhg is not None:
716
+ clique_tree.update(max_leaf_subhg)
717
+ subhg_list.remove(max_leaf_subhg)
718
+ else:
719
+ for each_subhg in subhg_list:
720
+ if _is_leaf(clique_tree.root_hg, each_subhg):
721
+ if max_leaf_subhg is None:
722
+ max_leaf_subhg = each_subhg
723
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
724
+ max_leaf_subhg = each_subhg
725
+ if max_leaf_subhg is not None:
726
+ clique_tree.update(max_leaf_subhg, True)
727
+ subhg_list.remove(max_leaf_subhg)
728
+ else:
729
+ break
730
+ if len(subhg_list) > 1:
731
+ '''
732
+ for each_idx, each_subhg in enumerate(subhg_list):
733
+ each_subhg.draw(f'{each_idx}', True)
734
+ clique_tree.root_hg.draw('root', True)
735
+ import pickle
736
+ with open('buggy_hg.pkl', 'wb') as f:
737
+ pickle.dump(hg, f)
738
+ return clique_tree, subhg_list, _subhg_list
739
+ '''
740
+ raise RuntimeError('bug in tree decomposition algorithm')
741
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
742
+
743
+ '''
744
+ for each_tree_node in clique_tree.adj[0]:
745
+ subhg = clique_tree.nodes[each_tree_node]['subhg']
746
+ for each_edge in subhg.edges:
747
+ if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
748
+ clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
749
+ '''
750
+ if irredundant:
751
+ clique_tree.to_irredundant()
752
+ return clique_tree #, _subhg_list
753
+
754
+ def _is_leaf(hg, subhg) -> bool:
755
+ ''' judge whether subhg is a leaf-like in the original hypergraph
756
+
757
+ Parameters
758
+ ----------
759
+ hg : Hypergraph
760
+ subhg : Hypergraph
761
+ `subhg` is a subhypergraph of `hg`
762
+
763
+ Returns
764
+ -------
765
+ bool
766
+ '''
767
+ if len(subhg.edges) == 0:
768
+ adj_edge_set = set([])
769
+ subhg_edge_set = set([])
770
+ for each_edge in hg.edges:
771
+ if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
772
+ subhg_edge_set.add(each_edge)
773
+ for each_node in subhg.nodes:
774
+ adj_edge_set.update(set(hg.adj_edges(each_node)))
775
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
776
+ return True
777
+ else:
778
+ return False
779
+ elif len(subhg.edges) == 1:
780
+ adj_edge_set = set([])
781
+ subhg_edge_set = subhg.edges
782
+ for each_node in subhg.nodes:
783
+ for each_adj_edge in hg.adj_edges(each_node):
784
+ adj_edge_set.add(each_adj_edge)
785
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
786
+ return True
787
+ else:
788
+ return False
789
+ else:
790
+ raise ValueError('subhg should be nodes only or one-edge hypergraph.')
791
+
792
+ def _is_ring_label(hg, subhg):
793
+ if len(subhg.edges) != 1:
794
+ return False
795
+ edge_name = list(subhg.edges)[0]
796
+ #assert edge_name in hg.edges, f'{edge_name}'
797
+ is_in_ring = False
798
+ for each_node in subhg.nodes:
799
+ if subhg.node_attr(each_node)['is_in_ring']:
800
+ is_in_ring = True
801
+ else:
802
+ adj_edge_list = list(hg.adj_edges(each_node))
803
+ adj_edge_list.remove(edge_name)
804
+ if len(adj_edge_list) == 1:
805
+ if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
806
+ return False
807
+ elif len(adj_edge_list) == 0:
808
+ pass
809
+ else:
810
+ raise ValueError
811
+ if is_in_ring:
812
+ return True
813
+ else:
814
+ return False
815
+
816
+ def _is_ring(hg):
817
+ for each_node in hg.nodes:
818
+ if not hg.node_attr(each_node)['is_in_ring']:
819
+ return False
820
+ return True
821
+
graph_grammar/graph_grammar/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (680 Bytes). View file
 
graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc ADDED
Binary file (29.1 kB). View file
 
graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (5.38 kB). View file
 
graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
graph_grammar/graph_grammar/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from abc import ABCMeta, abstractmethod
22
+
23
+ class GraphGrammarBase(metaclass=ABCMeta):
24
+ @abstractmethod
25
+ def learn(self):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def sample(self):
30
+ pass
graph_grammar/graph_grammar/corpus.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from collections import Counter
22
+ from functools import partial
23
+ from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
24
+ from networkx.algorithms.isomorphism import GraphMatcher
25
+ import os
26
+
27
+
28
+ class CliqueTreeCorpus(object):
29
+
30
+ ''' clique tree corpus
31
+
32
+ Attributes
33
+ ----------
34
+ clique_tree_list : list of CliqueTree
35
+ subhg_list : list of Hypergraph
36
+ '''
37
+
38
+ def __init__(self):
39
+ self.clique_tree_list = []
40
+ self.subhg_list = []
41
+
42
+ @property
43
+ def size(self):
44
+ return len(self.subhg_list)
45
+
46
+ def add_clique_tree(self, clique_tree):
47
+ for each_node in clique_tree.nodes:
48
+ subhg = clique_tree.nodes[each_node]['subhg']
49
+ subhg_idx = self.add_subhg(subhg)
50
+ clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
51
+ self.clique_tree_list.append(clique_tree)
52
+
53
+ def add_to_subhg_list(self, clique_tree, root_node):
54
+ parent_node_dict = {}
55
+ current_node = None
56
+ parent_node_dict[root_node] = None
57
+ stack = [root_node]
58
+ while stack:
59
+ current_node = stack.pop()
60
+ current_subhg = clique_tree.nodes[current_node]['subhg']
61
+ for each_child in clique_tree.adj[current_node]:
62
+ if each_child != parent_node_dict[current_node]:
63
+ stack.append(each_child)
64
+ parent_node_dict[each_child] = current_node
65
+ if parent_node_dict[current_node] is not None:
66
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
67
+ common, _ = common_node_list(parent_subhg, current_subhg)
68
+ parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
69
+
70
+ parent_node_dict = {}
71
+ current_node = None
72
+ parent_node_dict[root_node] = None
73
+ stack = [root_node]
74
+ while stack:
75
+ current_node = stack.pop()
76
+ current_subhg = clique_tree.nodes[current_node]['subhg']
77
+ for each_child in clique_tree.adj[current_node]:
78
+ if each_child != parent_node_dict[current_node]:
79
+ stack.append(each_child)
80
+ parent_node_dict[each_child] = current_node
81
+ if parent_node_dict[current_node] is not None:
82
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
83
+ common, _ = common_node_list(parent_subhg, current_subhg)
84
+ for each_idx, each_node in enumerate(common):
85
+ current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
86
+
87
+ subhg_idx, is_new = self.add_subhg(current_subhg)
88
+ clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
89
+ return clique_tree
90
+
91
+ def add_subhg(self, subhg):
92
+ if len(self.subhg_list) == 0:
93
+ node_dict = {}
94
+ for each_node in subhg.nodes:
95
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
96
+ node_list = []
97
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
98
+ node_list.append(each_key)
99
+ for each_idx, each_node in enumerate(node_list):
100
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
101
+ self.subhg_list.append(subhg)
102
+ return 0, True
103
+ else:
104
+ match = False
105
+ subhg_bond_symbol_counter \
106
+ = Counter([subhg.node_attr(each_node)['symbol'] \
107
+ for each_node in subhg.nodes])
108
+ subhg_atom_symbol_counter \
109
+ = Counter([subhg.edge_attr(each_edge).get('symbol', None) \
110
+ for each_edge in subhg.edges])
111
+ for each_idx, each_subhg in enumerate(self.subhg_list):
112
+ each_bond_symbol_counter \
113
+ = Counter([each_subhg.node_attr(each_node)['symbol'] \
114
+ for each_node in each_subhg.nodes])
115
+ each_atom_symbol_counter \
116
+ = Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
117
+ for each_edge in each_subhg.edges])
118
+ if not match \
119
+ and (subhg.num_nodes == each_subhg.num_nodes
120
+ and subhg.num_edges == each_subhg.num_edges
121
+ and subhg_bond_symbol_counter == each_bond_symbol_counter
122
+ and subhg_atom_symbol_counter == each_atom_symbol_counter):
123
+ gm = GraphMatcher(each_subhg.hg,
124
+ subhg.hg,
125
+ node_match=_easy_node_match,
126
+ edge_match=_edge_match)
127
+ try:
128
+ isomap = next(gm.isomorphisms_iter())
129
+ match = True
130
+ for each_node in each_subhg.nodes:
131
+ subhg.node_attr(isomap[each_node])['order4hrg'] \
132
+ = each_subhg.node_attr(each_node)['order4hrg']
133
+ if 'ext_id' in each_subhg.node_attr(each_node):
134
+ subhg.node_attr(isomap[each_node])['ext_id'] \
135
+ = each_subhg.node_attr(each_node)['ext_id']
136
+ return each_idx, False
137
+ except StopIteration:
138
+ match = False
139
+ if not match:
140
+ node_dict = {}
141
+ for each_node in subhg.nodes:
142
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
143
+ node_list = []
144
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
145
+ node_list.append(each_key)
146
+ for each_idx, each_node in enumerate(node_list):
147
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
148
+
149
+ #for each_idx, each_node in enumerate(subhg.nodes):
150
+ # subhg.node_attr(each_node)['order4hrg'] = each_idx
151
+ self.subhg_list.append(subhg)
152
+ return len(self.subhg_list) - 1, True
graph_grammar/graph_grammar/hrg.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from .corpus import CliqueTreeCorpus
22
+ from .base import GraphGrammarBase
23
+ from .symbols import TSymbol, NTSymbol, BondSymbol
24
+ from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
25
+ from ..hypergraph import Hypergraph
26
+ from collections import Counter
27
+ from copy import deepcopy
28
+ from ..algo.tree_decomposition import (
29
+ tree_decomposition,
30
+ tree_decomposition_with_hrg,
31
+ tree_decomposition_from_leaf,
32
+ topological_tree_decomposition,
33
+ molecular_tree_decomposition)
34
+ from functools import partial
35
+ from networkx.algorithms.isomorphism import GraphMatcher
36
+ from typing import List, Dict, Tuple
37
+ import networkx as nx
38
+ import numpy as np
39
+ import torch
40
+ import os
41
+ import random
42
+
43
+ DEBUG = False
44
+
45
+
46
+ class ProductionRule(object):
47
+ """ A class of a production rule
48
+
49
+ Attributes
50
+ ----------
51
+ lhs : Hypergraph or None
52
+ the left hand side of the production rule.
53
+ if None, the rule is a starting rule.
54
+ rhs : Hypergraph
55
+ the right hand side of the production rule.
56
+ """
57
+ def __init__(self, lhs, rhs):
58
+ self.lhs = lhs
59
+ self.rhs = rhs
60
+
61
+ @property
62
+ def is_start_rule(self) -> bool:
63
+ return self.lhs.num_nodes == 0
64
+
65
+ @property
66
+ def ext_node(self) -> Dict[int, str]:
67
+ """ return a dict of external nodes
68
+ """
69
+ if self.is_start_rule:
70
+ return {}
71
+ else:
72
+ ext_node_dict = {}
73
+ for each_node in self.lhs.nodes:
74
+ ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
75
+ return ext_node_dict
76
+
77
+ @property
78
+ def lhs_nt_symbol(self) -> NTSymbol:
79
+ if self.is_start_rule:
80
+ return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
81
+ else:
82
+ return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
83
+
84
+ def rhs_adj_mat(self, node_edge_list):
85
+ ''' return the adjacency matrix of rhs of the production rule
86
+ '''
87
+ return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
88
+
89
+ def draw(self, file_path=None):
90
+ return self.rhs.draw(file_path)
91
+
92
+ def is_same(self, prod_rule, ignore_order=False):
93
+ """ judge whether this production rule is
94
+ the same as the input one, `prod_rule`
95
+
96
+ Parameters
97
+ ----------
98
+ prod_rule : ProductionRule
99
+ production rule to be compared
100
+
101
+ Returns
102
+ -------
103
+ is_same : bool
104
+ isomap : dict
105
+ isomorphism of nodes and hyperedges.
106
+ ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
107
+ 'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
108
+ 'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
109
+ key comes from `prod_rule`, value comes from `self`.
110
+ """
111
+ if self.is_start_rule:
112
+ if not prod_rule.is_start_rule:
113
+ return False, {}
114
+ else:
115
+ if prod_rule.is_start_rule:
116
+ return False, {}
117
+ else:
118
+ if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
119
+ return False, {}
120
+
121
+ if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
122
+ return False, {}
123
+ if prod_rule.rhs.num_edges != self.rhs.num_edges:
124
+ return False, {}
125
+
126
+ subhg_bond_symbol_counter \
127
+ = Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
128
+ for each_node in prod_rule.rhs.nodes])
129
+ each_bond_symbol_counter \
130
+ = Counter([self.rhs.node_attr(each_node)['symbol'] \
131
+ for each_node in self.rhs.nodes])
132
+ if subhg_bond_symbol_counter != each_bond_symbol_counter:
133
+ return False, {}
134
+
135
+ subhg_atom_symbol_counter \
136
+ = Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
137
+ for each_edge in prod_rule.rhs.edges])
138
+ each_atom_symbol_counter \
139
+ = Counter([self.rhs.edge_attr(each_edge)['symbol'] \
140
+ for each_edge in self.rhs.edges])
141
+ if subhg_atom_symbol_counter != each_atom_symbol_counter:
142
+ return False, {}
143
+
144
+ gm = GraphMatcher(prod_rule.rhs.hg,
145
+ self.rhs.hg,
146
+ partial(_node_match_prod_rule,
147
+ ignore_order=ignore_order),
148
+ partial(_edge_match,
149
+ ignore_order=ignore_order))
150
+ try:
151
+ return True, next(gm.isomorphisms_iter())
152
+ except StopIteration:
153
+ return False, {}
154
+
155
+ def applied_to(self,
156
+ hg: Hypergraph,
157
+ edge: str) -> Tuple[Hypergraph, List[str]]:
158
+ """ augment `hg` by replacing `edge` with `self.rhs`.
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+ edge : str
164
+ `edge` must belong to `hg`
165
+
166
+ Returns
167
+ -------
168
+ hg : Hypergraph
169
+ resultant hypergraph
170
+ nt_edge_list : list
171
+ list of non-terminal edges
172
+ """
173
+ nt_edge_dict = {}
174
+ if self.is_start_rule:
175
+ if (edge is not None) or (hg is not None):
176
+ ValueError("edge and hg must be None for this prod rule.")
177
+ hg = Hypergraph()
178
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
179
+ for num_idx, each_node in enumerate(self.rhs.nodes):
180
+ hg.add_node(f"bond_{num_idx}",
181
+ #attr_dict=deepcopy(self.rhs.node_attr(each_node)))
182
+ attr_dict=self.rhs.node_attr(each_node))
183
+ node_map_rhs[each_node] = f"bond_{num_idx}"
184
+ for each_edge in self.rhs.edges:
185
+ node_list = []
186
+ for each_node in self.rhs.nodes_in_edge(each_edge):
187
+ node_list.append(node_map_rhs[each_node])
188
+ if isinstance(self.rhs.nodes_in_edge(each_edge), set):
189
+ node_list = set(node_list)
190
+ edge_id = hg.add_edge(
191
+ node_list,
192
+ #attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
193
+ attr_dict=self.rhs.edge_attr(each_edge))
194
+ if "nt_idx" in hg.edge_attr(edge_id):
195
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
196
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
197
+ return hg, nt_edge_list
198
+ else:
199
+ if edge not in hg.edges:
200
+ raise ValueError("the input hyperedge does not exist.")
201
+ if hg.edge_attr(edge)["terminal"]:
202
+ raise ValueError("the input hyperedge is terminal.")
203
+ if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
204
+ print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
205
+ raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
206
+ if DEBUG:
207
+ for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
208
+ other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
209
+ attr = deepcopy(self.lhs.node_attr(other_node))
210
+ attr.pop('ext_id')
211
+ if hg.node_attr(each_node) != attr:
212
+ raise ValueError('node attributes are inconsistent.')
213
+
214
+ # order of nodes that belong to the non-terminal edge in hg
215
+ nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
216
+ nt_order_dict_inv = {} # order -> hg_node
217
+ for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
218
+ nt_order_dict[each_node] = each_idx
219
+ nt_order_dict_inv[each_idx] = each_node
220
+
221
+ # construct a node_map_rhs: rhs -> new hg
222
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
223
+ node_idx = hg.num_nodes
224
+ for each_node in self.rhs.nodes:
225
+ if "ext_id" in self.rhs.node_attr(each_node):
226
+ node_map_rhs[each_node] \
227
+ = nt_order_dict_inv[
228
+ self.rhs.node_attr(each_node)["ext_id"]]
229
+ else:
230
+ node_map_rhs[each_node] = f"bond_{node_idx}"
231
+ node_idx += 1
232
+
233
+ # delete non-terminal
234
+ hg.remove_edge(edge)
235
+
236
+ # add nodes to hg
237
+ for each_node in self.rhs.nodes:
238
+ hg.add_node(node_map_rhs[each_node],
239
+ attr_dict=self.rhs.node_attr(each_node))
240
+
241
+ # add hyperedges to hg
242
+ for each_edge in self.rhs.edges:
243
+ node_list_hg = []
244
+ for each_node in self.rhs.nodes_in_edge(each_edge):
245
+ node_list_hg.append(node_map_rhs[each_node])
246
+ edge_id = hg.add_edge(
247
+ node_list_hg,
248
+ attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
249
+ if "nt_idx" in hg.edge_attr(edge_id):
250
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
251
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
252
+ return hg, nt_edge_list
253
+
254
+ def revert(self, hg: Hypergraph, return_subhg=False):
255
+ ''' revert applying this production rule.
256
+ i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
257
+ this method replaces the subhypergraph with a non-terminal hyperedge.
258
+
259
+ Parameters
260
+ ----------
261
+ hg : Hypergraph
262
+ hypergraph to be reverted
263
+ return_subhg : bool
264
+ if True, the removed subhypergraph will be returned.
265
+
266
+ Returns
267
+ -------
268
+ hg : Hypergraph
269
+ the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
270
+ success : bool
271
+ this indicates whether reverting is successed or not.
272
+ '''
273
+ gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
274
+ edge_match=_edge_match)
275
+ try:
276
+ # in case when the matched subhg is connected to the other part via external nodes and more.
277
+ not_iso = True
278
+ while not_iso:
279
+ isomap = next(gm.subgraph_isomorphisms_iter())
280
+ adj_node_set = set([]) # reachable nodes from the internal nodes
281
+ subhg_node_set = set(isomap.keys()) # nodes in subhg
282
+ for each_node in subhg_node_set:
283
+ adj_node_set.add(each_node)
284
+ if isomap[each_node] not in self.ext_node.values():
285
+ adj_node_set.update(hg.hg.adj[each_node])
286
+ if adj_node_set == subhg_node_set:
287
+ not_iso = False
288
+ else:
289
+ if return_subhg:
290
+ return hg, False, Hypergraph()
291
+ else:
292
+ return hg, False
293
+ inv_isomap = {v: k for k, v in isomap.items()}
294
+ '''
295
+ isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
296
+ 'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
297
+ where keys come from `hg` and values come from `self.rhs`
298
+ '''
299
+ except StopIteration:
300
+ if return_subhg:
301
+ return hg, False, Hypergraph()
302
+ else:
303
+ return hg, False
304
+
305
+ if return_subhg:
306
+ subhg = Hypergraph()
307
+ for each_node in hg.nodes:
308
+ if each_node in isomap:
309
+ subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
310
+ for each_edge in hg.edges:
311
+ if each_edge in isomap:
312
+ subhg.add_edge(hg.nodes_in_edge(each_edge),
313
+ attr_dict=hg.edge_attr(each_edge),
314
+ edge_name=each_edge)
315
+ subhg.edge_idx = hg.edge_idx
316
+
317
+ # remove subhg except for the externael nodes
318
+ for each_key, each_val in isomap.items():
319
+ if each_key.startswith('e'):
320
+ hg.remove_edge(each_key)
321
+ for each_key, each_val in isomap.items():
322
+ if each_key.startswith('bond_'):
323
+ if each_val not in self.ext_node.values():
324
+ hg.remove_node(each_key)
325
+
326
+ # add non-terminal hyperedge
327
+ nt_node_list = []
328
+ for each_ext_id in self.ext_node.keys():
329
+ nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
330
+
331
+ hg.add_edge(nt_node_list,
332
+ attr_dict=dict(
333
+ terminal=False,
334
+ symbol=self.lhs_nt_symbol))
335
+ if return_subhg:
336
+ return hg, True, subhg
337
+ else:
338
+ return hg, True
339
+
340
+
341
+ class ProductionRuleCorpus(object):
342
+
343
+ '''
344
+ A corpus of production rules.
345
+ This class maintains
346
+ (i) list of unique production rules,
347
+ (ii) list of unique edge symbols (both terminal and non-terminal), and
348
+ (iii) list of unique node symbols.
349
+
350
+ Attributes
351
+ ----------
352
+ prod_rule_list : list
353
+ list of unique production rules
354
+ edge_symbol_list : list
355
+ list of unique symbols (including both terminal and non-terminal)
356
+ node_symbol_list : list
357
+ list of node symbols
358
+ nt_symbol_list : list
359
+ list of unique lhs symbols
360
+ ext_id_list : list
361
+ list of ext_ids
362
+ lhs_in_prod_rule : array
363
+ a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
364
+ '''
365
+
366
+ def __init__(self):
367
+ self.prod_rule_list = []
368
+ self.edge_symbol_list = []
369
+ self.edge_symbol_dict = {}
370
+ self.node_symbol_list = []
371
+ self.node_symbol_dict = {}
372
+ self.nt_symbol_list = []
373
+ self.ext_id_list = []
374
+ self._lhs_in_prod_rule = None
375
+ self.lhs_in_prod_rule_row_list = []
376
+ self.lhs_in_prod_rule_col_list = []
377
+
378
+ @property
379
+ def lhs_in_prod_rule(self):
380
+ if self._lhs_in_prod_rule is None:
381
+ self._lhs_in_prod_rule = torch.sparse.FloatTensor(
382
+ torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
383
+ torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
384
+ torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
385
+ ).to_dense()
386
+ return self._lhs_in_prod_rule
387
+
388
+ @property
389
+ def num_prod_rule(self):
390
+ ''' return the number of production rules
391
+
392
+ Returns
393
+ -------
394
+ int : the number of unique production rules
395
+ '''
396
+ return len(self.prod_rule_list)
397
+
398
+ @property
399
+ def start_rule_list(self):
400
+ ''' return a list of start rules
401
+
402
+ Returns
403
+ -------
404
+ list : list of start rules
405
+ '''
406
+ start_rule_list = []
407
+ for each_prod_rule in self.prod_rule_list:
408
+ if each_prod_rule.is_start_rule:
409
+ start_rule_list.append(each_prod_rule)
410
+ return start_rule_list
411
+
412
+ @property
413
+ def num_edge_symbol(self):
414
+ return len(self.edge_symbol_list)
415
+
416
+ @property
417
+ def num_node_symbol(self):
418
+ return len(self.node_symbol_list)
419
+
420
+ @property
421
+ def num_ext_id(self):
422
+ return len(self.ext_id_list)
423
+
424
+ def construct_feature_vectors(self):
425
+ ''' this method constructs feature vectors for the production rules collected so far.
426
+ currently, NTSymbol and TSymbol are treated in the same manner.
427
+ '''
428
+ feature_id_dict = {}
429
+ feature_id_dict['TSymbol'] = 0
430
+ feature_id_dict['NTSymbol'] = 1
431
+ feature_id_dict['BondSymbol'] = 2
432
+ for each_edge_symbol in self.edge_symbol_list:
433
+ for each_attr in each_edge_symbol.__dict__.keys():
434
+ each_val = each_edge_symbol.__dict__[each_attr]
435
+ if isinstance(each_val, list):
436
+ each_val = tuple(each_val)
437
+ if (each_attr, each_val) not in feature_id_dict:
438
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
439
+
440
+ for each_node_symbol in self.node_symbol_list:
441
+ for each_attr in each_node_symbol.__dict__.keys():
442
+ each_val = each_node_symbol.__dict__[each_attr]
443
+ if isinstance(each_val, list):
444
+ each_val = tuple(each_val)
445
+ if (each_attr, each_val) not in feature_id_dict:
446
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
447
+ for each_ext_id in self.ext_id_list:
448
+ feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
449
+ dim = len(feature_id_dict)
450
+
451
+ feature_dict = {}
452
+ for each_edge_symbol in self.edge_symbol_list:
453
+ idx_list = []
454
+ idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
455
+ for each_attr in each_edge_symbol.__dict__.keys():
456
+ each_val = each_edge_symbol.__dict__[each_attr]
457
+ if isinstance(each_val, list):
458
+ each_val = tuple(each_val)
459
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
460
+ feature = torch.sparse.LongTensor(
461
+ torch.LongTensor([idx_list]),
462
+ torch.ones(len(idx_list)),
463
+ torch.Size([len(feature_id_dict)])
464
+ )
465
+ feature_dict[each_edge_symbol] = feature
466
+
467
+ for each_node_symbol in self.node_symbol_list:
468
+ idx_list = []
469
+ idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
470
+ for each_attr in each_node_symbol.__dict__.keys():
471
+ each_val = each_node_symbol.__dict__[each_attr]
472
+ if isinstance(each_val, list):
473
+ each_val = tuple(each_val)
474
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
475
+ feature = torch.sparse.LongTensor(
476
+ torch.LongTensor([idx_list]),
477
+ torch.ones(len(idx_list)),
478
+ torch.Size([len(feature_id_dict)])
479
+ )
480
+ feature_dict[each_node_symbol] = feature
481
+ for each_ext_id in self.ext_id_list:
482
+ idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
483
+ feature_dict[('ext_id', each_ext_id)] \
484
+ = torch.sparse.LongTensor(
485
+ torch.LongTensor([idx_list]),
486
+ torch.ones(len(idx_list)),
487
+ torch.Size([len(feature_id_dict)])
488
+ )
489
+ return feature_dict, dim
490
+
491
+ def edge_symbol_idx(self, symbol):
492
+ return self.edge_symbol_dict[symbol]
493
+
494
+ def node_symbol_idx(self, symbol):
495
+ return self.node_symbol_dict[symbol]
496
+
497
+ def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
498
+ """ return whether the input production rule is new or not, and its production rule id.
499
+ Production rules are regarded as the same if
500
+ i) there exists a one-to-one mapping of nodes and edges, and
501
+ ii) all the attributes associated with nodes and hyperedges are the same.
502
+
503
+ Parameters
504
+ ----------
505
+ prod_rule : ProductionRule
506
+
507
+ Returns
508
+ -------
509
+ prod_rule_id : int
510
+ production rule index. if new, a new index will be assigned.
511
+ prod_rule : ProductionRule
512
+ """
513
+ num_lhs = len(self.nt_symbol_list)
514
+ for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
515
+ is_same, isomap = prod_rule.is_same(each_prod_rule)
516
+ if is_same:
517
+ # we do not care about edge and node names, but care about the order of non-terminal edges.
518
+ for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
519
+ if key.startswith("bond_"):
520
+ continue
521
+
522
+ # rewrite `nt_idx` in `prod_rule` for further processing
523
+ if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
524
+ if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
525
+ raise ValueError
526
+ prod_rule.rhs.set_edge_attr(
527
+ val,
528
+ {'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
529
+ return each_idx, prod_rule
530
+ self.prod_rule_list.append(prod_rule)
531
+ self._update_edge_symbol_list(prod_rule)
532
+ self._update_node_symbol_list(prod_rule)
533
+ self._update_ext_id_list(prod_rule)
534
+
535
+ lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
536
+ self.lhs_in_prod_rule_row_list.append(lhs_idx)
537
+ self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
538
+ self._lhs_in_prod_rule = None
539
+ return len(self.prod_rule_list)-1, prod_rule
540
+
541
+ def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
542
+ return self.prod_rule_list[prod_rule_idx]
543
+
544
+ def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
545
+ ''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
546
+
547
+ Parameters
548
+ ----------
549
+ unmasked_logit_array : array-like, length `num_prod_rule`
550
+ nt_symbol : NTSymbol
551
+ '''
552
+ if not isinstance(unmasked_logit_array, np.ndarray):
553
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
554
+ if deterministic:
555
+ prob = masked_softmax(unmasked_logit_array,
556
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
557
+ return self.prod_rule_list[np.argmax(prob)]
558
+ else:
559
+ return np.random.choice(
560
+ self.prod_rule_list, 1,
561
+ p=masked_softmax(unmasked_logit_array,
562
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
563
+
564
+ def masked_logprob(self, unmasked_logit_array, nt_symbol):
565
+ if not isinstance(unmasked_logit_array, np.ndarray):
566
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
567
+ prob = masked_softmax(unmasked_logit_array,
568
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
569
+ return np.log(prob)
570
+
571
+ def _update_edge_symbol_list(self, prod_rule: ProductionRule):
572
+ ''' update edge symbol list
573
+
574
+ Parameters
575
+ ----------
576
+ prod_rule : ProductionRule
577
+ '''
578
+ if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
579
+ self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
580
+
581
+ for each_edge in prod_rule.rhs.edges:
582
+ if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
583
+ edge_symbol_idx = len(self.edge_symbol_list)
584
+ self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
585
+ self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
586
+ else:
587
+ edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
588
+ prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
589
+ pass
590
+
591
+ def _update_node_symbol_list(self, prod_rule: ProductionRule):
592
+ ''' update node symbol list
593
+
594
+ Parameters
595
+ ----------
596
+ prod_rule : ProductionRule
597
+ '''
598
+ for each_node in prod_rule.rhs.nodes:
599
+ if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
600
+ node_symbol_idx = len(self.node_symbol_list)
601
+ self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
602
+ self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
603
+ else:
604
+ node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
605
+ prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
606
+
607
+ def _update_ext_id_list(self, prod_rule: ProductionRule):
608
+ for each_node in prod_rule.rhs.nodes:
609
+ if 'ext_id' in prod_rule.rhs.node_attr(each_node):
610
+ if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
611
+ self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
612
+
613
+
614
+ class HyperedgeReplacementGrammar(GraphGrammarBase):
615
+ """
616
+ Learn a hyperedge replacement grammar from a set of hypergraphs.
617
+
618
+ Attributes
619
+ ----------
620
+ prod_rule_list : list of ProductionRule
621
+ production rules learned from the input hypergraphs
622
+ """
623
+ def __init__(self,
624
+ tree_decomposition=molecular_tree_decomposition,
625
+ ignore_order=False, **kwargs):
626
+ from functools import partial
627
+ self.prod_rule_corpus = ProductionRuleCorpus()
628
+ self.clique_tree_corpus = CliqueTreeCorpus()
629
+ self.ignore_order = ignore_order
630
+ self.tree_decomposition = partial(tree_decomposition, **kwargs)
631
+
632
+ @property
633
+ def num_prod_rule(self):
634
+ ''' return the number of production rules
635
+
636
+ Returns
637
+ -------
638
+ int : the number of unique production rules
639
+ '''
640
+ return self.prod_rule_corpus.num_prod_rule
641
+
642
+ @property
643
+ def start_rule_list(self):
644
+ ''' return a list of start rules
645
+
646
+ Returns
647
+ -------
648
+ list : list of start rules
649
+ '''
650
+ return self.prod_rule_corpus.start_rule_list
651
+
652
+ @property
653
+ def prod_rule_list(self):
654
+ return self.prod_rule_corpus.prod_rule_list
655
+
656
+ def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
657
+ """ learn from a list of hypergraphs
658
+
659
+ Parameters
660
+ ----------
661
+ hg_list : list of Hypergraph
662
+
663
+ Returns
664
+ -------
665
+ prod_rule_seq_list : list of integers
666
+ each element corresponds to a sequence of production rules to generate each hypergraph.
667
+ """
668
+ prod_rule_seq_list = []
669
+ idx = 0
670
+ for each_idx, each_hg in enumerate(hg_list):
671
+ clique_tree = self.tree_decomposition(each_hg)
672
+
673
+ # get a pair of myself and children
674
+ root_node = _find_root(clique_tree)
675
+ clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
676
+ prod_rule_seq = []
677
+ stack = []
678
+
679
+ children = sorted(list(clique_tree[root_node].keys()))
680
+
681
+ # extract a temporary production rule
682
+ prod_rule = extract_prod_rule(
683
+ None,
684
+ clique_tree.nodes[root_node]["subhg"],
685
+ [clique_tree.nodes[each_child]["subhg"]
686
+ for each_child in children],
687
+ clique_tree.nodes[root_node].get('subhg_idx', None))
688
+
689
+ # update the production rule list
690
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
691
+ children = reorder_children(root_node,
692
+ children,
693
+ prod_rule,
694
+ clique_tree)
695
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
696
+ prod_rule_seq.append(prod_rule_id)
697
+
698
+ while len(stack) != 0:
699
+ # get a triple of parent, myself, and children
700
+ parent, myself = stack.pop()
701
+ children = sorted(list(dict(clique_tree[myself]).keys()))
702
+ children.remove(parent)
703
+
704
+ # extract a temp prod rule
705
+ prod_rule = extract_prod_rule(
706
+ clique_tree.nodes[parent]["subhg"],
707
+ clique_tree.nodes[myself]["subhg"],
708
+ [clique_tree.nodes[each_child]["subhg"]
709
+ for each_child in children],
710
+ clique_tree.nodes[myself].get('subhg_idx', None))
711
+
712
+ # update the prod rule list
713
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
714
+ children = reorder_children(myself,
715
+ children,
716
+ prod_rule,
717
+ clique_tree)
718
+ stack.extend([(myself, each_child)
719
+ for each_child in children[::-1]])
720
+ prod_rule_seq.append(prod_rule_id)
721
+ prod_rule_seq_list.append(prod_rule_seq)
722
+ if (each_idx+1) % print_freq == 0:
723
+ msg = f'#(molecules processed)={each_idx+1}\t'\
724
+ f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
725
+ logger(msg)
726
+ if each_idx > max_mol:
727
+ break
728
+
729
+ print(f'corpus_size = {self.clique_tree_corpus.size}')
730
+ return prod_rule_seq_list
731
+
732
+ def sample(self, z, deterministic=False):
733
+ """ sample a new hypergraph from HRG.
734
+
735
+ Parameters
736
+ ----------
737
+ z : array-like, shape (len, num_prod_rule)
738
+ logit
739
+ deterministic : bool
740
+ if True, deterministic sampling
741
+
742
+ Returns
743
+ -------
744
+ Hypergraph
745
+ """
746
+ seq_idx = 0
747
+ stack = []
748
+ z = z[:, :-1]
749
+ init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
750
+ is_aromatic=False,
751
+ bond_symbol_list=[]),
752
+ deterministic=deterministic)
753
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
754
+ stack = deepcopy(nt_edge_list[::-1])
755
+ while len(stack) != 0 and seq_idx < z.shape[0]-1:
756
+ seq_idx += 1
757
+ nt_edge = stack.pop()
758
+ nt_symbol = hg.edge_attr(nt_edge)['symbol']
759
+ prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
760
+ hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
761
+ stack.extend(nt_edge_list[::-1])
762
+ if len(stack) != 0:
763
+ raise RuntimeError(f'{len(stack)} non-terminals are left.')
764
+ return hg
765
+
766
+ def construct(self, prod_rule_seq):
767
+ """ construct a hypergraph following `prod_rule_seq`
768
+
769
+ Parameters
770
+ ----------
771
+ prod_rule_seq : list of integers
772
+ a sequence of production rules.
773
+
774
+ Returns
775
+ -------
776
+ UndirectedHypergraph
777
+ """
778
+ seq_idx = 0
779
+ init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
780
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
781
+ stack = deepcopy(nt_edge_list[::-1])
782
+ while len(stack) != 0:
783
+ seq_idx += 1
784
+ nt_edge = stack.pop()
785
+ hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
786
+ stack.extend(nt_edge_list[::-1])
787
+ return hg
788
+
789
+ def update_prod_rule_list(self, prod_rule):
790
+ """ return whether the input production rule is new or not, and its production rule id.
791
+ Production rules are regarded as the same if
792
+ i) there exists a one-to-one mapping of nodes and edges, and
793
+ ii) all the attributes associated with nodes and hyperedges are the same.
794
+
795
+ Parameters
796
+ ----------
797
+ prod_rule : ProductionRule
798
+
799
+ Returns
800
+ -------
801
+ is_new : bool
802
+ if True, this production rule is new
803
+ prod_rule_id : int
804
+ production rule index. if new, a new index will be assigned.
805
+ """
806
+ return self.prod_rule_corpus.append(prod_rule)
807
+
808
+
809
+ class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
810
+ '''
811
+ This class learns HRG incrementally leveraging the previously obtained production rules.
812
+ '''
813
+ def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
814
+ self.prod_rule_list = []
815
+ self.tree_decomposition = tree_decomposition
816
+ self.ignore_order = ignore_order
817
+
818
+ def learn(self, hg_list):
819
+ """ learn from a list of hypergraphs
820
+
821
+ Parameters
822
+ ----------
823
+ hg_list : list of UndirectedHypergraph
824
+
825
+ Returns
826
+ -------
827
+ prod_rule_seq_list : list of integers
828
+ each element corresponds to a sequence of production rules to generate each hypergraph.
829
+ """
830
+ prod_rule_seq_list = []
831
+ for each_hg in hg_list:
832
+ clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
833
+
834
+ prod_rule_seq = []
835
+ stack = []
836
+
837
+ # get a pair of myself and children
838
+ children = sorted(list(clique_tree[root_node].keys()))
839
+
840
+ # extract a temporary production rule
841
+ prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
842
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
843
+
844
+ # update the production rule list
845
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
846
+ children = reorder_children(root_node, children, prod_rule, clique_tree)
847
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
848
+ prod_rule_seq.append(prod_rule_id)
849
+
850
+ while len(stack) != 0:
851
+ # get a triple of parent, myself, and children
852
+ parent, myself = stack.pop()
853
+ children = sorted(list(dict(clique_tree[myself]).keys()))
854
+ children.remove(parent)
855
+
856
+ # extract a temp prod rule
857
+ prod_rule = extract_prod_rule(
858
+ clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
859
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
860
+
861
+ # update the prod rule list
862
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
863
+ children = reorder_children(myself, children, prod_rule, clique_tree)
864
+ stack.extend([(myself, each_child) for each_child in children[::-1]])
865
+ prod_rule_seq.append(prod_rule_id)
866
+ prod_rule_seq_list.append(prod_rule_seq)
867
+ self._compute_stats()
868
+ return prod_rule_seq_list
869
+
870
+
871
+ def reorder_children(myself, children, prod_rule, clique_tree):
872
+ """ reorder children so that they match the order in `prod_rule`.
873
+
874
+ Parameters
875
+ ----------
876
+ myself : int
877
+ children : list of int
878
+ prod_rule : ProductionRule
879
+ clique_tree : nx.Graph
880
+
881
+ Returns
882
+ -------
883
+ new_children : list of str
884
+ reordered children
885
+ """
886
+ perm = {} # key : `nt_idx`, val : child
887
+ for each_edge in prod_rule.rhs.edges:
888
+ if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
889
+ for each_child in children:
890
+ common_node_set = set(
891
+ common_node_list(clique_tree.nodes[myself]["subhg"],
892
+ clique_tree.nodes[each_child]["subhg"])[0])
893
+ if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
894
+ assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
895
+ perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
896
+ new_children = []
897
+ assert len(perm) == len(children)
898
+ for i in range(len(perm)):
899
+ new_children.append(perm[i])
900
+ return new_children
901
+
902
+
903
+ def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
904
+ """ extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
905
+
906
+ Parameters
907
+ ----------
908
+ parent_hg : Hypergraph
909
+ myself_hg : Hypergraph
910
+ children_hg_list : list of Hypergraph
911
+
912
+ Returns
913
+ -------
914
+ ProductionRule, consisting of
915
+ lhs : Hypergraph or None
916
+ rhs : Hypergraph
917
+ """
918
+ def _add_ext_node(hg, ext_nodes):
919
+ """ mark nodes to be external (ordered ids are assigned)
920
+
921
+ Parameters
922
+ ----------
923
+ hg : UndirectedHypergraph
924
+ ext_nodes : list of str
925
+ list of external nodes
926
+
927
+ Returns
928
+ -------
929
+ hg : Hypergraph
930
+ nodes in `ext_nodes` are marked to be external
931
+ """
932
+ ext_id = 0
933
+ ext_id_exists = []
934
+ for each_node in ext_nodes:
935
+ ext_id_exists.append('ext_id' in hg.node_attr(each_node))
936
+ if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
937
+ raise ValueError
938
+ if not all(ext_id_exists):
939
+ for each_node in ext_nodes:
940
+ hg.node_attr(each_node)['ext_id'] = ext_id
941
+ ext_id += 1
942
+ return hg
943
+
944
+ def _check_aromatic(hg, node_list):
945
+ is_aromatic = False
946
+ node_aromatic_list = []
947
+ for each_node in node_list:
948
+ if hg.node_attr(each_node)['symbol'].is_aromatic:
949
+ is_aromatic = True
950
+ node_aromatic_list.append(True)
951
+ else:
952
+ node_aromatic_list.append(False)
953
+ return is_aromatic, node_aromatic_list
954
+
955
+ def _check_ring(hg):
956
+ for each_edge in hg.edges:
957
+ if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
958
+ return False
959
+ return True
960
+
961
+ if parent_hg is None:
962
+ lhs = Hypergraph()
963
+ node_list = []
964
+ else:
965
+ lhs = Hypergraph()
966
+ node_list, edge_exists = common_node_list(parent_hg, myself_hg)
967
+ for each_node in node_list:
968
+ lhs.add_node(each_node,
969
+ deepcopy(myself_hg.node_attr(each_node)))
970
+ is_aromatic, _ = _check_aromatic(parent_hg, node_list)
971
+ for_ring = _check_ring(myself_hg)
972
+ bond_symbol_list = []
973
+ for each_node in node_list:
974
+ bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
975
+ lhs.add_edge(
976
+ node_list,
977
+ attr_dict=dict(
978
+ terminal=False,
979
+ edge_exists=edge_exists,
980
+ symbol=NTSymbol(
981
+ degree=len(node_list),
982
+ is_aromatic=is_aromatic,
983
+ bond_symbol_list=bond_symbol_list,
984
+ for_ring=for_ring)))
985
+ try:
986
+ lhs = _add_ext_node(lhs, node_list)
987
+ except ValueError:
988
+ import pdb; pdb.set_trace()
989
+
990
+ rhs = remove_tmp_edge(deepcopy(myself_hg))
991
+ #rhs = remove_ext_node(rhs)
992
+ #rhs = remove_nt_edge(rhs)
993
+ try:
994
+ rhs = _add_ext_node(rhs, node_list)
995
+ except ValueError:
996
+ import pdb; pdb.set_trace()
997
+
998
+ nt_idx = 0
999
+ if children_hg_list is not None:
1000
+ for each_child_hg in children_hg_list:
1001
+ node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
1002
+ is_aromatic, _ = _check_aromatic(myself_hg, node_list)
1003
+ for_ring = _check_ring(each_child_hg)
1004
+ bond_symbol_list = []
1005
+ for each_node in node_list:
1006
+ bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
1007
+ rhs.add_edge(
1008
+ node_list,
1009
+ attr_dict=dict(
1010
+ terminal=False,
1011
+ nt_idx=nt_idx,
1012
+ edge_exists=edge_exists,
1013
+ symbol=NTSymbol(degree=len(node_list),
1014
+ is_aromatic=is_aromatic,
1015
+ bond_symbol_list=bond_symbol_list,
1016
+ for_ring=for_ring)))
1017
+ nt_idx += 1
1018
+ prod_rule = ProductionRule(lhs, rhs)
1019
+ prod_rule.subhg_idx = subhg_idx
1020
+ if DEBUG:
1021
+ if sorted(list(prod_rule.ext_node.keys())) \
1022
+ != list(np.arange(len(prod_rule.ext_node))):
1023
+ raise RuntimeError('ext_id is not continuous')
1024
+ return prod_rule
1025
+
1026
+
1027
+ def _find_root(clique_tree):
1028
+ max_node = None
1029
+ num_nodes_max = -np.inf
1030
+ for each_node in clique_tree.nodes:
1031
+ if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
1032
+ max_node = each_node
1033
+ num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
1034
+ '''
1035
+ children = sorted(list(clique_tree[each_node].keys()))
1036
+ prod_rule = extract_prod_rule(None,
1037
+ clique_tree.nodes[each_node]["subhg"],
1038
+ [clique_tree.nodes[each_child]["subhg"]
1039
+ for each_child in children])
1040
+ for each_start_rule in start_rule_list:
1041
+ if prod_rule.is_same(each_start_rule):
1042
+ return each_node
1043
+ '''
1044
+ return max_node
1045
+
1046
+ def remove_ext_node(hg):
1047
+ for each_node in hg.nodes:
1048
+ hg.node_attr(each_node).pop('ext_id', None)
1049
+ return hg
1050
+
1051
+ def remove_nt_edge(hg):
1052
+ remove_edge_list = []
1053
+ for each_edge in hg.edges:
1054
+ if not hg.edge_attr(each_edge)['terminal']:
1055
+ remove_edge_list.append(each_edge)
1056
+ hg.remove_edges(remove_edge_list)
1057
+ return hg
1058
+
1059
+ def remove_tmp_edge(hg):
1060
+ remove_edge_list = []
1061
+ for each_edge in hg.edges:
1062
+ if hg.edge_attr(each_edge).get('tmp', False):
1063
+ remove_edge_list.append(each_edge)
1064
+ hg.remove_edges(remove_edge_list)
1065
+ return hg
graph_grammar/graph_grammar/symbols.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+
15
+ """ Title """
16
+
17
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
18
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
19
+ __version__ = "0.1"
20
+ __date__ = "Jan 1 2018"
21
+
22
+ from typing import List
23
+
24
+ class TSymbol(object):
25
+
26
+ ''' terminal symbol
27
+
28
+ Attributes
29
+ ----------
30
+ degree : int
31
+ the number of nodes in a hyperedge
32
+ is_aromatic : bool
33
+ whether or not the hyperedge is in an aromatic ring
34
+ symbol : str
35
+ atomic symbol
36
+ num_explicit_Hs : int
37
+ the number of hydrogens associated to this hyperedge
38
+ formal_charge : int
39
+ charge
40
+ chirality : int
41
+ chirality
42
+ '''
43
+
44
+ def __init__(self, degree, is_aromatic,
45
+ symbol, num_explicit_Hs, formal_charge, chirality):
46
+ self.degree = degree
47
+ self.is_aromatic = is_aromatic
48
+ self.symbol = symbol
49
+ self.num_explicit_Hs = num_explicit_Hs
50
+ self.formal_charge = formal_charge
51
+ self.chirality = chirality
52
+
53
+ @property
54
+ def terminal(self):
55
+ return True
56
+
57
+ def __eq__(self, other):
58
+ if not isinstance(other, TSymbol):
59
+ return False
60
+ if self.degree != other.degree:
61
+ return False
62
+ if self.is_aromatic != other.is_aromatic:
63
+ return False
64
+ if self.symbol != other.symbol:
65
+ return False
66
+ if self.num_explicit_Hs != other.num_explicit_Hs:
67
+ return False
68
+ if self.formal_charge != other.formal_charge:
69
+ return False
70
+ if self.chirality != other.chirality:
71
+ return False
72
+ return True
73
+
74
+ def __hash__(self):
75
+ return self.__str__().__hash__()
76
+
77
+ def __str__(self):
78
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
79
+ f'symbol={self.symbol}, '\
80
+ f'num_explicit_Hs={self.num_explicit_Hs}, '\
81
+ f'formal_charge={self.formal_charge}, chirality={self.chirality}'
82
+
83
+
84
+ class NTSymbol(object):
85
+
86
+ ''' non-terminal symbol
87
+
88
+ Attributes
89
+ ----------
90
+ degree : int
91
+ degree of the hyperedge
92
+ is_aromatic : bool
93
+ if True, at least one of the associated bonds must be aromatic.
94
+ node_aromatic_list : list of bool
95
+ indicate whether each of the nodes is aromatic or not.
96
+ bond_type_list : list of int
97
+ bond type of each node"
98
+ '''
99
+
100
+ def __init__(self, degree: int, is_aromatic: bool,
101
+ bond_symbol_list: list,
102
+ for_ring=False):
103
+ self.degree = degree
104
+ self.is_aromatic = is_aromatic
105
+ self.for_ring = for_ring
106
+ self.bond_symbol_list = bond_symbol_list
107
+
108
+ @property
109
+ def terminal(self) -> bool:
110
+ return False
111
+
112
+ @property
113
+ def symbol(self):
114
+ return f'NT{self.degree}'
115
+
116
+ def __eq__(self, other) -> bool:
117
+ if not isinstance(other, NTSymbol):
118
+ return False
119
+
120
+ if self.degree != other.degree:
121
+ return False
122
+ if self.is_aromatic != other.is_aromatic:
123
+ return False
124
+ if self.for_ring != other.for_ring:
125
+ return False
126
+ if len(self.bond_symbol_list) != len(other.bond_symbol_list):
127
+ return False
128
+ for each_idx in range(len(self.bond_symbol_list)):
129
+ if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
130
+ return False
131
+ return True
132
+
133
+ def __hash__(self):
134
+ return self.__str__().__hash__()
135
+
136
+ def __str__(self) -> str:
137
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
138
+ f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
139
+ f'for_ring={self.for_ring}'
140
+
141
+
142
+ class BondSymbol(object):
143
+
144
+
145
+ ''' Bond symbol
146
+
147
+ Attributes
148
+ ----------
149
+ is_aromatic : bool
150
+ if True, at least one of the associated bonds must be aromatic.
151
+ bond_type : int
152
+ bond type of each node"
153
+ '''
154
+
155
+ def __init__(self, is_aromatic: bool,
156
+ bond_type: int,
157
+ stereo: int):
158
+ self.is_aromatic = is_aromatic
159
+ self.bond_type = bond_type
160
+ self.stereo = stereo
161
+
162
+ def __eq__(self, other) -> bool:
163
+ if not isinstance(other, BondSymbol):
164
+ return False
165
+
166
+ if self.is_aromatic != other.is_aromatic:
167
+ return False
168
+ if self.bond_type != other.bond_type:
169
+ return False
170
+ if self.stereo != other.stereo:
171
+ return False
172
+ return True
173
+
174
+ def __hash__(self):
175
+ return self.__str__().__hash__()
176
+
177
+ def __str__(self) -> str:
178
+ return f'is_aromatic={self.is_aromatic}, '\
179
+ f'bond_type={self.bond_type}, '\
180
+ f'stereo={self.stereo}, '
graph_grammar/graph_grammar/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from ..hypergraph import Hypergraph
22
+ from copy import deepcopy
23
+ from typing import List
24
+ import numpy as np
25
+
26
+
27
+ def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
28
+ """ return a list of common nodes
29
+
30
+ Parameters
31
+ ----------
32
+ hg1, hg2 : Hypergraph
33
+
34
+ Returns
35
+ -------
36
+ list of str
37
+ list of common nodes
38
+ """
39
+ if hg1 is None or hg2 is None:
40
+ return [], False
41
+ else:
42
+ node_set = hg1.nodes.intersection(hg2.nodes)
43
+ node_dict = {}
44
+ if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
45
+ for each_node in node_set:
46
+ node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
47
+ else:
48
+ for each_node in node_set:
49
+ node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
50
+ node_list = []
51
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
52
+ node_list.append(each_key)
53
+ edge_name = hg1.has_edge(node_list, ignore_order=True)
54
+ if edge_name:
55
+ if not hg1.edge_attr(edge_name).get('terminal', True):
56
+ node_list = hg1.nodes_in_edge(edge_name)
57
+ return node_list, True
58
+ else:
59
+ return node_list, False
60
+
61
+
62
+ def _node_match(node1, node2):
63
+ # if the nodes are hyperedges, `atom_attr` determines the match
64
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
65
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
66
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
67
+ # bond_symbol
68
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
69
+ else:
70
+ return False
71
+
72
+ def _easy_node_match(node1, node2):
73
+ # if the nodes are hyperedges, `atom_attr` determines the match
74
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
75
+ return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
76
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
77
+ # bond_symbol
78
+ return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
79
+ and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
80
+ else:
81
+ return False
82
+
83
+
84
+ def _node_match_prod_rule(node1, node2, ignore_order=False):
85
+ # if the nodes are hyperedges, `atom_attr` determines the match
86
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
87
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
88
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
89
+ # ext_id, order4hrg, bond_symbol
90
+ if ignore_order:
91
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
92
+ else:
93
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
94
+ and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
95
+ else:
96
+ return False
97
+
98
+
99
+ def _edge_match(edge1, edge2, ignore_order=False):
100
+ #return True
101
+ if ignore_order:
102
+ return True
103
+ else:
104
+ return edge1["order"] == edge2["order"]
105
+
106
+ def masked_softmax(logit, mask):
107
+ ''' compute a probability distribution from logit
108
+
109
+ Parameters
110
+ ----------
111
+ logit : array-like, length D
112
+ each element indicates how each dimension is likely to be chosen
113
+ (the larger, the more likely)
114
+ mask : array-like, length D
115
+ each element is either 0 or 1.
116
+ if 0, the dimension is ignored
117
+ when computing the probability distribution.
118
+
119
+ Returns
120
+ -------
121
+ prob_dist : array, length D
122
+ probability distribution computed from logit.
123
+ if `mask[d] = 0`, `prob_dist[d] = 0`.
124
+ '''
125
+ if logit.shape != mask.shape:
126
+ raise ValueError('logit and mask must have the same shape')
127
+ c = np.max(logit)
128
+ exp_logit = np.exp(logit - c) * mask
129
+ sum_exp_logit = exp_logit @ mask
130
+ return exp_logit / sum_exp_logit
graph_grammar/hypergraph.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 31 2018"
20
+
21
+ from copy import deepcopy
22
+ from typing import List, Dict, Tuple
23
+ import networkx as nx
24
+ import numpy as np
25
+ import os
26
+
27
+
28
+ class Hypergraph(object):
29
+ '''
30
+ A class of a hypergraph.
31
+ Each hyperedge can be ordered. For the ordered case,
32
+ edges adjacent to the hyperedge node are labeled by their orders.
33
+
34
+ Attributes
35
+ ----------
36
+ hg : nx.Graph
37
+ a bipartite graph representation of a hypergraph
38
+ edge_idx : int
39
+ total number of hyperedges that exist so far
40
+ '''
41
+ def __init__(self):
42
+ self.hg = nx.Graph()
43
+ self.edge_idx = 0
44
+ self.nodes = set([])
45
+ self.num_nodes = 0
46
+ self.edges = set([])
47
+ self.num_edges = 0
48
+ self.nodes_in_edge_dict = {}
49
+
50
+ def add_node(self, node: str, attr_dict=None):
51
+ ''' add a node to hypergraph
52
+
53
+ Parameters
54
+ ----------
55
+ node : str
56
+ node name
57
+ attr_dict : dict
58
+ dictionary of node attributes
59
+ '''
60
+ self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
61
+ if node not in self.nodes:
62
+ self.num_nodes += 1
63
+ self.nodes.add(node)
64
+
65
+ def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
66
+ ''' add an edge consisting of nodes `node_list`
67
+
68
+ Parameters
69
+ ----------
70
+ node_list : list
71
+ ordered list of nodes that consist the edge
72
+ attr_dict : dict
73
+ dictionary of edge attributes
74
+ '''
75
+ if edge_name is None:
76
+ edge = 'e{}'.format(self.edge_idx)
77
+ else:
78
+ assert edge_name not in self.edges
79
+ edge = edge_name
80
+ self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
81
+ if edge not in self.edges:
82
+ self.num_edges += 1
83
+ self.edges.add(edge)
84
+ self.nodes_in_edge_dict[edge] = node_list
85
+ if type(node_list) == list:
86
+ for node_idx, each_node in enumerate(node_list):
87
+ self.hg.add_edge(edge, each_node, order=node_idx)
88
+ if each_node not in self.nodes:
89
+ self.num_nodes += 1
90
+ self.nodes.add(each_node)
91
+
92
+ elif type(node_list) == set:
93
+ for each_node in node_list:
94
+ self.hg.add_edge(edge, each_node, order=-1)
95
+ if each_node not in self.nodes:
96
+ self.num_nodes += 1
97
+ self.nodes.add(each_node)
98
+ else:
99
+ raise ValueError
100
+ self.edge_idx += 1
101
+ return edge
102
+
103
+ def remove_node(self, node: str, remove_connected_edges=True):
104
+ ''' remove a node
105
+
106
+ Parameters
107
+ ----------
108
+ node : str
109
+ node name
110
+ remove_connected_edges : bool
111
+ if True, remove edges that are adjacent to the node
112
+ '''
113
+ if remove_connected_edges:
114
+ connected_edges = deepcopy(self.adj_edges(node))
115
+ for each_edge in connected_edges:
116
+ self.remove_edge(each_edge)
117
+ self.hg.remove_node(node)
118
+ self.num_nodes -= 1
119
+ self.nodes.remove(node)
120
+
121
+ def remove_nodes(self, node_iter, remove_connected_edges=True):
122
+ ''' remove a set of nodes
123
+
124
+ Parameters
125
+ ----------
126
+ node_iter : iterator of strings
127
+ nodes to be removed
128
+ remove_connected_edges : bool
129
+ if True, remove edges that are adjacent to the node
130
+ '''
131
+ for each_node in node_iter:
132
+ self.remove_node(each_node, remove_connected_edges)
133
+
134
+ def remove_edge(self, edge: str):
135
+ ''' remove an edge
136
+
137
+ Parameters
138
+ ----------
139
+ edge : str
140
+ edge to be removed
141
+ '''
142
+ self.hg.remove_node(edge)
143
+ self.edges.remove(edge)
144
+ self.num_edges -= 1
145
+ self.nodes_in_edge_dict.pop(edge)
146
+
147
+ def remove_edges(self, edge_iter):
148
+ ''' remove a set of edges
149
+
150
+ Parameters
151
+ ----------
152
+ edge_iter : iterator of strings
153
+ edges to be removed
154
+ '''
155
+ for each_edge in edge_iter:
156
+ self.remove_edge(each_edge)
157
+
158
+ def remove_edges_with_attr(self, edge_attr_dict):
159
+ remove_edge_list = []
160
+ for each_edge in self.edges:
161
+ satisfy = True
162
+ for each_key, each_val in edge_attr_dict.items():
163
+ if not satisfy:
164
+ break
165
+ try:
166
+ if self.edge_attr(each_edge)[each_key] != each_val:
167
+ satisfy = False
168
+ except KeyError:
169
+ satisfy = False
170
+ if satisfy:
171
+ remove_edge_list.append(each_edge)
172
+ self.remove_edges(remove_edge_list)
173
+
174
+ def remove_subhg(self, subhg):
175
+ ''' remove subhypergraph.
176
+ all of the hyperedges are removed.
177
+ each node of subhg is removed if its degree becomes 0 after removing hyperedges.
178
+
179
+ Parameters
180
+ ----------
181
+ subhg : Hypergraph
182
+ '''
183
+ for each_edge in subhg.edges:
184
+ self.remove_edge(each_edge)
185
+ for each_node in subhg.nodes:
186
+ if self.degree(each_node) == 0:
187
+ self.remove_node(each_node)
188
+
189
+ def nodes_in_edge(self, edge):
190
+ ''' return an ordered list of nodes in a given edge.
191
+
192
+ Parameters
193
+ ----------
194
+ edge : str
195
+ edge whose nodes are returned
196
+
197
+ Returns
198
+ -------
199
+ list or set
200
+ ordered list or set of nodes that belong to the edge
201
+ '''
202
+ if edge.startswith('e'):
203
+ return self.nodes_in_edge_dict[edge]
204
+ else:
205
+ adj_node_list = self.hg.adj[edge]
206
+ adj_node_order_list = []
207
+ adj_node_name_list = []
208
+ for each_node in adj_node_list:
209
+ adj_node_order_list.append(adj_node_list[each_node]['order'])
210
+ adj_node_name_list.append(each_node)
211
+ if adj_node_order_list == [-1] * len(adj_node_order_list):
212
+ return set(adj_node_name_list)
213
+ else:
214
+ return [adj_node_name_list[each_idx] for each_idx
215
+ in np.argsort(adj_node_order_list)]
216
+
217
+ def adj_edges(self, node):
218
+ ''' return a dict of adjacent hyperedges
219
+
220
+ Parameters
221
+ ----------
222
+ node : str
223
+
224
+ Returns
225
+ -------
226
+ set
227
+ set of edges that are adjacent to `node`
228
+ '''
229
+ return self.hg.adj[node]
230
+
231
+ def adj_nodes(self, node):
232
+ ''' return a set of adjacent nodes
233
+
234
+ Parameters
235
+ ----------
236
+ node : str
237
+
238
+ Returns
239
+ -------
240
+ set
241
+ set of nodes that are adjacent to `node`
242
+ '''
243
+ node_set = set([])
244
+ for each_adj_edge in self.adj_edges(node):
245
+ node_set.update(set(self.nodes_in_edge(each_adj_edge)))
246
+ node_set.discard(node)
247
+ return node_set
248
+
249
+ def has_edge(self, node_list, ignore_order=False):
250
+ for each_edge in self.edges:
251
+ if ignore_order:
252
+ if set(self.nodes_in_edge(each_edge)) == set(node_list):
253
+ return each_edge
254
+ else:
255
+ if self.nodes_in_edge(each_edge) == node_list:
256
+ return each_edge
257
+ return False
258
+
259
+ def degree(self, node):
260
+ return len(self.hg.adj[node])
261
+
262
+ def degrees(self):
263
+ return {each_node: self.degree(each_node) for each_node in self.nodes}
264
+
265
+ def edge_degree(self, edge):
266
+ return len(self.nodes_in_edge(edge))
267
+
268
+ def edge_degrees(self):
269
+ return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
270
+
271
+ def is_adj(self, node1, node2):
272
+ return node1 in self.adj_nodes(node2)
273
+
274
+ def adj_subhg(self, node, ident_node_dict=None):
275
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
276
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
277
+
278
+ Parameters
279
+ ----------
280
+ node : str
281
+ ident_node_dict : dict
282
+ dict containing identical nodes. see `get_identical_node_dict` for more details
283
+
284
+ Returns
285
+ -------
286
+ subhg : Hypergraph
287
+ """
288
+ if ident_node_dict is None:
289
+ ident_node_dict = self.get_identical_node_dict()
290
+ adj_node_set = set(ident_node_dict[node])
291
+ adj_edge_set = set([])
292
+ for each_node in ident_node_dict[node]:
293
+ adj_edge_set.update(set(self.adj_edges(each_node)))
294
+ fixed_adj_edge_set = deepcopy(adj_edge_set)
295
+ for each_edge in fixed_adj_edge_set:
296
+ other_nodes = self.nodes_in_edge(each_edge)
297
+ adj_node_set.update(other_nodes)
298
+
299
+ # if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
300
+ for each_node in other_nodes:
301
+ for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
302
+ if len(set(self.nodes_in_edge(other_edge)) \
303
+ - set(self.nodes_in_edge(each_edge))) == 0:
304
+ adj_edge_set.update(set([other_edge]))
305
+ subhg = Hypergraph()
306
+ for each_node in adj_node_set:
307
+ subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
308
+ for each_edge in adj_edge_set:
309
+ subhg.add_edge(self.nodes_in_edge(each_edge),
310
+ attr_dict=self.edge_attr(each_edge),
311
+ edge_name=each_edge)
312
+ subhg.edge_idx = self.edge_idx
313
+ return subhg
314
+
315
+ def get_subhg(self, node_list, edge_list, ident_node_dict=None):
316
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
317
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
318
+
319
+ Parameters
320
+ ----------
321
+ node : str
322
+ ident_node_dict : dict
323
+ dict containing identical nodes. see `get_identical_node_dict` for more details
324
+
325
+ Returns
326
+ -------
327
+ subhg : Hypergraph
328
+ """
329
+ if ident_node_dict is None:
330
+ ident_node_dict = self.get_identical_node_dict()
331
+ adj_node_set = set([])
332
+ for each_node in node_list:
333
+ adj_node_set.update(set(ident_node_dict[each_node]))
334
+ adj_edge_set = set(edge_list)
335
+
336
+ subhg = Hypergraph()
337
+ for each_node in adj_node_set:
338
+ subhg.add_node(each_node,
339
+ attr_dict=deepcopy(self.node_attr(each_node)))
340
+ for each_edge in adj_edge_set:
341
+ subhg.add_edge(self.nodes_in_edge(each_edge),
342
+ attr_dict=deepcopy(self.edge_attr(each_edge)),
343
+ edge_name=each_edge)
344
+ subhg.edge_idx = self.edge_idx
345
+ return subhg
346
+
347
+ def copy(self):
348
+ ''' return a copy of the object
349
+
350
+ Returns
351
+ -------
352
+ Hypergraph
353
+ '''
354
+ return deepcopy(self)
355
+
356
+ def node_attr(self, node):
357
+ return self.hg.nodes[node]['attr_dict']
358
+
359
+ def edge_attr(self, edge):
360
+ return self.hg.nodes[edge]['attr_dict']
361
+
362
+ def set_node_attr(self, node, attr_dict):
363
+ for each_key, each_val in attr_dict.items():
364
+ self.hg.nodes[node]['attr_dict'][each_key] = each_val
365
+
366
+ def set_edge_attr(self, edge, attr_dict):
367
+ for each_key, each_val in attr_dict.items():
368
+ self.hg.nodes[edge]['attr_dict'][each_key] = each_val
369
+
370
+ def get_identical_node_dict(self):
371
+ ''' get identical nodes
372
+ nodes are identical if they share the same set of adjacent edges.
373
+
374
+ Returns
375
+ -------
376
+ ident_node_dict : dict
377
+ ident_node_dict[node] returns a list of nodes that are identical to `node`.
378
+ '''
379
+ ident_node_dict = {}
380
+ for each_node in self.nodes:
381
+ ident_node_list = []
382
+ for each_other_node in self.nodes:
383
+ if each_other_node == each_node:
384
+ ident_node_list.append(each_other_node)
385
+ elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
386
+ and len(self.adj_edges(each_node)) != 0:
387
+ ident_node_list.append(each_other_node)
388
+ ident_node_dict[each_node] = ident_node_list
389
+ return ident_node_dict
390
+ '''
391
+ ident_node_dict = {}
392
+ for each_node in self.nodes:
393
+ ident_node_dict[each_node] = [each_node]
394
+ return ident_node_dict
395
+ '''
396
+
397
+ def get_leaf_edge(self):
398
+ ''' get an edge that is incident only to one edge
399
+
400
+ Returns
401
+ -------
402
+ if exists, return a leaf edge. otherwise, return None.
403
+ '''
404
+ for each_edge in self.edges:
405
+ if len(self.adj_nodes(each_edge)) == 1:
406
+ if 'tmp' not in self.edge_attr(each_edge):
407
+ return each_edge
408
+ return None
409
+
410
+ def get_nontmp_edge(self):
411
+ for each_edge in self.edges:
412
+ if 'tmp' not in self.edge_attr(each_edge):
413
+ return each_edge
414
+ return None
415
+
416
+ def is_subhg(self, hg):
417
+ ''' return whether this hypergraph is a subhypergraph of `hg`
418
+
419
+ Returns
420
+ -------
421
+ True if self \in hg,
422
+ False otherwise.
423
+ '''
424
+ for each_node in self.nodes:
425
+ if each_node not in hg.nodes:
426
+ return False
427
+ for each_edge in self.edges:
428
+ if each_edge not in hg.edges:
429
+ return False
430
+ return True
431
+
432
+ def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
433
+ ''' if `node` is in a cycle, then return True. otherwise, False.
434
+
435
+ Parameters
436
+ ----------
437
+ node : str
438
+ node in a hypergraph
439
+ visited : list
440
+ list of visited nodes, used for recursion
441
+ parent : str
442
+ parent node, used to eliminate a cycle consisting of two nodes and one edge.
443
+
444
+ Returns
445
+ -------
446
+ bool
447
+ '''
448
+ if visited is None:
449
+ visited = []
450
+ if parent == '':
451
+ visited = []
452
+ if root_node == '':
453
+ root_node = node
454
+ visited.append(node)
455
+ for each_adj_node in self.adj_nodes(node):
456
+ if each_adj_node not in visited:
457
+ if self.in_cycle(each_adj_node, visited, node, root_node):
458
+ return True
459
+ elif each_adj_node != parent and each_adj_node == root_node:
460
+ return True
461
+ return False
462
+
463
+
464
+ def draw(self, file_path=None, with_node=False, with_edge_name=False):
465
+ ''' draw hypergraph
466
+ '''
467
+ import graphviz
468
+ G = graphviz.Graph(format='png')
469
+ for each_node in self.nodes:
470
+ if 'ext_id' in self.node_attr(each_node):
471
+ G.node(each_node, label='',
472
+ shape='circle', width='0.1', height='0.1', style='filled',
473
+ fillcolor='black')
474
+ else:
475
+ if with_node:
476
+ G.node(each_node, label='',
477
+ shape='circle', width='0.1', height='0.1', style='filled',
478
+ fillcolor='gray')
479
+ edge_list = []
480
+ for each_edge in self.edges:
481
+ if self.edge_attr(each_edge).get('terminal', False):
482
+ G.node(each_edge,
483
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
484
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
485
+ fontcolor='black', shape='square')
486
+ elif self.edge_attr(each_edge).get('tmp', False):
487
+ G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
488
+ fontcolor='black', shape='square')
489
+ else:
490
+ G.node(each_edge,
491
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
492
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
493
+ fontcolor='black', shape='square', style='filled')
494
+ if with_node:
495
+ for each_node in self.nodes_in_edge(each_edge):
496
+ G.edge(each_edge, each_node)
497
+ else:
498
+ for each_node in self.nodes_in_edge(each_edge):
499
+ if 'ext_id' in self.node_attr(each_node)\
500
+ and set([each_node, each_edge]) not in edge_list:
501
+ G.edge(each_edge, each_node)
502
+ edge_list.append(set([each_node, each_edge]))
503
+ for each_other_edge in self.adj_nodes(each_edge):
504
+ if set([each_edge, each_other_edge]) not in edge_list:
505
+ num_bond = 0
506
+ common_node_set = set(self.nodes_in_edge(each_edge))\
507
+ .intersection(set(self.nodes_in_edge(each_other_edge)))
508
+ for each_node in common_node_set:
509
+ if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
510
+ num_bond += self.node_attr(each_node)['symbol'].bond_type
511
+ elif self.node_attr(each_node)['symbol'].bond_type in [12]:
512
+ num_bond += 1
513
+ else:
514
+ raise NotImplementedError('unsupported bond type')
515
+ for _ in range(num_bond):
516
+ G.edge(each_edge, each_other_edge)
517
+ edge_list.append(set([each_edge, each_other_edge]))
518
+ if file_path is not None:
519
+ G.render(file_path, cleanup=True)
520
+ #os.remove(file_path)
521
+ return G
522
+
523
+ def is_dividable(self, node):
524
+ _hg = deepcopy(self.hg)
525
+ _hg.remove_node(node)
526
+ return (not nx.is_connected(_hg))
527
+
528
+ def divide(self, node):
529
+ subhg_list = []
530
+
531
+ hg_wo_node = deepcopy(self)
532
+ hg_wo_node.remove_node(node, remove_connected_edges=False)
533
+ connected_components = nx.connected_components(hg_wo_node.hg)
534
+ for each_component in connected_components:
535
+ node_list = [node]
536
+ edge_list = []
537
+ node_list.extend([each_node for each_node in each_component
538
+ if each_node.startswith('bond_')])
539
+ edge_list.extend([each_edge for each_edge in each_component
540
+ if each_edge.startswith('e')])
541
+ subhg_list.append(self.get_subhg(node_list, edge_list))
542
+ #subhg_list[-1].set_node_attr(node, {'divided': True})
543
+ return subhg_list
544
+
graph_grammar/io/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
graph_grammar/io/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (669 Bytes). View file
 
graph_grammar/io/__pycache__/smi.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
graph_grammar/io/smi.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 12 2018"
20
+
21
+ from copy import deepcopy
22
+ from rdkit import Chem
23
+ from rdkit import RDLogger
24
+ import networkx as nx
25
+ import numpy as np
26
+ from ..hypergraph import Hypergraph
27
+ from ..graph_grammar.symbols import TSymbol, BondSymbol
28
+
29
+ # supress warnings
30
+ lg = RDLogger.logger()
31
+ lg.setLevel(RDLogger.CRITICAL)
32
+
33
+
34
+ class HGGen(object):
35
+ """
36
+ load .smi file and yield a hypergraph.
37
+
38
+ Attributes
39
+ ----------
40
+ path_to_file : str
41
+ path to .smi file
42
+ kekulize : bool
43
+ kekulize or not
44
+ add_Hs : bool
45
+ add implicit hydrogens to the molecule or not.
46
+ all_single : bool
47
+ if True, all multiple bonds are summarized into a single bond with some attributes
48
+
49
+ Yields
50
+ ------
51
+ Hypergraph
52
+ """
53
+ def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
54
+ self.num_line = 1
55
+ self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
56
+ self.kekulize = kekulize
57
+ self.add_Hs = add_Hs
58
+ self.all_single = all_single
59
+
60
+ def __iter__(self):
61
+ return self
62
+
63
+ def __next__(self):
64
+ '''
65
+ each_mol = None
66
+ while each_mol is None:
67
+ each_mol = next(self.mol_gen)
68
+ '''
69
+ # not ignoring parse errors
70
+ each_mol = next(self.mol_gen)
71
+ if each_mol is None:
72
+ raise ValueError(f'incorrect smiles in line {self.num_line}')
73
+ else:
74
+ self.num_line += 1
75
+ return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
76
+
77
+
78
+ def mol_to_bipartite(mol, kekulize):
79
+ """
80
+ get a bipartite representation of a molecule.
81
+
82
+ Parameters
83
+ ----------
84
+ mol : rdkit.Chem.rdchem.Mol
85
+ molecule object
86
+
87
+ Returns
88
+ -------
89
+ nx.Graph
90
+ a bipartite graph representing which bond is connected to which atoms.
91
+ """
92
+ try:
93
+ mol = standardize_stereo(mol)
94
+ except KeyError:
95
+ print(Chem.MolToSmiles(mol))
96
+ raise KeyError
97
+
98
+ if kekulize:
99
+ Chem.Kekulize(mol)
100
+
101
+ bipartite_g = nx.Graph()
102
+ for each_atom in mol.GetAtoms():
103
+ bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
104
+ atom_attr=atom_attr(each_atom, kekulize))
105
+
106
+ for each_bond in mol.GetBonds():
107
+ bond_idx = each_bond.GetIdx()
108
+ bipartite_g.add_node(
109
+ f"bond_{bond_idx}",
110
+ bond_attr=bond_attr(each_bond, kekulize))
111
+ bipartite_g.add_edge(
112
+ f"atom_{each_bond.GetBeginAtomIdx()}",
113
+ f"bond_{bond_idx}")
114
+ bipartite_g.add_edge(
115
+ f"atom_{each_bond.GetEndAtomIdx()}",
116
+ f"bond_{bond_idx}")
117
+ return bipartite_g
118
+
119
+
120
+ def mol_to_hg(mol, kekulize, add_Hs):
121
+ """
122
+ get a bipartite representation of a molecule.
123
+
124
+ Parameters
125
+ ----------
126
+ mol : rdkit.Chem.rdchem.Mol
127
+ molecule object
128
+ kekulize : bool
129
+ kekulize or not
130
+ add_Hs : bool
131
+ add implicit hydrogens to the molecule or not.
132
+
133
+ Returns
134
+ -------
135
+ Hypergraph
136
+ """
137
+ if add_Hs:
138
+ mol = Chem.AddHs(mol)
139
+
140
+ if kekulize:
141
+ Chem.Kekulize(mol)
142
+
143
+ bipartite_g = mol_to_bipartite(mol, kekulize)
144
+ hg = Hypergraph()
145
+ for each_atom in [each_node for each_node in bipartite_g.nodes()
146
+ if each_node.startswith('atom_')]:
147
+ node_set = set([])
148
+ for each_bond in bipartite_g.adj[each_atom]:
149
+ hg.add_node(each_bond,
150
+ attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
151
+ node_set.add(each_bond)
152
+ hg.add_edge(node_set,
153
+ attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
154
+ return hg
155
+
156
+
157
+ def hg_to_mol(hg, verbose=False):
158
+ """ convert a hypergraph into Mol object
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+
164
+ Returns
165
+ -------
166
+ mol : Chem.RWMol
167
+ """
168
+ mol = Chem.RWMol()
169
+ atom_dict = {}
170
+ bond_set = set([])
171
+ for each_edge in hg.edges:
172
+ atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
173
+ atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
174
+ atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
175
+ atom.SetChiralTag(
176
+ Chem.rdchem.ChiralType.values[
177
+ hg.edge_attr(each_edge)['symbol'].chirality])
178
+ atom_idx = mol.AddAtom(atom)
179
+ atom_dict[each_edge] = atom_idx
180
+
181
+ for each_node in hg.nodes:
182
+ edge_1, edge_2 = hg.adj_edges(each_node)
183
+ if edge_1+edge_2 not in bond_set:
184
+ if hg.node_attr(each_node)['symbol'].bond_type <= 3:
185
+ num_bond = hg.node_attr(each_node)['symbol'].bond_type
186
+ elif hg.node_attr(each_node)['symbol'].bond_type == 12:
187
+ num_bond = 1
188
+ else:
189
+ raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
190
+ _ = mol.AddBond(atom_dict[edge_1],
191
+ atom_dict[edge_2],
192
+ order=Chem.rdchem.BondType.values[num_bond])
193
+ bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
194
+
195
+ # stereo
196
+ mol.GetBondWithIdx(bond_idx).SetStereo(
197
+ Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
198
+ bond_set.update([edge_1+edge_2])
199
+ bond_set.update([edge_2+edge_1])
200
+ mol.UpdatePropertyCache()
201
+ mol = mol.GetMol()
202
+ not_stereo_mol = deepcopy(mol)
203
+ if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
204
+ raise RuntimeError('no valid molecule was obtained.')
205
+ try:
206
+ mol = set_stereo(mol)
207
+ is_stereo = True
208
+ except:
209
+ import traceback
210
+ traceback.print_exc()
211
+ is_stereo = False
212
+ mol_tmp = deepcopy(mol)
213
+ Chem.SetAromaticity(mol_tmp)
214
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
215
+ mol = mol_tmp
216
+ else:
217
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
218
+ mol = not_stereo_mol
219
+ mol.UpdatePropertyCache()
220
+ Chem.GetSymmSSSR(mol)
221
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
222
+ if verbose:
223
+ return mol, is_stereo
224
+ else:
225
+ return mol
226
+
227
+ def hgs_to_mols(hg_list, ignore_error=False):
228
+ if ignore_error:
229
+ mol_list = []
230
+ for each_hg in hg_list:
231
+ try:
232
+ mol = hg_to_mol(each_hg)
233
+ except:
234
+ mol = None
235
+ mol_list.append(mol)
236
+ else:
237
+ mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
238
+ return mol_list
239
+
240
+ def hgs_to_smiles(hg_list, ignore_error=False):
241
+ mol_list = hgs_to_mols(hg_list, ignore_error)
242
+ smiles_list = []
243
+ for each_mol in mol_list:
244
+ try:
245
+ smiles_list.append(
246
+ Chem.MolToSmiles(
247
+ Chem.MolFromSmiles(
248
+ Chem.MolToSmiles(
249
+ each_mol))))
250
+ except:
251
+ smiles_list.append(None)
252
+ return smiles_list
253
+
254
+ def atom_attr(atom, kekulize):
255
+ """
256
+ get atom's attributes
257
+
258
+ Parameters
259
+ ----------
260
+ atom : rdkit.Chem.rdchem.Atom
261
+ kekulize : bool
262
+ kekulize or not
263
+
264
+ Returns
265
+ -------
266
+ atom_attr : dict
267
+ "is_aromatic" : bool
268
+ the atom is aromatic or not.
269
+ "smarts" : str
270
+ SMARTS representation of the atom.
271
+ """
272
+ if kekulize:
273
+ return {'terminal': True,
274
+ 'is_in_ring': atom.IsInRing(),
275
+ 'symbol': TSymbol(degree=0,
276
+ #degree=atom.GetTotalDegree(),
277
+ is_aromatic=False,
278
+ symbol=atom.GetSymbol(),
279
+ num_explicit_Hs=atom.GetNumExplicitHs(),
280
+ formal_charge=atom.GetFormalCharge(),
281
+ chirality=atom.GetChiralTag().real
282
+ )}
283
+ else:
284
+ return {'terminal': True,
285
+ 'is_in_ring': atom.IsInRing(),
286
+ 'symbol': TSymbol(degree=0,
287
+ #degree=atom.GetTotalDegree(),
288
+ is_aromatic=atom.GetIsAromatic(),
289
+ symbol=atom.GetSymbol(),
290
+ num_explicit_Hs=atom.GetNumExplicitHs(),
291
+ formal_charge=atom.GetFormalCharge(),
292
+ chirality=atom.GetChiralTag().real
293
+ )}
294
+
295
+ def bond_attr(bond, kekulize):
296
+ """
297
+ get atom's attributes
298
+
299
+ Parameters
300
+ ----------
301
+ bond : rdkit.Chem.rdchem.Bond
302
+ kekulize : bool
303
+ kekulize or not
304
+
305
+ Returns
306
+ -------
307
+ bond_attr : dict
308
+ "bond_type" : int
309
+ {0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
310
+ 1: rdkit.Chem.rdchem.BondType.SINGLE,
311
+ 2: rdkit.Chem.rdchem.BondType.DOUBLE,
312
+ 3: rdkit.Chem.rdchem.BondType.TRIPLE,
313
+ 4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
314
+ 5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
315
+ 6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
316
+ 7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
317
+ 8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
318
+ 9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
319
+ 10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
320
+ 11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
321
+ 12: rdkit.Chem.rdchem.BondType.AROMATIC,
322
+ 13: rdkit.Chem.rdchem.BondType.IONIC,
323
+ 14: rdkit.Chem.rdchem.BondType.HYDROGEN,
324
+ 15: rdkit.Chem.rdchem.BondType.THREECENTER,
325
+ 16: rdkit.Chem.rdchem.BondType.DATIVEONE,
326
+ 17: rdkit.Chem.rdchem.BondType.DATIVE,
327
+ 18: rdkit.Chem.rdchem.BondType.DATIVEL,
328
+ 19: rdkit.Chem.rdchem.BondType.DATIVER,
329
+ 20: rdkit.Chem.rdchem.BondType.OTHER,
330
+ 21: rdkit.Chem.rdchem.BondType.ZERO}
331
+ """
332
+ if kekulize:
333
+ is_aromatic = False
334
+ if bond.GetBondType().real == 12:
335
+ bond_type = 1
336
+ else:
337
+ bond_type = bond.GetBondType().real
338
+ else:
339
+ is_aromatic = bond.GetIsAromatic()
340
+ bond_type = bond.GetBondType().real
341
+ return {'symbol': BondSymbol(is_aromatic=is_aromatic,
342
+ bond_type=bond_type,
343
+ stereo=int(bond.GetStereo())),
344
+ 'is_in_ring': bond.IsInRing()}
345
+
346
+
347
+ def standardize_stereo(mol):
348
+ '''
349
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
350
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
351
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
352
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
353
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
354
+
355
+ '''
356
+ # mol = Chem.AddHs(mol) # this removes CIPRank !!!
357
+ for each_bond in mol.GetBonds():
358
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
359
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
360
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
361
+ atom_idx_1 = each_bond.GetStereoAtoms()[0]
362
+ atom_idx_2 = each_bond.GetStereoAtoms()[1]
363
+ if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
364
+ begin_atom_idx = atom_idx_1
365
+ end_atom_idx = atom_idx_2
366
+ else:
367
+ begin_atom_idx = atom_idx_2
368
+ end_atom_idx = atom_idx_1
369
+
370
+ begin_another_atom_idx = None
371
+ assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
372
+ for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
373
+ each_neighbor_idx = each_neighbor.GetIdx()
374
+ if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
375
+ begin_another_atom_idx = each_neighbor_idx
376
+
377
+ end_another_atom_idx = None
378
+ assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
379
+ for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
380
+ each_neighbor_idx = each_neighbor.GetIdx()
381
+ if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
382
+ end_another_atom_idx = each_neighbor_idx
383
+
384
+ '''
385
+ relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
386
+ '''
387
+ begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
388
+ end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
389
+ try:
390
+ begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
391
+ except:
392
+ begin_another_atom_rank = np.inf
393
+ try:
394
+ end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
395
+ except:
396
+ end_another_atom_rank = np.inf
397
+ if begin_atom_rank < begin_another_atom_rank\
398
+ and end_atom_rank < end_another_atom_rank:
399
+ pass
400
+ elif begin_atom_rank < begin_another_atom_rank\
401
+ and end_atom_rank > end_another_atom_rank:
402
+ # (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
403
+ if each_bond.GetStereo() == 2:
404
+ # set stereo
405
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
406
+ # set bond dir
407
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
408
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
409
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
410
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
411
+ elif each_bond.GetStereo() == 3:
412
+ # set stereo
413
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
414
+ # set bond dir
415
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
416
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
417
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
418
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
419
+ else:
420
+ raise ValueError
421
+ each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
422
+ elif begin_atom_rank > begin_another_atom_rank\
423
+ and end_atom_rank < end_another_atom_rank:
424
+ # (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
425
+ if each_bond.GetStereo() == 2:
426
+ # set stereo
427
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
428
+ # set bond dir
429
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
430
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
431
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
432
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
433
+ elif each_bond.GetStereo() == 3:
434
+ # set stereo
435
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
436
+ # set bond dir
437
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
438
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
439
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
440
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
441
+ else:
442
+ raise ValueError
443
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
444
+ elif begin_atom_rank > begin_another_atom_rank\
445
+ and end_atom_rank > end_another_atom_rank:
446
+ # begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
447
+ if each_bond.GetStereo() == 2:
448
+ # set bond dir
449
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
450
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
451
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
452
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
453
+ elif each_bond.GetStereo() == 3:
454
+ # set bond dir
455
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
456
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
457
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
458
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
459
+ else:
460
+ raise ValueError
461
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
462
+ else:
463
+ raise RuntimeError
464
+ return mol
465
+
466
+
467
+ def set_stereo(mol):
468
+ '''
469
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
470
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
471
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
472
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
473
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
474
+ '''
475
+ _mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
476
+ Chem.Kekulize(_mol, True)
477
+ substruct_match = mol.GetSubstructMatch(_mol)
478
+ if not substruct_match:
479
+ ''' mol and _mol are kekulized.
480
+ sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
481
+ '''
482
+ Chem.SetAromaticity(mol)
483
+ Chem.SetAromaticity(_mol)
484
+ substruct_match = mol.GetSubstructMatch(_mol)
485
+ try:
486
+ atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
487
+ except:
488
+ raise ValueError('two molecules obtained from the same data do not match.')
489
+
490
+ for each_bond in mol.GetBonds():
491
+ begin_atom_idx = each_bond.GetBeginAtomIdx()
492
+ end_atom_idx = each_bond.GetEndAtomIdx()
493
+ _bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
494
+ _bond.SetStereo(each_bond.GetStereo())
495
+
496
+ mol = _mol
497
+ for each_bond in mol.GetBonds():
498
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
499
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
500
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
501
+ begin_atom_idx_set = set([each_neighbor.GetIdx()
502
+ for each_neighbor
503
+ in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
504
+ if each_neighbor.GetIdx() != end_stereo_atom_idx])
505
+ end_atom_idx_set = set([each_neighbor.GetIdx()
506
+ for each_neighbor
507
+ in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
508
+ if each_neighbor.GetIdx() != begin_stereo_atom_idx])
509
+ if not begin_atom_idx_set:
510
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
511
+ continue
512
+ if not end_atom_idx_set:
513
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
514
+ continue
515
+ if len(begin_atom_idx_set) == 1:
516
+ begin_atom_idx = begin_atom_idx_set.pop()
517
+ begin_another_atom_idx = None
518
+ if len(end_atom_idx_set) == 1:
519
+ end_atom_idx = end_atom_idx_set.pop()
520
+ end_another_atom_idx = None
521
+ if len(begin_atom_idx_set) == 2:
522
+ atom_idx_1 = begin_atom_idx_set.pop()
523
+ atom_idx_2 = begin_atom_idx_set.pop()
524
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
525
+ begin_atom_idx = atom_idx_1
526
+ begin_another_atom_idx = atom_idx_2
527
+ else:
528
+ begin_atom_idx = atom_idx_2
529
+ begin_another_atom_idx = atom_idx_1
530
+ if len(end_atom_idx_set) == 2:
531
+ atom_idx_1 = end_atom_idx_set.pop()
532
+ atom_idx_2 = end_atom_idx_set.pop()
533
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
534
+ end_atom_idx = atom_idx_1
535
+ end_another_atom_idx = atom_idx_2
536
+ else:
537
+ end_atom_idx = atom_idx_2
538
+ end_another_atom_idx = atom_idx_1
539
+
540
+ if each_bond.GetStereo() == 2: # same side
541
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
542
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
543
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
544
+ elif each_bond.GetStereo() == 3: # opposite side
545
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
546
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
547
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
548
+ else:
549
+ raise ValueError
550
+ return mol
551
+
552
+
553
+ def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
554
+ if atom_idx_1 is None or atom_idx_2 is None:
555
+ return mol
556
+ else:
557
+ mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
558
+ return mol
559
+
graph_grammar/nn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """
graph_grammar/nn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (508 Bytes). View file
 
graph_grammar/nn/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (3.98 kB). View file
 
graph_grammar/nn/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (5.38 kB). View file
 
graph_grammar/nn/dataset.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Apr 18 2018"
20
+
21
+ from torch.utils.data import Dataset, DataLoader
22
+ import torch
23
+ import numpy as np
24
+
25
+
26
+ def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False):
27
+ ''' pad left
28
+
29
+ Parameters
30
+ ----------
31
+ sentence_list : list of sequences of integers
32
+ max_len : int
33
+ maximum length of sentences.
34
+ if a sentence is shorter than `max_len`, its left part is padded.
35
+ pad_idx : int
36
+ integer for padding
37
+ inverse : bool
38
+ if True, the sequence is inversed.
39
+
40
+ Returns
41
+ -------
42
+ List of torch.LongTensor
43
+ each sentence is left-padded.
44
+ '''
45
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
46
+
47
+ if max_in_list > max_len:
48
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
49
+
50
+ if inverse:
51
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list]
52
+ else:
53
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list]
54
+
55
+
56
+ def right_padding(sentence_list, max_len, pad_idx=-1):
57
+ ''' pad right
58
+
59
+ Parameters
60
+ ----------
61
+ sentence_list : list of sequences of integers
62
+ max_len : int
63
+ maximum length of sentences.
64
+ if a sentence is shorter than `max_len`, its right part is padded.
65
+ pad_idx : int
66
+ integer for padding
67
+
68
+ Returns
69
+ -------
70
+ List of torch.LongTensor
71
+ each sentence is right-padded.
72
+ '''
73
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
74
+ if max_in_list > max_len:
75
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
76
+
77
+ return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list]
78
+
79
+
80
+ class HRGDataset(Dataset):
81
+
82
+ '''
83
+ A class of HRG data
84
+ '''
85
+
86
+ def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False):
87
+ self.hrg = hrg
88
+ self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list,
89
+ max_len,
90
+ inverse=inversed_input)
91
+
92
+ self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len)
93
+ self.inserved_input = inversed_input
94
+ self.target_val_list = target_val_list
95
+ if target_val_list is not None:
96
+ if len(prod_rule_seq_list) != len(target_val_list):
97
+ raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}')
98
+
99
+ def __len__(self):
100
+ return len(self.left_prod_rule_seq_list)
101
+
102
+ def __getitem__(self, idx):
103
+ if self.target_val_list is not None:
104
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx])
105
+ else:
106
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx]
107
+
108
+ @property
109
+ def vocab_size(self):
110
+ return self.hrg.num_prod_rule
111
+
112
+ def batch_padding(each_batch, batch_size, padding_idx):
113
+ num_pad = batch_size - len(each_batch[0])
114
+ if num_pad:
115
+ each_batch[0] = torch.cat([each_batch[0],
116
+ padding_idx * torch.ones((batch_size - len(each_batch[0]),
117
+ len(each_batch[0][0])), dtype=torch.int64)], dim=0)
118
+ each_batch[1] = torch.cat([each_batch[1],
119
+ padding_idx * torch.ones((batch_size - len(each_batch[1]),
120
+ len(each_batch[1][0])), dtype=torch.int64)], dim=0)
121
+ return each_batch, num_pad
graph_grammar/nn/decoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Aug 9 2018"
20
+
21
+
22
+ import abc
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+
27
+
28
+ class DecoderBase(nn.Module):
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.hidden_dict = {}
33
+
34
+ @abc.abstractmethod
35
+ def forward_one_step(self, tgt_emb_in):
36
+ ''' one-step forward model
37
+
38
+ Parameters
39
+ ----------
40
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
41
+
42
+ Returns
43
+ -------
44
+ Tensor, shape (batch_size, hidden_dim)
45
+ '''
46
+ tgt_emb_out = None
47
+ return tgt_emb_out
48
+
49
+ @abc.abstractmethod
50
+ def init_hidden(self):
51
+ ''' initialize the hidden states
52
+ '''
53
+ pass
54
+
55
+ @abc.abstractmethod
56
+ def feed_hidden(self, hidden_dict_0):
57
+ for each_hidden in self.hidden_dict.keys():
58
+ self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden]
59
+
60
+
61
+ class GRUDecoder(DecoderBase):
62
+
63
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
64
+ dropout: float, batch_size: int, use_gpu: bool,
65
+ no_dropout=False):
66
+ super().__init__()
67
+ self.input_dim = input_dim
68
+ self.hidden_dim = hidden_dim
69
+ self.num_layers = num_layers
70
+ self.dropout = dropout
71
+ self.batch_size = batch_size
72
+ self.use_gpu = use_gpu
73
+ self.model = nn.GRU(input_size=self.input_dim,
74
+ hidden_size=self.hidden_dim,
75
+ num_layers=self.num_layers,
76
+ batch_first=True,
77
+ bidirectional=False,
78
+ dropout=self.dropout if not no_dropout else 0
79
+ )
80
+ if self.use_gpu:
81
+ self.model.cuda()
82
+ self.init_hidden()
83
+
84
+ def init_hidden(self):
85
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
86
+ self.batch_size,
87
+ self.hidden_dim),
88
+ requires_grad=False)
89
+ if self.use_gpu:
90
+ self.hidden_dict['h'] = self.hidden_dict['h'].cuda()
91
+
92
+ def forward_one_step(self, tgt_emb_in):
93
+ ''' one-step forward model
94
+
95
+ Parameters
96
+ ----------
97
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
98
+
99
+ Returns
100
+ -------
101
+ Tensor, shape (batch_size, hidden_dim)
102
+ '''
103
+ tgt_emb_out, self.hidden_dict['h'] \
104
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
105
+ self.hidden_dict['h'])
106
+ return tgt_emb_out
107
+
108
+
109
+ class LSTMDecoder(DecoderBase):
110
+
111
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
112
+ dropout: float, batch_size: int, use_gpu: bool,
113
+ no_dropout=False):
114
+ super().__init__()
115
+ self.input_dim = input_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.dropout = dropout
119
+ self.batch_size = batch_size
120
+ self.use_gpu = use_gpu
121
+ self.model = nn.LSTM(input_size=self.input_dim,
122
+ hidden_size=self.hidden_dim,
123
+ num_layers=self.num_layers,
124
+ batch_first=True,
125
+ bidirectional=False,
126
+ dropout=self.dropout if not no_dropout else 0)
127
+ if self.use_gpu:
128
+ self.model.cuda()
129
+ self.init_hidden()
130
+
131
+ def init_hidden(self):
132
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
133
+ self.batch_size,
134
+ self.hidden_dim),
135
+ requires_grad=False)
136
+ self.hidden_dict['c'] = torch.zeros((self.num_layers,
137
+ self.batch_size,
138
+ self.hidden_dim),
139
+ requires_grad=False)
140
+ if self.use_gpu:
141
+ for each_hidden in self.hidden_dict.keys():
142
+ self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda()
143
+
144
+ def forward_one_step(self, tgt_emb_in):
145
+ ''' one-step forward model
146
+
147
+ Parameters
148
+ ----------
149
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
150
+
151
+ Returns
152
+ -------
153
+ Tensor, shape (batch_size, hidden_dim)
154
+ '''
155
+ tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \
156
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
157
+ self.hidden_dict['h'], self.hidden_dict['c'])
158
+ return tgt_hidden_out
graph_grammar/nn/encoder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Aug 9 2018"
20
+
21
+
22
+ import abc
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+ from typing import List
28
+
29
+
30
+ class EncoderBase(nn.Module):
31
+
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ @abc.abstractmethod
36
+ def forward(self, in_seq):
37
+ ''' forward model
38
+
39
+ Parameters
40
+ ----------
41
+ in_seq_emb : Variable, shape (batch_size, max_len, input_dim)
42
+
43
+ Returns
44
+ -------
45
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
46
+ '''
47
+ pass
48
+
49
+ @abc.abstractmethod
50
+ def init_hidden(self):
51
+ ''' initialize the hidden states
52
+ '''
53
+ pass
54
+
55
+
56
+ class GRUEncoder(EncoderBase):
57
+
58
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
59
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
60
+ no_dropout=False):
61
+ super().__init__()
62
+ self.input_dim = input_dim
63
+ self.hidden_dim = hidden_dim
64
+ self.num_layers = num_layers
65
+ self.bidirectional = bidirectional
66
+ self.dropout = dropout
67
+ self.batch_size = batch_size
68
+ self.use_gpu = use_gpu
69
+ self.model = nn.GRU(input_size=self.input_dim,
70
+ hidden_size=self.hidden_dim,
71
+ num_layers=self.num_layers,
72
+ batch_first=True,
73
+ bidirectional=self.bidirectional,
74
+ dropout=self.dropout if not no_dropout else 0)
75
+ if self.use_gpu:
76
+ self.model.cuda()
77
+ self.init_hidden()
78
+
79
+
80
+ def init_hidden(self):
81
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
82
+ self.batch_size,
83
+ self.hidden_dim),
84
+ requires_grad=False)
85
+ if self.use_gpu:
86
+ self.h0 = self.h0.cuda()
87
+
88
+ def forward(self, in_seq_emb):
89
+ ''' forward model
90
+
91
+ Parameters
92
+ ----------
93
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
94
+
95
+ Returns
96
+ -------
97
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
98
+ '''
99
+ max_len = in_seq_emb.size(1)
100
+ hidden_seq_emb, self.h0 = self.model(
101
+ in_seq_emb, self.h0)
102
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
103
+ max_len,
104
+ 1 + self.bidirectional,
105
+ self.hidden_dim)
106
+ return hidden_seq_emb
107
+
108
+
109
+ class LSTMEncoder(EncoderBase):
110
+
111
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
112
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
113
+ no_dropout=False):
114
+ super().__init__()
115
+ self.input_dim = input_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.bidirectional = bidirectional
119
+ self.dropout = dropout
120
+ self.batch_size = batch_size
121
+ self.use_gpu = use_gpu
122
+ self.model = nn.LSTM(input_size=self.input_dim,
123
+ hidden_size=self.hidden_dim,
124
+ num_layers=self.num_layers,
125
+ batch_first=True,
126
+ bidirectional=self.bidirectional,
127
+ dropout=self.dropout if not no_dropout else 0)
128
+ if self.use_gpu:
129
+ self.model.cuda()
130
+ self.init_hidden()
131
+
132
+ def init_hidden(self):
133
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
134
+ self.batch_size,
135
+ self.hidden_dim),
136
+ requires_grad=False)
137
+ self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
138
+ self.batch_size,
139
+ self.hidden_dim),
140
+ requires_grad=False)
141
+ if self.use_gpu:
142
+ self.h0 = self.h0.cuda()
143
+ self.c0 = self.c0.cuda()
144
+
145
+ def forward(self, in_seq_emb):
146
+ ''' forward model
147
+
148
+ Parameters
149
+ ----------
150
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
151
+
152
+ Returns
153
+ -------
154
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
155
+ '''
156
+ max_len = in_seq_emb.size(1)
157
+ hidden_seq_emb, (self.h0, self.c0) = self.model(
158
+ in_seq_emb, (self.h0, self.c0))
159
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
160
+ max_len,
161
+ 1 + self.bidirectional,
162
+ self.hidden_dim)
163
+ return hidden_seq_emb
164
+
165
+
166
+ class FullConnectedEncoder(EncoderBase):
167
+
168
+ def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int],
169
+ batch_size: int, use_gpu: bool):
170
+ super().__init__()
171
+ self.input_dim = input_dim
172
+ self.hidden_dim = hidden_dim
173
+ self.max_len = max_len
174
+ self.hidden_dim_list = hidden_dim_list
175
+ self.use_gpu = use_gpu
176
+ in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim]
177
+ self.linear_list = nn.ModuleList(
178
+ [nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\
179
+ for each_idx in range(len(in_out_dim_list) - 1)])
180
+
181
+ def forward(self, in_seq_emb):
182
+ ''' forward model
183
+
184
+ Parameters
185
+ ----------
186
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
187
+
188
+ Returns
189
+ -------
190
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
191
+ '''
192
+ batch_size = in_seq_emb.size(0)
193
+ x = in_seq_emb.view(batch_size, -1)
194
+ for each_linear in self.linear_list:
195
+ x = F.relu(each_linear(x))
196
+ return x.view(batch_size, 1, -1)
197
+
198
+ def init_hidden(self):
199
+ pass
graph_grammar/nn/graph.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus
25
+ from torch import nn
26
+ from torch.autograd import Variable
27
+
28
+ class MolecularProdRuleEmbedding(nn.Module):
29
+
30
+ ''' molecular fingerprint layer
31
+ '''
32
+
33
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
34
+ out_dim=32, element_embed_dim=32,
35
+ num_layers=3, padding_idx=None, use_gpu=False):
36
+ super().__init__()
37
+ if padding_idx is not None:
38
+ assert padding_idx == -1, 'padding_idx must be -1.'
39
+ self.prod_rule_corpus = prod_rule_corpus
40
+ self.layer2layer_activation = layer2layer_activation
41
+ self.layer2out_activation = layer2out_activation
42
+ self.out_dim = out_dim
43
+ self.element_embed_dim = element_embed_dim
44
+ self.num_layers = num_layers
45
+ self.padding_idx = padding_idx
46
+ self.use_gpu = use_gpu
47
+
48
+ self.layer2layer_list = []
49
+ self.layer2out_list = []
50
+
51
+ if self.use_gpu:
52
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
53
+ self.element_embed_dim, requires_grad=True).cuda()
54
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
55
+ self.element_embed_dim, requires_grad=True).cuda()
56
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
57
+ self.element_embed_dim, requires_grad=True).cuda()
58
+ for _ in range(num_layers):
59
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
60
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
61
+ else:
62
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
63
+ self.element_embed_dim, requires_grad=True)
64
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
65
+ self.element_embed_dim, requires_grad=True)
66
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
67
+ self.element_embed_dim, requires_grad=True)
68
+ for _ in range(num_layers):
69
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
70
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
71
+
72
+
73
+ def forward(self, prod_rule_idx_seq):
74
+ ''' forward model for mini-batch
75
+
76
+ Parameters
77
+ ----------
78
+ prod_rule_idx_seq : (batch_size, length)
79
+
80
+ Returns
81
+ -------
82
+ Variable, shape (batch_size, length, out_dim)
83
+ '''
84
+ batch_size, length = prod_rule_idx_seq.shape
85
+ if self.use_gpu:
86
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
87
+ else:
88
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
89
+ for each_batch_idx in range(batch_size):
90
+ for each_idx in range(length):
91
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
92
+ continue
93
+ else:
94
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
95
+ layer_wise_embed_dict = {each_edge: self.atom_embed[
96
+ each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
97
+ for each_edge in each_prod_rule.rhs.edges}
98
+ layer_wise_embed_dict.update({each_node: self.bond_embed[
99
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]
100
+ for each_node in each_prod_rule.rhs.nodes})
101
+ for each_node in each_prod_rule.rhs.nodes:
102
+ if 'ext_id' in each_prod_rule.rhs.node_attr(each_node):
103
+ layer_wise_embed_dict[each_node] \
104
+ = layer_wise_embed_dict[each_node] \
105
+ + self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']]
106
+
107
+ for each_layer in range(self.num_layers):
108
+ next_layer_embed_dict = {}
109
+ for each_edge in each_prod_rule.rhs.edges:
110
+ v = layer_wise_embed_dict[each_edge]
111
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
112
+ v = v + layer_wise_embed_dict[each_node]
113
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
114
+ out[each_batch_idx, each_idx, :] \
115
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
116
+ for each_node in each_prod_rule.rhs.nodes:
117
+ v = layer_wise_embed_dict[each_node]
118
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
119
+ v = v + layer_wise_embed_dict[each_edge]
120
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
121
+ out[each_batch_idx, each_idx, :]\
122
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
123
+ layer_wise_embed_dict = next_layer_embed_dict
124
+
125
+ return out
126
+
127
+
128
+ class MolecularProdRuleEmbeddingLastLayer(nn.Module):
129
+
130
+ ''' molecular fingerprint layer
131
+ '''
132
+
133
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
134
+ out_dim=32, element_embed_dim=32,
135
+ num_layers=3, padding_idx=None, use_gpu=False):
136
+ super().__init__()
137
+ if padding_idx is not None:
138
+ assert padding_idx == -1, 'padding_idx must be -1.'
139
+ self.prod_rule_corpus = prod_rule_corpus
140
+ self.layer2layer_activation = layer2layer_activation
141
+ self.layer2out_activation = layer2out_activation
142
+ self.out_dim = out_dim
143
+ self.element_embed_dim = element_embed_dim
144
+ self.num_layers = num_layers
145
+ self.padding_idx = padding_idx
146
+ self.use_gpu = use_gpu
147
+
148
+ self.layer2layer_list = []
149
+ self.layer2out_list = []
150
+
151
+ if self.use_gpu:
152
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda()
153
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda()
154
+ for _ in range(num_layers+1):
155
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
156
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
157
+ else:
158
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim)
159
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim)
160
+ for _ in range(num_layers+1):
161
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
162
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
163
+
164
+
165
+ def forward(self, prod_rule_idx_seq):
166
+ ''' forward model for mini-batch
167
+
168
+ Parameters
169
+ ----------
170
+ prod_rule_idx_seq : (batch_size, length)
171
+
172
+ Returns
173
+ -------
174
+ Variable, shape (batch_size, length, out_dim)
175
+ '''
176
+ batch_size, length = prod_rule_idx_seq.shape
177
+ if self.use_gpu:
178
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
179
+ else:
180
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
181
+ for each_batch_idx in range(batch_size):
182
+ for each_idx in range(length):
183
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
184
+ continue
185
+ else:
186
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
187
+
188
+ if self.use_gpu:
189
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
190
+ Variable(torch.LongTensor(
191
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
192
+ ), requires_grad=False).cuda())
193
+ for each_edge in each_prod_rule.rhs.edges}
194
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
195
+ Variable(
196
+ torch.LongTensor([
197
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
198
+ requires_grad=False).cuda()
199
+ ) for each_node in each_prod_rule.rhs.nodes})
200
+ else:
201
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
202
+ Variable(torch.LongTensor(
203
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
204
+ ), requires_grad=False))
205
+ for each_edge in each_prod_rule.rhs.edges}
206
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
207
+ Variable(
208
+ torch.LongTensor([
209
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
210
+ requires_grad=False)
211
+ ) for each_node in each_prod_rule.rhs.nodes})
212
+
213
+ for each_layer in range(self.num_layers):
214
+ next_layer_embed_dict = {}
215
+ for each_edge in each_prod_rule.rhs.edges:
216
+ v = layer_wise_embed_dict[each_edge]
217
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
218
+ v += layer_wise_embed_dict[each_node]
219
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
220
+ for each_node in each_prod_rule.rhs.nodes:
221
+ v = layer_wise_embed_dict[each_node]
222
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
223
+ v += layer_wise_embed_dict[each_edge]
224
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
225
+ layer_wise_embed_dict = next_layer_embed_dict
226
+ for each_edge in each_prod_rule.rhs.edges:
227
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
228
+ for each_edge in each_prod_rule.rhs.edges:
229
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
230
+
231
+ return out
232
+
233
+
234
+ class MolecularProdRuleEmbeddingUsingFeatures(nn.Module):
235
+
236
+ ''' molecular fingerprint layer
237
+ '''
238
+
239
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
240
+ out_dim=32, num_layers=3, padding_idx=None, use_gpu=False):
241
+ super().__init__()
242
+ if padding_idx is not None:
243
+ assert padding_idx == -1, 'padding_idx must be -1.'
244
+ self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors()
245
+ self.prod_rule_corpus = prod_rule_corpus
246
+ self.layer2layer_activation = layer2layer_activation
247
+ self.layer2out_activation = layer2out_activation
248
+ self.out_dim = out_dim
249
+ self.num_layers = num_layers
250
+ self.padding_idx = padding_idx
251
+ self.use_gpu = use_gpu
252
+
253
+ self.layer2layer_list = []
254
+ self.layer2out_list = []
255
+
256
+ if self.use_gpu:
257
+ for each_key in self.feature_dict:
258
+ self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda()
259
+ for _ in range(num_layers):
260
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda())
261
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda())
262
+ else:
263
+ for _ in range(num_layers):
264
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim))
265
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim))
266
+
267
+
268
+ def forward(self, prod_rule_idx_seq):
269
+ ''' forward model for mini-batch
270
+
271
+ Parameters
272
+ ----------
273
+ prod_rule_idx_seq : (batch_size, length)
274
+
275
+ Returns
276
+ -------
277
+ Variable, shape (batch_size, length, out_dim)
278
+ '''
279
+ batch_size, length = prod_rule_idx_seq.shape
280
+ if self.use_gpu:
281
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
282
+ else:
283
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
284
+ for each_batch_idx in range(batch_size):
285
+ for each_idx in range(length):
286
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
287
+ continue
288
+ else:
289
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
290
+ edge_list = sorted(list(each_prod_rule.rhs.edges))
291
+ node_list = sorted(list(each_prod_rule.rhs.nodes))
292
+ adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list)))
293
+ if self.use_gpu:
294
+ adj_mat = adj_mat.cuda()
295
+ layer_wise_embed = [
296
+ self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']]
297
+ for each_edge in edge_list]\
298
+ + [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']]
299
+ for each_node in node_list]
300
+ for each_node in each_prod_rule.ext_node.values():
301
+ layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
302
+ = layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
303
+ + self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])]
304
+ layer_wise_embed = torch.stack(layer_wise_embed)
305
+
306
+ for each_layer in range(self.num_layers):
307
+ message = adj_mat @ layer_wise_embed
308
+ next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message))
309
+ out[each_batch_idx, each_idx, :] \
310
+ = out[each_batch_idx, each_idx, :] \
311
+ + self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0)
312
+ layer_wise_embed = next_layer_embed
313
+ return out
images/mhg_example.png ADDED
images/mhg_example1.png ADDED
images/mhg_example2.png ADDED
load.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ from rdkit import Chem
12
+ import torch
13
+ from torch_geometric.utils.smiles import from_smiles
14
+
15
+ from typing import Any, Dict, List, Optional, Union
16
+ from typing_extensions import Self
17
+
18
+ from .graph_grammar.io.smi import hg_to_mol
19
+ from .models.mhgvae import GrammarGINVAE
20
+
21
+
22
+ class PretrainedModelWrapper:
23
+ model: GrammarGINVAE
24
+
25
+ def __init__(self, model_dict: Dict[str, Any]) -> None:
26
+ json_params = model_dict['gnn_params']
27
+ encoder_params = json_params['encoder_params']
28
+ encoder_params['node_feature_size'] = model_dict['num_features']
29
+ encoder_params['edge_feature_size'] = model_dict['num_edge_features']
30
+ self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
31
+ decoder_params=json_params['decoder_params'],
32
+ prod_rule_embed_params=json_params["prod_rule_embed_params"],
33
+ batch_size=512, max_len=model_dict['max_length'])
34
+ self.model.load_state_dict(model_dict['model_state_dict'])
35
+
36
+ self.model.eval()
37
+
38
+ def to(self, device: Union[str, int, torch.device]) -> Self:
39
+ dev_type = type(device)
40
+ if dev_type != torch.device:
41
+ if dev_type == str or torch.cuda.is_available():
42
+ device = torch.device(device)
43
+ else:
44
+ device = torch.device("mps", device)
45
+
46
+ self.model = self.model.to(device)
47
+ return self
48
+
49
+ def encode(self, data: List[str]) -> List[torch.tensor]:
50
+ # Need to encode them into a graph nn
51
+ output = []
52
+ for d in data:
53
+ params = next(self.model.parameters())
54
+ g = from_smiles(d)
55
+ if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
56
+ g.to(params.device)
57
+ ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
58
+ output.append(ltvec[0])
59
+ return output
60
+
61
+ def decode(self, data: List[torch.tensor]) -> List[str]:
62
+ output = []
63
+ for d in data:
64
+ mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
65
+ z = self.model.reparameterize(mu, logvar)
66
+ flags, _, hgs = self.model.decode(z)
67
+ if flags[0]:
68
+ reconstructed_mol, _ = hg_to_mol(hgs[0], True)
69
+ output.append(Chem.MolToSmiles(reconstructed_mol))
70
+ else:
71
+ output.append(None)
72
+ return output
73
+
74
+
75
+ def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
76
+ PretrainedModelWrapper]:
77
+ for p in sys.path:
78
+ file = p + "/" + model_name
79
+ if os.path.isfile(file):
80
+ with open(file, "rb") as f:
81
+ model_dict = pickle.load(f)
82
+ return PretrainedModelWrapper(model_dict)
83
+ return None
mhg_gnn.egg-info/PKG-INFO ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: mhg-gnn
3
+ Version: 0.0
4
+ Summary: Package for mhg-gnn
5
+ Author: team
6
+ License: TBD
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.9
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: networkx>=2.8
11
+ Requires-Dist: numpy<2.0.0,>=1.23.5
12
+ Requires-Dist: pandas>=1.5.3
13
+ Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4
14
+ Requires-Dist: torch>=2.0.0
15
+ Requires-Dist: torchinfo>=1.8.0
16
+ Requires-Dist: torch-geometric>=2.3.1
17
+
18
+ # mhg-gnn
19
+
20
+ This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
21
+
22
+ **Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
23
+
24
+ For more information contact: SEIJITKD@jp.ibm.com
25
+
26
+ ![mhg-gnn](images/mhg_example1.png)
27
+
28
+ ## Introduction
29
+
30
+ We present MHG-GNN, an autoencoder architecture
31
+ that has an encoder based on GNN and a decoder based on a sequential model with MHG.
32
+ Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
33
+ demonstrate high predictive performance on molecular graph data.
34
+ In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
35
+
36
+ ## Table of Contents
37
+
38
+ 1. [Getting Started](#getting-started)
39
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
40
+ 2. [Replicating Conda Environment](#replicating-conda-environment)
41
+ 2. [Feature Extraction](#feature-extraction)
42
+
43
+ ## Getting Started
44
+
45
+ **This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
46
+
47
+ ### Pretrained Models and Training Logs
48
+
49
+ We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
50
+
51
+ Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
52
+
53
+ ### Replacicating Conda Environment
54
+
55
+ Follow these steps to replicate our Conda environment and install the necessary libraries:
56
+
57
+ ```
58
+ conda create --name mhg-gnn-env python=3.8.18
59
+ conda activate mhg-gnn-env
60
+ ```
61
+
62
+ #### Install Packages with Conda
63
+
64
+ ```
65
+ conda install -c conda-forge networkx=2.8
66
+ conda install numpy=1.23.5
67
+ # conda install -c conda-forge rdkit=2022.9.4
68
+ conda install pytorch=2.0.0 torchvision torchaudio -c pytorch
69
+ conda install -c conda-forge torchinfo=1.8.0
70
+ conda install pyg -c pyg
71
+ ```
72
+
73
+ #### Install Packages with pip
74
+ ```
75
+ pip install rdkit torch-nl==0.3 torch-scatter torch-sparse
76
+ ```
77
+
78
+ ## Feature Extraction
79
+
80
+ The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
81
+
82
+ To load mhg-gnn, you can simply use:
83
+
84
+ ```python
85
+ import torch
86
+ import load
87
+
88
+ model = load.load()
89
+ ```
90
+
91
+ To encode SMILES into embeddings, you can use:
92
+
93
+ ```python
94
+ with torch.no_grad():
95
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
96
+ ```
97
+
98
+ For decoder, you can use the function, so you can return from embeddings to SMILES strings:
99
+
100
+ ```python
101
+ orig = model.decode(repr)
102
+ ```
mhg_gnn.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.cfg
3
+ setup.py
4
+ ./graph_grammar/__init__.py
5
+ ./graph_grammar/hypergraph.py
6
+ ./graph_grammar/algo/__init__.py
7
+ ./graph_grammar/algo/tree_decomposition.py
8
+ ./graph_grammar/graph_grammar/__init__.py
9
+ ./graph_grammar/graph_grammar/base.py
10
+ ./graph_grammar/graph_grammar/corpus.py
11
+ ./graph_grammar/graph_grammar/hrg.py
12
+ ./graph_grammar/graph_grammar/symbols.py
13
+ ./graph_grammar/graph_grammar/utils.py
14
+ ./graph_grammar/io/__init__.py
15
+ ./graph_grammar/io/smi.py
16
+ ./graph_grammar/nn/__init__.py
17
+ ./graph_grammar/nn/dataset.py
18
+ ./graph_grammar/nn/decoder.py
19
+ ./graph_grammar/nn/encoder.py
20
+ ./graph_grammar/nn/graph.py
21
+ ./models/__init__.py
22
+ ./models/mhgvae.py
23
+ graph_grammar/__init__.py
24
+ graph_grammar/hypergraph.py
25
+ graph_grammar/algo/__init__.py
26
+ graph_grammar/algo/tree_decomposition.py
27
+ graph_grammar/graph_grammar/__init__.py
28
+ graph_grammar/graph_grammar/base.py
29
+ graph_grammar/graph_grammar/corpus.py
30
+ graph_grammar/graph_grammar/hrg.py
31
+ graph_grammar/graph_grammar/symbols.py
32
+ graph_grammar/graph_grammar/utils.py
33
+ graph_grammar/io/__init__.py
34
+ graph_grammar/io/smi.py
35
+ graph_grammar/nn/__init__.py
36
+ graph_grammar/nn/dataset.py
37
+ graph_grammar/nn/decoder.py
38
+ graph_grammar/nn/encoder.py
39
+ graph_grammar/nn/graph.py
40
+ mhg_gnn.egg-info/PKG-INFO
41
+ mhg_gnn.egg-info/SOURCES.txt
42
+ mhg_gnn.egg-info/dependency_links.txt
43
+ mhg_gnn.egg-info/requires.txt
44
+ mhg_gnn.egg-info/top_level.txt
45
+ models/__init__.py
46
+ models/mhgvae.py