diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..bc34d1d00a09644e3a67e48e4c78fdc9ba58d104
Binary files /dev/null and b/.DS_Store differ
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..26d33521af10bcc7fd8cea344038eaaeb78d0ef5
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..2207299f02f75c9370600b47ce77051ca3b54130
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,29 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/materials.mhg-ged.iml b/.idea/materials.mhg-ged.iml
new file mode 100644
index 0000000000000000000000000000000000000000..039314de6c082718b0c4495bd64f8df7279e477f
--- /dev/null
+++ b/.idea/materials.mhg-ged.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6c89280f83fddd489fc6c07c979f7f1ba7c4969f
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..35eb1ddfbbc029bcab630581847471d7f238ec53
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 7b95401dc46245ac339fc25059d4a56d90b4cde5..9709c8cca394a4de06e80fb7ea7a711b8d3e6749 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,78 @@
----
-license: apache-2.0
----
+---
+license: apache-2.0
+---
+# mhg-gnn
+
+This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
+
+**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
+
+![mhg-gnn](images/mhg_example1.png)
+
+## Introduction
+
+We present MHG-GNN, an autoencoder architecture
+that has an encoder based on GNN and a decoder based on a sequential model with MHG.
+Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
+demonstrate high predictive performance on molecular graph data.
+In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
+
+## Table of Contents
+
+1. [Getting Started](#getting-started)
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
+ 2. [Installation](#installation)
+2. [Feature Extraction](#feature-extraction)
+
+## Getting Started
+
+**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
+
+### Pretrained Models and Training Logs
+
+We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
+
+Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
+
+### Installation
+
+We recommend to create a virtual environment. For example:
+
+```
+python3 -m venv .venv
+. .venv/bin/activate
+```
+
+Type the following command once the virtual environment is activated:
+
+```
+git clone git@github.ibm.com:CMD-TRL/mhg-gnn.git
+cd ./mhg-gnn
+pip install .
+```
+
+## Feature Extraction
+
+The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
+
+To load mhg-gnn, you can simply use:
+
+```python
+import torch
+import load
+
+model = load.load()
+```
+
+To encode SMILES into embeddings, you can use:
+
+```python
+with torch.no_grad():
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
+```
+
+For decoder, you can use the function, so you can return from embeddings to SMILES strings:
+
+```python
+orig = model.decode(repr)
+```
\ No newline at end of file
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed12f0d3c8ed176384f16d26197882c9ad47a36c
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding:utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
\ No newline at end of file
diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82cf6ca01ddef5690bc9ced78d22113c2a6a7528
Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ
diff --git a/__pycache__/load.cpython-310.pyc b/__pycache__/load.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96a31d1e584591bc095c35e8e539acd4f23a1c4e
Binary files /dev/null and b/__pycache__/load.cpython-310.pyc differ
diff --git a/graph_grammar/.DS_Store b/graph_grammar/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..c5f5f83e31db413d50194276d0d7497e37b1f13a
Binary files /dev/null and b/graph_grammar/.DS_Store differ
diff --git a/graph_grammar/__init__.py b/graph_grammar/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..26f82acdc3d0383157745e5f0fe8ddd870325145
--- /dev/null
+++ b/graph_grammar/__init__.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 1 2018"
+
diff --git a/graph_grammar/__pycache__/__init__.cpython-310.pyc b/graph_grammar/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ced0ce9fa98dbac640cd04dcdbeb90278633020
Binary files /dev/null and b/graph_grammar/__pycache__/__init__.cpython-310.pyc differ
diff --git a/graph_grammar/__pycache__/hypergraph.cpython-310.pyc b/graph_grammar/__pycache__/hypergraph.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da6dafdb54e559d55d716ecb81242709a112a1b8
Binary files /dev/null and b/graph_grammar/__pycache__/hypergraph.cpython-310.pyc differ
diff --git a/graph_grammar/algo/__init__.py b/graph_grammar/algo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6f597c0a768b8dd64708ec70fe7071000953ce8
--- /dev/null
+++ b/graph_grammar/algo/__init__.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 1 2018"
+
diff --git a/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc b/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18bd58db664f9fcabf573dcdc21fd6e7ef497e89
Binary files /dev/null and b/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc differ
diff --git a/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc b/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d050b46a7e787f475c853001d7510d3c8b1d4031
Binary files /dev/null and b/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc differ
diff --git a/graph_grammar/algo/tree_decomposition.py b/graph_grammar/algo/tree_decomposition.py
new file mode 100644
index 0000000000000000000000000000000000000000..81cb7748f573c99597c8f0658555b9efd1171cfa
--- /dev/null
+++ b/graph_grammar/algo/tree_decomposition.py
@@ -0,0 +1,821 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2017"
+__version__ = "0.1"
+__date__ = "Dec 11 2017"
+
+from copy import deepcopy
+from itertools import combinations
+from ..hypergraph import Hypergraph
+import networkx as nx
+import numpy as np
+
+
+class CliqueTree(nx.Graph):
+ ''' clique tree object
+
+ Attributes
+ ----------
+ hg : Hypergraph
+ This hypergraph will be decomposed.
+ root_hg : Hypergraph
+ Hypergraph on the root node.
+ ident_node_dict : dict
+ ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
+ '''
+ def __init__(self, hg=None, **kwargs):
+ self.hg = deepcopy(hg)
+ if self.hg is not None:
+ self.ident_node_dict = self.hg.get_identical_node_dict()
+ else:
+ self.ident_node_dict = {}
+ super().__init__(**kwargs)
+
+ @property
+ def root_hg(self):
+ ''' return the hypergraph on the root node
+ '''
+ return self.nodes[0]['subhg']
+
+ @root_hg.setter
+ def root_hg(self, hypergraph):
+ ''' set the hypergraph on the root node
+ '''
+ self.nodes[0]['subhg'] = hypergraph
+
+ def insert_subhg(self, subhypergraph: Hypergraph) -> None:
+ ''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
+
+ Parameters
+ ----------
+ subhg : Hypergraph
+ '''
+ num_nodes = self.number_of_nodes()
+ self.add_node(num_nodes, subhg=subhypergraph)
+ self.add_edge(num_nodes, 0)
+ adj_nodes = deepcopy(list(self.adj[0].keys()))
+ for each_node in adj_nodes:
+ if len(self.nodes[each_node]["subhg"].nodes.intersection(
+ self.nodes[num_nodes]["subhg"].nodes)\
+ - self.root_hg.nodes) != 0 and each_node != num_nodes:
+ self.remove_edge(0, each_node)
+ self.add_edge(each_node, num_nodes)
+
+ def to_irredundant(self) -> None:
+ ''' convert the clique tree to be irredundant
+ '''
+ for each_node in self.hg.nodes:
+ subtree = self.subgraph([
+ each_tree_node for each_tree_node in self.nodes()\
+ if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
+ leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
+ redundant_leaf_node_list = []
+ for each_leaf_node in leaf_node_list:
+ if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
+ redundant_leaf_node_list.append(each_leaf_node)
+ for each_red_leaf_node in redundant_leaf_node_list:
+ current_node = each_red_leaf_node
+ while subtree.degree(current_node) == 1 \
+ and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
+ self.nodes[current_node]["subhg"].remove_node(each_node)
+ remove_node = current_node
+ current_node = list(dict(subtree[remove_node]).keys())[0]
+ subtree.remove_node(remove_node)
+
+ fixed_node_set = deepcopy(self.nodes)
+ for each_node in fixed_node_set:
+ if self.nodes[each_node]["subhg"].num_edges == 0:
+ if len(self[each_node]) == 1:
+ self.remove_node(each_node)
+ elif len(self[each_node]) == 2:
+ self.add_edge(*self[each_node])
+ self.remove_node(each_node)
+ else:
+ pass
+ else:
+ pass
+
+ redundant = True
+ while redundant:
+ redundant = False
+ fixed_edge_set = deepcopy(self.edges)
+ remove_node_set = set()
+ for node_1, node_2 in fixed_edge_set:
+ if node_1 in remove_node_set or node_2 in remove_node_set:
+ pass
+ else:
+ if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
+ redundant = True
+ adj_node_list = set(self.adj[node_1]) - {node_2}
+ self.remove_node(node_1)
+ remove_node_set.add(node_1)
+ for each_node in adj_node_list:
+ self.add_edge(node_2, each_node)
+
+ elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
+ redundant = True
+ adj_node_list = set(self.adj[node_2]) - {node_1}
+ self.remove_node(node_2)
+ remove_node_set.add(node_2)
+ for each_node in adj_node_list:
+ self.add_edge(node_1, each_node)
+
+ def node_update(self, key_node: str, subhg) -> None:
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
+
+ Parameters
+ ----------
+ key_node : str
+ key node that must be removed.
+ subhg : Hypegraph
+ """
+ for each_edge in subhg.edges:
+ self.root_hg.remove_edge(each_edge)
+ self.root_hg.remove_nodes(self.ident_node_dict[key_node])
+
+ adj_node_list = list(subhg.nodes)
+ for each_node in subhg.nodes:
+ if each_node not in self.ident_node_dict[key_node]:
+ if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
+ self.root_hg.remove_node(each_node)
+ adj_node_list.remove(each_node)
+ else:
+ adj_node_list.remove(each_node)
+
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
+ self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
+
+ subhg.remove_edges_with_attr({'tmp' : True})
+ self.insert_subhg(subhg)
+
+ def update(self, subhg, remove_nodes=False):
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
+
+ Parameters
+ ----------
+ subhg : Hypegraph
+ """
+ for each_edge in subhg.edges:
+ self.root_hg.remove_edge(each_edge)
+ if remove_nodes:
+ remove_edge_list = []
+ for each_edge in self.root_hg.edges:
+ if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
+ and self.root_hg.edge_attr(each_edge).get('tmp', False):
+ remove_edge_list.append(each_edge)
+ self.root_hg.remove_edges(remove_edge_list)
+
+ adj_node_list = list(subhg.nodes)
+ for each_node in subhg.nodes:
+ if self.root_hg.degree(each_node) == 0:
+ self.root_hg.remove_node(each_node)
+ adj_node_list.remove(each_node)
+
+ if len(adj_node_list) != 1 and not remove_nodes:
+ self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
+ '''
+ else:
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
+ self.root_hg.add_edge(
+ [each_node_1, each_node_2], attr_dict=dict(tmp=True))
+ '''
+ subhg.remove_edges_with_attr({'tmp':True})
+ self.insert_subhg(subhg)
+
+
+def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
+ if mode == 'standard':
+ degree_dict = hg.degrees()
+ min_deg_node = min(degree_dict, key=degree_dict.get)
+ min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
+ return min_deg_node, min_deg_subhg
+ elif mode == 'mol':
+ degree_dict = hg.degrees()
+ min_deg = min(degree_dict.values())
+ min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
+ min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
+ for each_min_deg_node in min_deg_node_list]
+ best_score = np.inf
+ best_idx = -1
+ for each_idx in range(len(min_deg_subhg_list)):
+ if min_deg_subhg_list[each_idx].num_nodes < best_score:
+ best_idx = each_idx
+ return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
+ else:
+ raise ValueError
+
+
+def tree_decomposition(hg, irredundant=True):
+ """ compute a tree decomposition of the input hypergraph
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ hypergraph to be decomposed
+ irredundant : bool
+ if True, irredundant tree decomposition will be computed.
+
+ Returns
+ -------
+ clique_tree : nx.Graph
+ each node contains a subhypergraph of `hg`
+ """
+ org_hg = hg.copy()
+ ident_node_dict = hg.get_identical_node_dict()
+ clique_tree = CliqueTree(org_hg)
+ clique_tree.add_node(0, subhg=org_hg)
+ while True:
+ degree_dict = org_hg.degrees()
+ min_deg_node = min(degree_dict, key=degree_dict.get)
+ min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
+ if org_hg.nodes == min_deg_subhg.nodes:
+ break
+
+ # org_hg and min_deg_subhg are divided
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
+
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
+
+ if irredundant:
+ clique_tree.to_irredundant()
+ return clique_tree
+
+
+def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
+ ''' compute a tree decomposition given a hyperedge replacement grammar.
+ the resultant clique tree should induce a less compact HRG.
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ hypergraph to be decomposed
+ hrg : HyperedgeReplacementGrammar
+ current HRG
+ irredundant : bool
+ if True, irredundant tree decomposition will be computed.
+
+ Returns
+ -------
+ clique_tree : nx.Graph
+ each node contains a subhypergraph of `hg`
+ '''
+ org_hg = hg.copy()
+ ident_node_dict = hg.get_identical_node_dict()
+ clique_tree = CliqueTree(org_hg)
+ clique_tree.add_node(0, subhg=org_hg)
+ root_node = 0
+
+ # construct a clique tree using HRG
+ success_any = True
+ while success_any:
+ success_any = False
+ for each_prod_rule in hrg.prod_rule_list:
+ org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
+ if success:
+ if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
+ success_any = True
+ subhg.remove_edges_with_attr({'terminal' : False})
+ clique_tree.root_hg = org_hg
+ clique_tree.insert_subhg(subhg)
+
+ clique_tree.root_hg = org_hg
+
+ for each_edge in deepcopy(org_hg.edges):
+ if not org_hg.edge_attr(each_edge)['terminal']:
+ node_list = org_hg.nodes_in_edge(each_edge)
+ org_hg.remove_edge(each_edge)
+
+ for each_node_1, each_node_2 in combinations(node_list, 2):
+ if not org_hg.is_adj(each_node_1, each_node_2):
+ org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
+
+ # construct a clique tree using the existing algorithm
+ degree_dict = org_hg.degrees()
+ if degree_dict:
+ while True:
+ min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
+ if org_hg.nodes == min_deg_subhg.nodes: break
+
+ # org_hg and min_deg_subhg are divided
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
+
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
+ if irredundant:
+ clique_tree.to_irredundant()
+
+ if return_root:
+ if root_node == 0 and 0 not in clique_tree.nodes:
+ root_node = clique_tree.number_of_nodes()
+ while root_node not in clique_tree.nodes:
+ root_node -= 1
+ elif root_node not in clique_tree.nodes:
+ while root_node not in clique_tree.nodes:
+ root_node -= 1
+ else:
+ pass
+ return clique_tree, root_node
+ else:
+ return clique_tree
+
+
+def tree_decomposition_from_leaf(hg, irredundant=True):
+ """ compute a tree decomposition of the input hypergraph
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ hypergraph to be decomposed
+ irredundant : bool
+ if True, irredundant tree decomposition will be computed.
+
+ Returns
+ -------
+ clique_tree : nx.Graph
+ each node contains a subhypergraph of `hg`
+ """
+ def apply_normal_decomposition(clique_tree):
+ degree_dict = clique_tree.root_hg.degrees()
+ min_deg_node = min(degree_dict, key=degree_dict.get)
+ min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
+ return clique_tree, False
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
+ return clique_tree, True
+
+ def apply_min_edge_deg_decomposition(clique_tree):
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
+ non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
+ if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
+ if not non_tmp_edge_list:
+ return clique_tree, False
+ min_deg_edge = None
+ min_deg = np.inf
+ for each_edge in non_tmp_edge_list:
+ if min_deg > edge_degree_dict[each_edge]:
+ min_deg_edge = each_edge
+ min_deg = edge_degree_dict[each_edge]
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
+ return clique_tree, False
+ clique_tree.update(min_deg_subhg)
+ return clique_tree, True
+
+ org_hg = hg.copy()
+ clique_tree = CliqueTree(org_hg)
+ clique_tree.add_node(0, subhg=org_hg)
+
+ success = True
+ while success:
+ clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
+ if not success:
+ clique_tree, success = apply_normal_decomposition(clique_tree)
+
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
+ if irredundant:
+ clique_tree.to_irredundant()
+ return clique_tree
+
+def topological_tree_decomposition(
+ hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
+ ''' compute a tree decomposition of the input hypergraph
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ hypergraph to be decomposed
+ irredundant : bool
+ if True, irredundant tree decomposition will be computed.
+
+ Returns
+ -------
+ clique_tree : CliqueTree
+ each node contains a subhypergraph of `hg`
+ '''
+ def _contract_tree(clique_tree):
+ ''' contract a single leaf
+
+ Parameters
+ ----------
+ clique_tree : CliqueTree
+
+ Returns
+ -------
+ CliqueTree, bool
+ bool represents whether this operation succeeds or not.
+ '''
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
+ leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
+ if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
+ and edge_degree_dict[each_edge] == 1]
+ if not leaf_edge_list:
+ return clique_tree, False
+ min_deg_edge = leaf_edge_list[0]
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
+ return clique_tree, False
+ clique_tree.update(min_deg_subhg)
+ return clique_tree, True
+
+ def _rip_labels_from_cycles(clique_tree, org_hg):
+ ''' rip hyperedge-labels off
+
+ Parameters
+ ----------
+ clique_tree : CliqueTree
+ org_hg : Hypergraph
+
+ Returns
+ -------
+ CliqueTree, bool
+ bool represents whether this operation succeeds or not.
+ '''
+ ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
+ for each_edge in clique_tree.root_hg.edges:
+ if each_edge in org_hg.edges:
+ if org_hg.in_cycle(each_edge):
+ node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
+ subhg = clique_tree.root_hg.get_subhg(
+ node_list, [each_edge], ident_node_dict)
+ if clique_tree.root_hg.nodes == subhg.nodes:
+ return clique_tree, False
+ clique_tree.update(subhg)
+ '''
+ in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
+ if not all(in_cycle_dict.values()):
+ node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
+ node_list = [node_not_in_cycle]
+ node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
+ edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
+ import pdb; pdb.set_trace()
+ subhg = clique_tree.root_hg.get_subhg(
+ node_list, edge_list, ident_node_dict)
+
+ clique_tree.update(subhg)
+ '''
+ return clique_tree, True
+ return clique_tree, False
+
+ def _shrink_cycle(clique_tree):
+ ''' shrink a cycle
+
+ Parameters
+ ----------
+ clique_tree : CliqueTree
+
+ Returns
+ -------
+ CliqueTree, bool
+ bool represents whether this operation succeeds or not.
+ '''
+ def filter_subhg(subhg, hg, key_node):
+ num_nodes_cycle = 0
+ nodes_in_cycle_list = []
+ for each_node in subhg.nodes:
+ if hg.in_cycle(each_node):
+ num_nodes_cycle += 1
+ if each_node != key_node:
+ nodes_in_cycle_list.append(each_node)
+ if num_nodes_cycle > 3:
+ break
+ if num_nodes_cycle != 3:
+ return False
+ else:
+ for each_edge in hg.edges:
+ if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
+ return False
+ return True
+
+ #ident_node_dict = hg.get_identical_node_dict()
+ ident_node_dict = clique_tree.ident_node_dict
+ for each_node in clique_tree.root_hg.nodes:
+ if clique_tree.root_hg.in_cycle(each_node)\
+ and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
+ clique_tree.root_hg,
+ each_node):
+ target_node = each_node
+ target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
+ if clique_tree.root_hg.nodes == target_subhg.nodes:
+ return clique_tree, False
+ clique_tree.update(target_subhg)
+ return clique_tree, True
+ return clique_tree, False
+
+ def _contract_cycles(clique_tree):
+ '''
+ remove a subhypergraph that looks like a cycle on a leaf.
+
+ Parameters
+ ----------
+ clique_tree : CliqueTree
+
+ Returns
+ -------
+ CliqueTree, bool
+ bool represents whether this operation succeeds or not.
+ '''
+ def _divide_hg(hg):
+ ''' divide a hypergraph into subhypergraphs such that
+ each subhypergraph is connected to each other in a tree-like way.
+
+ Parameters
+ ----------
+ hg : Hypergraph
+
+ Returns
+ -------
+ list of Hypergraphs
+ each element corresponds to a subhypergraph of `hg`
+ '''
+ for each_node in hg.nodes:
+ if hg.is_dividable(each_node):
+ adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
+ '''
+ if any(adj_edges_dict.values()):
+ import pdb; pdb.set_trace()
+ edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
+ subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
+ return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
+ else:
+ '''
+ subhg1, subhg2 = hg.divide(each_node)
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
+ return [hg]
+
+ def _is_leaf(hg, divided_subhg) -> bool:
+ ''' judge whether subhg is a leaf-like in the original hypergraph
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ divided_subhg : Hypergraph
+ `divided_subhg` is a subhypergraph of `hg`
+
+ Returns
+ -------
+ bool
+ '''
+ '''
+ adj_edges_set = set([])
+ for each_node in divided_subhg.nodes:
+ adj_edges_set.update(set(hg.adj_edges(each_node)))
+
+
+ _hg = deepcopy(hg)
+ _hg.remove_subhg(divided_subhg)
+ if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
+ import pdb; pdb.set_trace()
+ return len(adj_edges_set - divided_subhg.edges) == 1
+ '''
+ _hg = deepcopy(hg)
+ _hg.remove_subhg(divided_subhg)
+ return nx.is_connected(_hg.hg)
+
+ subhg_list = _divide_hg(clique_tree.root_hg)
+ if len(subhg_list) == 1:
+ return clique_tree, False
+ else:
+ while len(subhg_list) > 1:
+ max_leaf_subhg = None
+ for each_subhg in subhg_list:
+ if _is_leaf(clique_tree.root_hg, each_subhg):
+ if max_leaf_subhg is None:
+ max_leaf_subhg = each_subhg
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
+ max_leaf_subhg = each_subhg
+ clique_tree.update(max_leaf_subhg)
+ subhg_list.remove(max_leaf_subhg)
+ return clique_tree, True
+
+ org_hg = hg.copy()
+ clique_tree = CliqueTree(org_hg)
+ clique_tree.add_node(0, subhg=org_hg)
+
+ success = True
+ while success:
+ '''
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
+ if not success:
+ clique_tree, success = _contract_cycles(clique_tree)
+ '''
+ clique_tree, success = _contract_tree(clique_tree)
+ if not success:
+ if rip_labels:
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
+ if not success:
+ if shrink_cycle:
+ clique_tree, success = _shrink_cycle(clique_tree)
+ if not success:
+ if contract_cycles:
+ clique_tree, success = _contract_cycles(clique_tree)
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
+ if irredundant:
+ clique_tree.to_irredundant()
+ return clique_tree
+
+def molecular_tree_decomposition(hg, irredundant=True):
+ """ compute a tree decomposition of the input molecular hypergraph
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ molecular hypergraph to be decomposed
+ irredundant : bool
+ if True, irredundant tree decomposition will be computed.
+
+ Returns
+ -------
+ clique_tree : CliqueTree
+ each node contains a subhypergraph of `hg`
+ """
+ def _divide_hg(hg):
+ ''' divide a hypergraph into subhypergraphs such that
+ each subhypergraph is connected to each other in a tree-like way.
+
+ Parameters
+ ----------
+ hg : Hypergraph
+
+ Returns
+ -------
+ list of Hypergraphs
+ each element corresponds to a subhypergraph of `hg`
+ '''
+ is_ring = False
+ for each_node in hg.nodes:
+ if hg.node_attr(each_node)['is_in_ring']:
+ is_ring = True
+ if not hg.node_attr(each_node)['is_in_ring'] \
+ and hg.degree(each_node) == 2:
+ subhg1, subhg2 = hg.divide(each_node)
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
+
+ if is_ring:
+ subhg_list = []
+ remove_edge_list = []
+ remove_node_list = []
+ for each_edge in hg.edges:
+ node_list = hg.nodes_in_edge(each_edge)
+ subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
+ subhg_list.append(subhg)
+ remove_edge_list.append(each_edge)
+ for each_node in node_list:
+ if not hg.node_attr(each_node)['is_in_ring']:
+ remove_node_list.append(each_node)
+ hg.remove_edges(remove_edge_list)
+ hg.remove_nodes(remove_node_list, False)
+ return subhg_list + [hg]
+ else:
+ return [hg]
+
+ org_hg = hg.copy()
+ clique_tree = CliqueTree(org_hg)
+ clique_tree.add_node(0, subhg=org_hg)
+
+ subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
+ #_subhg_list = deepcopy(subhg_list)
+ if len(subhg_list) == 1:
+ pass
+ else:
+ while len(subhg_list) > 1:
+ max_leaf_subhg = None
+ for each_subhg in subhg_list:
+ if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
+ if max_leaf_subhg is None:
+ max_leaf_subhg = each_subhg
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
+ max_leaf_subhg = each_subhg
+
+ if max_leaf_subhg is None:
+ for each_subhg in subhg_list:
+ if _is_ring_label(clique_tree.root_hg, each_subhg):
+ if max_leaf_subhg is None:
+ max_leaf_subhg = each_subhg
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
+ max_leaf_subhg = each_subhg
+ if max_leaf_subhg is not None:
+ clique_tree.update(max_leaf_subhg)
+ subhg_list.remove(max_leaf_subhg)
+ else:
+ for each_subhg in subhg_list:
+ if _is_leaf(clique_tree.root_hg, each_subhg):
+ if max_leaf_subhg is None:
+ max_leaf_subhg = each_subhg
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
+ max_leaf_subhg = each_subhg
+ if max_leaf_subhg is not None:
+ clique_tree.update(max_leaf_subhg, True)
+ subhg_list.remove(max_leaf_subhg)
+ else:
+ break
+ if len(subhg_list) > 1:
+ '''
+ for each_idx, each_subhg in enumerate(subhg_list):
+ each_subhg.draw(f'{each_idx}', True)
+ clique_tree.root_hg.draw('root', True)
+ import pickle
+ with open('buggy_hg.pkl', 'wb') as f:
+ pickle.dump(hg, f)
+ return clique_tree, subhg_list, _subhg_list
+ '''
+ raise RuntimeError('bug in tree decomposition algorithm')
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
+
+ '''
+ for each_tree_node in clique_tree.adj[0]:
+ subhg = clique_tree.nodes[each_tree_node]['subhg']
+ for each_edge in subhg.edges:
+ if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
+ clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
+ '''
+ if irredundant:
+ clique_tree.to_irredundant()
+ return clique_tree #, _subhg_list
+
+def _is_leaf(hg, subhg) -> bool:
+ ''' judge whether subhg is a leaf-like in the original hypergraph
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ subhg : Hypergraph
+ `subhg` is a subhypergraph of `hg`
+
+ Returns
+ -------
+ bool
+ '''
+ if len(subhg.edges) == 0:
+ adj_edge_set = set([])
+ subhg_edge_set = set([])
+ for each_edge in hg.edges:
+ if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
+ subhg_edge_set.add(each_edge)
+ for each_node in subhg.nodes:
+ adj_edge_set.update(set(hg.adj_edges(each_node)))
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
+ return True
+ else:
+ return False
+ elif len(subhg.edges) == 1:
+ adj_edge_set = set([])
+ subhg_edge_set = subhg.edges
+ for each_node in subhg.nodes:
+ for each_adj_edge in hg.adj_edges(each_node):
+ adj_edge_set.add(each_adj_edge)
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
+ return True
+ else:
+ return False
+ else:
+ raise ValueError('subhg should be nodes only or one-edge hypergraph.')
+
+def _is_ring_label(hg, subhg):
+ if len(subhg.edges) != 1:
+ return False
+ edge_name = list(subhg.edges)[0]
+ #assert edge_name in hg.edges, f'{edge_name}'
+ is_in_ring = False
+ for each_node in subhg.nodes:
+ if subhg.node_attr(each_node)['is_in_ring']:
+ is_in_ring = True
+ else:
+ adj_edge_list = list(hg.adj_edges(each_node))
+ adj_edge_list.remove(edge_name)
+ if len(adj_edge_list) == 1:
+ if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
+ return False
+ elif len(adj_edge_list) == 0:
+ pass
+ else:
+ raise ValueError
+ if is_in_ring:
+ return True
+ else:
+ return False
+
+def _is_ring(hg):
+ for each_node in hg.nodes:
+ if not hg.node_attr(each_node)['is_in_ring']:
+ return False
+ return True
+
diff --git a/graph_grammar/graph_grammar/__init__.py b/graph_grammar/graph_grammar/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e6131daba8a4f601ae72d37e6eb035d9503045
--- /dev/null
+++ b/graph_grammar/graph_grammar/__init__.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 1 2018"
+
diff --git a/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..214041c778104f480f6870e377626ee4c7021c7e
Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc differ
diff --git a/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c09d5018842a72c2fa5452afbbc4e2d19fdda8a3
Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc differ
diff --git a/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..554af9638f063de67d4c34092aa0aacc2b6d7681
Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc differ
diff --git a/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14ca25470c52e7547eba204d4019f21c321558ba
Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc differ
diff --git a/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..208465ba58169d7ac7b468ebde4ed3953cdec08b
Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc differ
diff --git a/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5fda6d465da19e2b79cbeaf39bde2d7ac1f4a77
Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc differ
diff --git a/graph_grammar/graph_grammar/base.py b/graph_grammar/graph_grammar/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5977dff873dc004b0e1f1fbae1e65af5b52052c
--- /dev/null
+++ b/graph_grammar/graph_grammar/base.py
@@ -0,0 +1,30 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2017"
+__version__ = "0.1"
+__date__ = "Dec 11 2017"
+
+from abc import ABCMeta, abstractmethod
+
+class GraphGrammarBase(metaclass=ABCMeta):
+ @abstractmethod
+ def learn(self):
+ pass
+
+ @abstractmethod
+ def sample(self):
+ pass
diff --git a/graph_grammar/graph_grammar/corpus.py b/graph_grammar/graph_grammar/corpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad81a13d0b4873b2929e32a9031f3105811c808
--- /dev/null
+++ b/graph_grammar/graph_grammar/corpus.py
@@ -0,0 +1,152 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jun 4 2018"
+
+from collections import Counter
+from functools import partial
+from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
+from networkx.algorithms.isomorphism import GraphMatcher
+import os
+
+
+class CliqueTreeCorpus(object):
+
+ ''' clique tree corpus
+
+ Attributes
+ ----------
+ clique_tree_list : list of CliqueTree
+ subhg_list : list of Hypergraph
+ '''
+
+ def __init__(self):
+ self.clique_tree_list = []
+ self.subhg_list = []
+
+ @property
+ def size(self):
+ return len(self.subhg_list)
+
+ def add_clique_tree(self, clique_tree):
+ for each_node in clique_tree.nodes:
+ subhg = clique_tree.nodes[each_node]['subhg']
+ subhg_idx = self.add_subhg(subhg)
+ clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
+ self.clique_tree_list.append(clique_tree)
+
+ def add_to_subhg_list(self, clique_tree, root_node):
+ parent_node_dict = {}
+ current_node = None
+ parent_node_dict[root_node] = None
+ stack = [root_node]
+ while stack:
+ current_node = stack.pop()
+ current_subhg = clique_tree.nodes[current_node]['subhg']
+ for each_child in clique_tree.adj[current_node]:
+ if each_child != parent_node_dict[current_node]:
+ stack.append(each_child)
+ parent_node_dict[each_child] = current_node
+ if parent_node_dict[current_node] is not None:
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
+ common, _ = common_node_list(parent_subhg, current_subhg)
+ parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
+
+ parent_node_dict = {}
+ current_node = None
+ parent_node_dict[root_node] = None
+ stack = [root_node]
+ while stack:
+ current_node = stack.pop()
+ current_subhg = clique_tree.nodes[current_node]['subhg']
+ for each_child in clique_tree.adj[current_node]:
+ if each_child != parent_node_dict[current_node]:
+ stack.append(each_child)
+ parent_node_dict[each_child] = current_node
+ if parent_node_dict[current_node] is not None:
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
+ common, _ = common_node_list(parent_subhg, current_subhg)
+ for each_idx, each_node in enumerate(common):
+ current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
+
+ subhg_idx, is_new = self.add_subhg(current_subhg)
+ clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
+ return clique_tree
+
+ def add_subhg(self, subhg):
+ if len(self.subhg_list) == 0:
+ node_dict = {}
+ for each_node in subhg.nodes:
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
+ node_list = []
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
+ node_list.append(each_key)
+ for each_idx, each_node in enumerate(node_list):
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
+ self.subhg_list.append(subhg)
+ return 0, True
+ else:
+ match = False
+ subhg_bond_symbol_counter \
+ = Counter([subhg.node_attr(each_node)['symbol'] \
+ for each_node in subhg.nodes])
+ subhg_atom_symbol_counter \
+ = Counter([subhg.edge_attr(each_edge).get('symbol', None) \
+ for each_edge in subhg.edges])
+ for each_idx, each_subhg in enumerate(self.subhg_list):
+ each_bond_symbol_counter \
+ = Counter([each_subhg.node_attr(each_node)['symbol'] \
+ for each_node in each_subhg.nodes])
+ each_atom_symbol_counter \
+ = Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
+ for each_edge in each_subhg.edges])
+ if not match \
+ and (subhg.num_nodes == each_subhg.num_nodes
+ and subhg.num_edges == each_subhg.num_edges
+ and subhg_bond_symbol_counter == each_bond_symbol_counter
+ and subhg_atom_symbol_counter == each_atom_symbol_counter):
+ gm = GraphMatcher(each_subhg.hg,
+ subhg.hg,
+ node_match=_easy_node_match,
+ edge_match=_edge_match)
+ try:
+ isomap = next(gm.isomorphisms_iter())
+ match = True
+ for each_node in each_subhg.nodes:
+ subhg.node_attr(isomap[each_node])['order4hrg'] \
+ = each_subhg.node_attr(each_node)['order4hrg']
+ if 'ext_id' in each_subhg.node_attr(each_node):
+ subhg.node_attr(isomap[each_node])['ext_id'] \
+ = each_subhg.node_attr(each_node)['ext_id']
+ return each_idx, False
+ except StopIteration:
+ match = False
+ if not match:
+ node_dict = {}
+ for each_node in subhg.nodes:
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
+ node_list = []
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
+ node_list.append(each_key)
+ for each_idx, each_node in enumerate(node_list):
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
+
+ #for each_idx, each_node in enumerate(subhg.nodes):
+ # subhg.node_attr(each_node)['order4hrg'] = each_idx
+ self.subhg_list.append(subhg)
+ return len(self.subhg_list) - 1, True
diff --git a/graph_grammar/graph_grammar/hrg.py b/graph_grammar/graph_grammar/hrg.py
new file mode 100644
index 0000000000000000000000000000000000000000..49adf224b06b6b0fbac9040865f46b0c4f20a85a
--- /dev/null
+++ b/graph_grammar/graph_grammar/hrg.py
@@ -0,0 +1,1065 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2017"
+__version__ = "0.1"
+__date__ = "Dec 11 2017"
+
+from .corpus import CliqueTreeCorpus
+from .base import GraphGrammarBase
+from .symbols import TSymbol, NTSymbol, BondSymbol
+from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
+from ..hypergraph import Hypergraph
+from collections import Counter
+from copy import deepcopy
+from ..algo.tree_decomposition import (
+ tree_decomposition,
+ tree_decomposition_with_hrg,
+ tree_decomposition_from_leaf,
+ topological_tree_decomposition,
+ molecular_tree_decomposition)
+from functools import partial
+from networkx.algorithms.isomorphism import GraphMatcher
+from typing import List, Dict, Tuple
+import networkx as nx
+import numpy as np
+import torch
+import os
+import random
+
+DEBUG = False
+
+
+class ProductionRule(object):
+ """ A class of a production rule
+
+ Attributes
+ ----------
+ lhs : Hypergraph or None
+ the left hand side of the production rule.
+ if None, the rule is a starting rule.
+ rhs : Hypergraph
+ the right hand side of the production rule.
+ """
+ def __init__(self, lhs, rhs):
+ self.lhs = lhs
+ self.rhs = rhs
+
+ @property
+ def is_start_rule(self) -> bool:
+ return self.lhs.num_nodes == 0
+
+ @property
+ def ext_node(self) -> Dict[int, str]:
+ """ return a dict of external nodes
+ """
+ if self.is_start_rule:
+ return {}
+ else:
+ ext_node_dict = {}
+ for each_node in self.lhs.nodes:
+ ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
+ return ext_node_dict
+
+ @property
+ def lhs_nt_symbol(self) -> NTSymbol:
+ if self.is_start_rule:
+ return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
+ else:
+ return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
+
+ def rhs_adj_mat(self, node_edge_list):
+ ''' return the adjacency matrix of rhs of the production rule
+ '''
+ return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
+
+ def draw(self, file_path=None):
+ return self.rhs.draw(file_path)
+
+ def is_same(self, prod_rule, ignore_order=False):
+ """ judge whether this production rule is
+ the same as the input one, `prod_rule`
+
+ Parameters
+ ----------
+ prod_rule : ProductionRule
+ production rule to be compared
+
+ Returns
+ -------
+ is_same : bool
+ isomap : dict
+ isomorphism of nodes and hyperedges.
+ ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
+ 'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
+ 'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
+ key comes from `prod_rule`, value comes from `self`.
+ """
+ if self.is_start_rule:
+ if not prod_rule.is_start_rule:
+ return False, {}
+ else:
+ if prod_rule.is_start_rule:
+ return False, {}
+ else:
+ if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
+ return False, {}
+
+ if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
+ return False, {}
+ if prod_rule.rhs.num_edges != self.rhs.num_edges:
+ return False, {}
+
+ subhg_bond_symbol_counter \
+ = Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
+ for each_node in prod_rule.rhs.nodes])
+ each_bond_symbol_counter \
+ = Counter([self.rhs.node_attr(each_node)['symbol'] \
+ for each_node in self.rhs.nodes])
+ if subhg_bond_symbol_counter != each_bond_symbol_counter:
+ return False, {}
+
+ subhg_atom_symbol_counter \
+ = Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
+ for each_edge in prod_rule.rhs.edges])
+ each_atom_symbol_counter \
+ = Counter([self.rhs.edge_attr(each_edge)['symbol'] \
+ for each_edge in self.rhs.edges])
+ if subhg_atom_symbol_counter != each_atom_symbol_counter:
+ return False, {}
+
+ gm = GraphMatcher(prod_rule.rhs.hg,
+ self.rhs.hg,
+ partial(_node_match_prod_rule,
+ ignore_order=ignore_order),
+ partial(_edge_match,
+ ignore_order=ignore_order))
+ try:
+ return True, next(gm.isomorphisms_iter())
+ except StopIteration:
+ return False, {}
+
+ def applied_to(self,
+ hg: Hypergraph,
+ edge: str) -> Tuple[Hypergraph, List[str]]:
+ """ augment `hg` by replacing `edge` with `self.rhs`.
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ edge : str
+ `edge` must belong to `hg`
+
+ Returns
+ -------
+ hg : Hypergraph
+ resultant hypergraph
+ nt_edge_list : list
+ list of non-terminal edges
+ """
+ nt_edge_dict = {}
+ if self.is_start_rule:
+ if (edge is not None) or (hg is not None):
+ ValueError("edge and hg must be None for this prod rule.")
+ hg = Hypergraph()
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
+ for num_idx, each_node in enumerate(self.rhs.nodes):
+ hg.add_node(f"bond_{num_idx}",
+ #attr_dict=deepcopy(self.rhs.node_attr(each_node)))
+ attr_dict=self.rhs.node_attr(each_node))
+ node_map_rhs[each_node] = f"bond_{num_idx}"
+ for each_edge in self.rhs.edges:
+ node_list = []
+ for each_node in self.rhs.nodes_in_edge(each_edge):
+ node_list.append(node_map_rhs[each_node])
+ if isinstance(self.rhs.nodes_in_edge(each_edge), set):
+ node_list = set(node_list)
+ edge_id = hg.add_edge(
+ node_list,
+ #attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
+ attr_dict=self.rhs.edge_attr(each_edge))
+ if "nt_idx" in hg.edge_attr(edge_id):
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
+ return hg, nt_edge_list
+ else:
+ if edge not in hg.edges:
+ raise ValueError("the input hyperedge does not exist.")
+ if hg.edge_attr(edge)["terminal"]:
+ raise ValueError("the input hyperedge is terminal.")
+ if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
+ print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
+ raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
+ if DEBUG:
+ for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
+ other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
+ attr = deepcopy(self.lhs.node_attr(other_node))
+ attr.pop('ext_id')
+ if hg.node_attr(each_node) != attr:
+ raise ValueError('node attributes are inconsistent.')
+
+ # order of nodes that belong to the non-terminal edge in hg
+ nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
+ nt_order_dict_inv = {} # order -> hg_node
+ for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
+ nt_order_dict[each_node] = each_idx
+ nt_order_dict_inv[each_idx] = each_node
+
+ # construct a node_map_rhs: rhs -> new hg
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
+ node_idx = hg.num_nodes
+ for each_node in self.rhs.nodes:
+ if "ext_id" in self.rhs.node_attr(each_node):
+ node_map_rhs[each_node] \
+ = nt_order_dict_inv[
+ self.rhs.node_attr(each_node)["ext_id"]]
+ else:
+ node_map_rhs[each_node] = f"bond_{node_idx}"
+ node_idx += 1
+
+ # delete non-terminal
+ hg.remove_edge(edge)
+
+ # add nodes to hg
+ for each_node in self.rhs.nodes:
+ hg.add_node(node_map_rhs[each_node],
+ attr_dict=self.rhs.node_attr(each_node))
+
+ # add hyperedges to hg
+ for each_edge in self.rhs.edges:
+ node_list_hg = []
+ for each_node in self.rhs.nodes_in_edge(each_edge):
+ node_list_hg.append(node_map_rhs[each_node])
+ edge_id = hg.add_edge(
+ node_list_hg,
+ attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
+ if "nt_idx" in hg.edge_attr(edge_id):
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
+ return hg, nt_edge_list
+
+ def revert(self, hg: Hypergraph, return_subhg=False):
+ ''' revert applying this production rule.
+ i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
+ this method replaces the subhypergraph with a non-terminal hyperedge.
+
+ Parameters
+ ----------
+ hg : Hypergraph
+ hypergraph to be reverted
+ return_subhg : bool
+ if True, the removed subhypergraph will be returned.
+
+ Returns
+ -------
+ hg : Hypergraph
+ the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
+ success : bool
+ this indicates whether reverting is successed or not.
+ '''
+ gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
+ edge_match=_edge_match)
+ try:
+ # in case when the matched subhg is connected to the other part via external nodes and more.
+ not_iso = True
+ while not_iso:
+ isomap = next(gm.subgraph_isomorphisms_iter())
+ adj_node_set = set([]) # reachable nodes from the internal nodes
+ subhg_node_set = set(isomap.keys()) # nodes in subhg
+ for each_node in subhg_node_set:
+ adj_node_set.add(each_node)
+ if isomap[each_node] not in self.ext_node.values():
+ adj_node_set.update(hg.hg.adj[each_node])
+ if adj_node_set == subhg_node_set:
+ not_iso = False
+ else:
+ if return_subhg:
+ return hg, False, Hypergraph()
+ else:
+ return hg, False
+ inv_isomap = {v: k for k, v in isomap.items()}
+ '''
+ isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
+ 'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
+ where keys come from `hg` and values come from `self.rhs`
+ '''
+ except StopIteration:
+ if return_subhg:
+ return hg, False, Hypergraph()
+ else:
+ return hg, False
+
+ if return_subhg:
+ subhg = Hypergraph()
+ for each_node in hg.nodes:
+ if each_node in isomap:
+ subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
+ for each_edge in hg.edges:
+ if each_edge in isomap:
+ subhg.add_edge(hg.nodes_in_edge(each_edge),
+ attr_dict=hg.edge_attr(each_edge),
+ edge_name=each_edge)
+ subhg.edge_idx = hg.edge_idx
+
+ # remove subhg except for the externael nodes
+ for each_key, each_val in isomap.items():
+ if each_key.startswith('e'):
+ hg.remove_edge(each_key)
+ for each_key, each_val in isomap.items():
+ if each_key.startswith('bond_'):
+ if each_val not in self.ext_node.values():
+ hg.remove_node(each_key)
+
+ # add non-terminal hyperedge
+ nt_node_list = []
+ for each_ext_id in self.ext_node.keys():
+ nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
+
+ hg.add_edge(nt_node_list,
+ attr_dict=dict(
+ terminal=False,
+ symbol=self.lhs_nt_symbol))
+ if return_subhg:
+ return hg, True, subhg
+ else:
+ return hg, True
+
+
+class ProductionRuleCorpus(object):
+
+ '''
+ A corpus of production rules.
+ This class maintains
+ (i) list of unique production rules,
+ (ii) list of unique edge symbols (both terminal and non-terminal), and
+ (iii) list of unique node symbols.
+
+ Attributes
+ ----------
+ prod_rule_list : list
+ list of unique production rules
+ edge_symbol_list : list
+ list of unique symbols (including both terminal and non-terminal)
+ node_symbol_list : list
+ list of node symbols
+ nt_symbol_list : list
+ list of unique lhs symbols
+ ext_id_list : list
+ list of ext_ids
+ lhs_in_prod_rule : array
+ a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
+ '''
+
+ def __init__(self):
+ self.prod_rule_list = []
+ self.edge_symbol_list = []
+ self.edge_symbol_dict = {}
+ self.node_symbol_list = []
+ self.node_symbol_dict = {}
+ self.nt_symbol_list = []
+ self.ext_id_list = []
+ self._lhs_in_prod_rule = None
+ self.lhs_in_prod_rule_row_list = []
+ self.lhs_in_prod_rule_col_list = []
+
+ @property
+ def lhs_in_prod_rule(self):
+ if self._lhs_in_prod_rule is None:
+ self._lhs_in_prod_rule = torch.sparse.FloatTensor(
+ torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
+ torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
+ torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
+ ).to_dense()
+ return self._lhs_in_prod_rule
+
+ @property
+ def num_prod_rule(self):
+ ''' return the number of production rules
+
+ Returns
+ -------
+ int : the number of unique production rules
+ '''
+ return len(self.prod_rule_list)
+
+ @property
+ def start_rule_list(self):
+ ''' return a list of start rules
+
+ Returns
+ -------
+ list : list of start rules
+ '''
+ start_rule_list = []
+ for each_prod_rule in self.prod_rule_list:
+ if each_prod_rule.is_start_rule:
+ start_rule_list.append(each_prod_rule)
+ return start_rule_list
+
+ @property
+ def num_edge_symbol(self):
+ return len(self.edge_symbol_list)
+
+ @property
+ def num_node_symbol(self):
+ return len(self.node_symbol_list)
+
+ @property
+ def num_ext_id(self):
+ return len(self.ext_id_list)
+
+ def construct_feature_vectors(self):
+ ''' this method constructs feature vectors for the production rules collected so far.
+ currently, NTSymbol and TSymbol are treated in the same manner.
+ '''
+ feature_id_dict = {}
+ feature_id_dict['TSymbol'] = 0
+ feature_id_dict['NTSymbol'] = 1
+ feature_id_dict['BondSymbol'] = 2
+ for each_edge_symbol in self.edge_symbol_list:
+ for each_attr in each_edge_symbol.__dict__.keys():
+ each_val = each_edge_symbol.__dict__[each_attr]
+ if isinstance(each_val, list):
+ each_val = tuple(each_val)
+ if (each_attr, each_val) not in feature_id_dict:
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
+
+ for each_node_symbol in self.node_symbol_list:
+ for each_attr in each_node_symbol.__dict__.keys():
+ each_val = each_node_symbol.__dict__[each_attr]
+ if isinstance(each_val, list):
+ each_val = tuple(each_val)
+ if (each_attr, each_val) not in feature_id_dict:
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
+ for each_ext_id in self.ext_id_list:
+ feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
+ dim = len(feature_id_dict)
+
+ feature_dict = {}
+ for each_edge_symbol in self.edge_symbol_list:
+ idx_list = []
+ idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
+ for each_attr in each_edge_symbol.__dict__.keys():
+ each_val = each_edge_symbol.__dict__[each_attr]
+ if isinstance(each_val, list):
+ each_val = tuple(each_val)
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
+ feature = torch.sparse.LongTensor(
+ torch.LongTensor([idx_list]),
+ torch.ones(len(idx_list)),
+ torch.Size([len(feature_id_dict)])
+ )
+ feature_dict[each_edge_symbol] = feature
+
+ for each_node_symbol in self.node_symbol_list:
+ idx_list = []
+ idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
+ for each_attr in each_node_symbol.__dict__.keys():
+ each_val = each_node_symbol.__dict__[each_attr]
+ if isinstance(each_val, list):
+ each_val = tuple(each_val)
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
+ feature = torch.sparse.LongTensor(
+ torch.LongTensor([idx_list]),
+ torch.ones(len(idx_list)),
+ torch.Size([len(feature_id_dict)])
+ )
+ feature_dict[each_node_symbol] = feature
+ for each_ext_id in self.ext_id_list:
+ idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
+ feature_dict[('ext_id', each_ext_id)] \
+ = torch.sparse.LongTensor(
+ torch.LongTensor([idx_list]),
+ torch.ones(len(idx_list)),
+ torch.Size([len(feature_id_dict)])
+ )
+ return feature_dict, dim
+
+ def edge_symbol_idx(self, symbol):
+ return self.edge_symbol_dict[symbol]
+
+ def node_symbol_idx(self, symbol):
+ return self.node_symbol_dict[symbol]
+
+ def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
+ """ return whether the input production rule is new or not, and its production rule id.
+ Production rules are regarded as the same if
+ i) there exists a one-to-one mapping of nodes and edges, and
+ ii) all the attributes associated with nodes and hyperedges are the same.
+
+ Parameters
+ ----------
+ prod_rule : ProductionRule
+
+ Returns
+ -------
+ prod_rule_id : int
+ production rule index. if new, a new index will be assigned.
+ prod_rule : ProductionRule
+ """
+ num_lhs = len(self.nt_symbol_list)
+ for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
+ is_same, isomap = prod_rule.is_same(each_prod_rule)
+ if is_same:
+ # we do not care about edge and node names, but care about the order of non-terminal edges.
+ for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
+ if key.startswith("bond_"):
+ continue
+
+ # rewrite `nt_idx` in `prod_rule` for further processing
+ if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
+ if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
+ raise ValueError
+ prod_rule.rhs.set_edge_attr(
+ val,
+ {'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
+ return each_idx, prod_rule
+ self.prod_rule_list.append(prod_rule)
+ self._update_edge_symbol_list(prod_rule)
+ self._update_node_symbol_list(prod_rule)
+ self._update_ext_id_list(prod_rule)
+
+ lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
+ self.lhs_in_prod_rule_row_list.append(lhs_idx)
+ self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
+ self._lhs_in_prod_rule = None
+ return len(self.prod_rule_list)-1, prod_rule
+
+ def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
+ return self.prod_rule_list[prod_rule_idx]
+
+ def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
+ ''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
+
+ Parameters
+ ----------
+ unmasked_logit_array : array-like, length `num_prod_rule`
+ nt_symbol : NTSymbol
+ '''
+ if not isinstance(unmasked_logit_array, np.ndarray):
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
+ if deterministic:
+ prob = masked_softmax(unmasked_logit_array,
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
+ return self.prod_rule_list[np.argmax(prob)]
+ else:
+ return np.random.choice(
+ self.prod_rule_list, 1,
+ p=masked_softmax(unmasked_logit_array,
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
+
+ def masked_logprob(self, unmasked_logit_array, nt_symbol):
+ if not isinstance(unmasked_logit_array, np.ndarray):
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
+ prob = masked_softmax(unmasked_logit_array,
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
+ return np.log(prob)
+
+ def _update_edge_symbol_list(self, prod_rule: ProductionRule):
+ ''' update edge symbol list
+
+ Parameters
+ ----------
+ prod_rule : ProductionRule
+ '''
+ if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
+ self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
+
+ for each_edge in prod_rule.rhs.edges:
+ if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
+ edge_symbol_idx = len(self.edge_symbol_list)
+ self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
+ self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
+ else:
+ edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
+ prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
+ pass
+
+ def _update_node_symbol_list(self, prod_rule: ProductionRule):
+ ''' update node symbol list
+
+ Parameters
+ ----------
+ prod_rule : ProductionRule
+ '''
+ for each_node in prod_rule.rhs.nodes:
+ if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
+ node_symbol_idx = len(self.node_symbol_list)
+ self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
+ self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
+ else:
+ node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
+ prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
+
+ def _update_ext_id_list(self, prod_rule: ProductionRule):
+ for each_node in prod_rule.rhs.nodes:
+ if 'ext_id' in prod_rule.rhs.node_attr(each_node):
+ if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
+ self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
+
+
+class HyperedgeReplacementGrammar(GraphGrammarBase):
+ """
+ Learn a hyperedge replacement grammar from a set of hypergraphs.
+
+ Attributes
+ ----------
+ prod_rule_list : list of ProductionRule
+ production rules learned from the input hypergraphs
+ """
+ def __init__(self,
+ tree_decomposition=molecular_tree_decomposition,
+ ignore_order=False, **kwargs):
+ from functools import partial
+ self.prod_rule_corpus = ProductionRuleCorpus()
+ self.clique_tree_corpus = CliqueTreeCorpus()
+ self.ignore_order = ignore_order
+ self.tree_decomposition = partial(tree_decomposition, **kwargs)
+
+ @property
+ def num_prod_rule(self):
+ ''' return the number of production rules
+
+ Returns
+ -------
+ int : the number of unique production rules
+ '''
+ return self.prod_rule_corpus.num_prod_rule
+
+ @property
+ def start_rule_list(self):
+ ''' return a list of start rules
+
+ Returns
+ -------
+ list : list of start rules
+ '''
+ return self.prod_rule_corpus.start_rule_list
+
+ @property
+ def prod_rule_list(self):
+ return self.prod_rule_corpus.prod_rule_list
+
+ def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
+ """ learn from a list of hypergraphs
+
+ Parameters
+ ----------
+ hg_list : list of Hypergraph
+
+ Returns
+ -------
+ prod_rule_seq_list : list of integers
+ each element corresponds to a sequence of production rules to generate each hypergraph.
+ """
+ prod_rule_seq_list = []
+ idx = 0
+ for each_idx, each_hg in enumerate(hg_list):
+ clique_tree = self.tree_decomposition(each_hg)
+
+ # get a pair of myself and children
+ root_node = _find_root(clique_tree)
+ clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
+ prod_rule_seq = []
+ stack = []
+
+ children = sorted(list(clique_tree[root_node].keys()))
+
+ # extract a temporary production rule
+ prod_rule = extract_prod_rule(
+ None,
+ clique_tree.nodes[root_node]["subhg"],
+ [clique_tree.nodes[each_child]["subhg"]
+ for each_child in children],
+ clique_tree.nodes[root_node].get('subhg_idx', None))
+
+ # update the production rule list
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
+ children = reorder_children(root_node,
+ children,
+ prod_rule,
+ clique_tree)
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
+ prod_rule_seq.append(prod_rule_id)
+
+ while len(stack) != 0:
+ # get a triple of parent, myself, and children
+ parent, myself = stack.pop()
+ children = sorted(list(dict(clique_tree[myself]).keys()))
+ children.remove(parent)
+
+ # extract a temp prod rule
+ prod_rule = extract_prod_rule(
+ clique_tree.nodes[parent]["subhg"],
+ clique_tree.nodes[myself]["subhg"],
+ [clique_tree.nodes[each_child]["subhg"]
+ for each_child in children],
+ clique_tree.nodes[myself].get('subhg_idx', None))
+
+ # update the prod rule list
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
+ children = reorder_children(myself,
+ children,
+ prod_rule,
+ clique_tree)
+ stack.extend([(myself, each_child)
+ for each_child in children[::-1]])
+ prod_rule_seq.append(prod_rule_id)
+ prod_rule_seq_list.append(prod_rule_seq)
+ if (each_idx+1) % print_freq == 0:
+ msg = f'#(molecules processed)={each_idx+1}\t'\
+ f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
+ logger(msg)
+ if each_idx > max_mol:
+ break
+
+ print(f'corpus_size = {self.clique_tree_corpus.size}')
+ return prod_rule_seq_list
+
+ def sample(self, z, deterministic=False):
+ """ sample a new hypergraph from HRG.
+
+ Parameters
+ ----------
+ z : array-like, shape (len, num_prod_rule)
+ logit
+ deterministic : bool
+ if True, deterministic sampling
+
+ Returns
+ -------
+ Hypergraph
+ """
+ seq_idx = 0
+ stack = []
+ z = z[:, :-1]
+ init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
+ is_aromatic=False,
+ bond_symbol_list=[]),
+ deterministic=deterministic)
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
+ stack = deepcopy(nt_edge_list[::-1])
+ while len(stack) != 0 and seq_idx < z.shape[0]-1:
+ seq_idx += 1
+ nt_edge = stack.pop()
+ nt_symbol = hg.edge_attr(nt_edge)['symbol']
+ prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
+ hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
+ stack.extend(nt_edge_list[::-1])
+ if len(stack) != 0:
+ raise RuntimeError(f'{len(stack)} non-terminals are left.')
+ return hg
+
+ def construct(self, prod_rule_seq):
+ """ construct a hypergraph following `prod_rule_seq`
+
+ Parameters
+ ----------
+ prod_rule_seq : list of integers
+ a sequence of production rules.
+
+ Returns
+ -------
+ UndirectedHypergraph
+ """
+ seq_idx = 0
+ init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
+ stack = deepcopy(nt_edge_list[::-1])
+ while len(stack) != 0:
+ seq_idx += 1
+ nt_edge = stack.pop()
+ hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
+ stack.extend(nt_edge_list[::-1])
+ return hg
+
+ def update_prod_rule_list(self, prod_rule):
+ """ return whether the input production rule is new or not, and its production rule id.
+ Production rules are regarded as the same if
+ i) there exists a one-to-one mapping of nodes and edges, and
+ ii) all the attributes associated with nodes and hyperedges are the same.
+
+ Parameters
+ ----------
+ prod_rule : ProductionRule
+
+ Returns
+ -------
+ is_new : bool
+ if True, this production rule is new
+ prod_rule_id : int
+ production rule index. if new, a new index will be assigned.
+ """
+ return self.prod_rule_corpus.append(prod_rule)
+
+
+class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
+ '''
+ This class learns HRG incrementally leveraging the previously obtained production rules.
+ '''
+ def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
+ self.prod_rule_list = []
+ self.tree_decomposition = tree_decomposition
+ self.ignore_order = ignore_order
+
+ def learn(self, hg_list):
+ """ learn from a list of hypergraphs
+
+ Parameters
+ ----------
+ hg_list : list of UndirectedHypergraph
+
+ Returns
+ -------
+ prod_rule_seq_list : list of integers
+ each element corresponds to a sequence of production rules to generate each hypergraph.
+ """
+ prod_rule_seq_list = []
+ for each_hg in hg_list:
+ clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
+
+ prod_rule_seq = []
+ stack = []
+
+ # get a pair of myself and children
+ children = sorted(list(clique_tree[root_node].keys()))
+
+ # extract a temporary production rule
+ prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
+
+ # update the production rule list
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
+ children = reorder_children(root_node, children, prod_rule, clique_tree)
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
+ prod_rule_seq.append(prod_rule_id)
+
+ while len(stack) != 0:
+ # get a triple of parent, myself, and children
+ parent, myself = stack.pop()
+ children = sorted(list(dict(clique_tree[myself]).keys()))
+ children.remove(parent)
+
+ # extract a temp prod rule
+ prod_rule = extract_prod_rule(
+ clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
+
+ # update the prod rule list
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
+ children = reorder_children(myself, children, prod_rule, clique_tree)
+ stack.extend([(myself, each_child) for each_child in children[::-1]])
+ prod_rule_seq.append(prod_rule_id)
+ prod_rule_seq_list.append(prod_rule_seq)
+ self._compute_stats()
+ return prod_rule_seq_list
+
+
+def reorder_children(myself, children, prod_rule, clique_tree):
+ """ reorder children so that they match the order in `prod_rule`.
+
+ Parameters
+ ----------
+ myself : int
+ children : list of int
+ prod_rule : ProductionRule
+ clique_tree : nx.Graph
+
+ Returns
+ -------
+ new_children : list of str
+ reordered children
+ """
+ perm = {} # key : `nt_idx`, val : child
+ for each_edge in prod_rule.rhs.edges:
+ if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
+ for each_child in children:
+ common_node_set = set(
+ common_node_list(clique_tree.nodes[myself]["subhg"],
+ clique_tree.nodes[each_child]["subhg"])[0])
+ if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
+ assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
+ perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
+ new_children = []
+ assert len(perm) == len(children)
+ for i in range(len(perm)):
+ new_children.append(perm[i])
+ return new_children
+
+
+def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
+ """ extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
+
+ Parameters
+ ----------
+ parent_hg : Hypergraph
+ myself_hg : Hypergraph
+ children_hg_list : list of Hypergraph
+
+ Returns
+ -------
+ ProductionRule, consisting of
+ lhs : Hypergraph or None
+ rhs : Hypergraph
+ """
+ def _add_ext_node(hg, ext_nodes):
+ """ mark nodes to be external (ordered ids are assigned)
+
+ Parameters
+ ----------
+ hg : UndirectedHypergraph
+ ext_nodes : list of str
+ list of external nodes
+
+ Returns
+ -------
+ hg : Hypergraph
+ nodes in `ext_nodes` are marked to be external
+ """
+ ext_id = 0
+ ext_id_exists = []
+ for each_node in ext_nodes:
+ ext_id_exists.append('ext_id' in hg.node_attr(each_node))
+ if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
+ raise ValueError
+ if not all(ext_id_exists):
+ for each_node in ext_nodes:
+ hg.node_attr(each_node)['ext_id'] = ext_id
+ ext_id += 1
+ return hg
+
+ def _check_aromatic(hg, node_list):
+ is_aromatic = False
+ node_aromatic_list = []
+ for each_node in node_list:
+ if hg.node_attr(each_node)['symbol'].is_aromatic:
+ is_aromatic = True
+ node_aromatic_list.append(True)
+ else:
+ node_aromatic_list.append(False)
+ return is_aromatic, node_aromatic_list
+
+ def _check_ring(hg):
+ for each_edge in hg.edges:
+ if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
+ return False
+ return True
+
+ if parent_hg is None:
+ lhs = Hypergraph()
+ node_list = []
+ else:
+ lhs = Hypergraph()
+ node_list, edge_exists = common_node_list(parent_hg, myself_hg)
+ for each_node in node_list:
+ lhs.add_node(each_node,
+ deepcopy(myself_hg.node_attr(each_node)))
+ is_aromatic, _ = _check_aromatic(parent_hg, node_list)
+ for_ring = _check_ring(myself_hg)
+ bond_symbol_list = []
+ for each_node in node_list:
+ bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
+ lhs.add_edge(
+ node_list,
+ attr_dict=dict(
+ terminal=False,
+ edge_exists=edge_exists,
+ symbol=NTSymbol(
+ degree=len(node_list),
+ is_aromatic=is_aromatic,
+ bond_symbol_list=bond_symbol_list,
+ for_ring=for_ring)))
+ try:
+ lhs = _add_ext_node(lhs, node_list)
+ except ValueError:
+ import pdb; pdb.set_trace()
+
+ rhs = remove_tmp_edge(deepcopy(myself_hg))
+ #rhs = remove_ext_node(rhs)
+ #rhs = remove_nt_edge(rhs)
+ try:
+ rhs = _add_ext_node(rhs, node_list)
+ except ValueError:
+ import pdb; pdb.set_trace()
+
+ nt_idx = 0
+ if children_hg_list is not None:
+ for each_child_hg in children_hg_list:
+ node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
+ is_aromatic, _ = _check_aromatic(myself_hg, node_list)
+ for_ring = _check_ring(each_child_hg)
+ bond_symbol_list = []
+ for each_node in node_list:
+ bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
+ rhs.add_edge(
+ node_list,
+ attr_dict=dict(
+ terminal=False,
+ nt_idx=nt_idx,
+ edge_exists=edge_exists,
+ symbol=NTSymbol(degree=len(node_list),
+ is_aromatic=is_aromatic,
+ bond_symbol_list=bond_symbol_list,
+ for_ring=for_ring)))
+ nt_idx += 1
+ prod_rule = ProductionRule(lhs, rhs)
+ prod_rule.subhg_idx = subhg_idx
+ if DEBUG:
+ if sorted(list(prod_rule.ext_node.keys())) \
+ != list(np.arange(len(prod_rule.ext_node))):
+ raise RuntimeError('ext_id is not continuous')
+ return prod_rule
+
+
+def _find_root(clique_tree):
+ max_node = None
+ num_nodes_max = -np.inf
+ for each_node in clique_tree.nodes:
+ if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
+ max_node = each_node
+ num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
+ '''
+ children = sorted(list(clique_tree[each_node].keys()))
+ prod_rule = extract_prod_rule(None,
+ clique_tree.nodes[each_node]["subhg"],
+ [clique_tree.nodes[each_child]["subhg"]
+ for each_child in children])
+ for each_start_rule in start_rule_list:
+ if prod_rule.is_same(each_start_rule):
+ return each_node
+ '''
+ return max_node
+
+def remove_ext_node(hg):
+ for each_node in hg.nodes:
+ hg.node_attr(each_node).pop('ext_id', None)
+ return hg
+
+def remove_nt_edge(hg):
+ remove_edge_list = []
+ for each_edge in hg.edges:
+ if not hg.edge_attr(each_edge)['terminal']:
+ remove_edge_list.append(each_edge)
+ hg.remove_edges(remove_edge_list)
+ return hg
+
+def remove_tmp_edge(hg):
+ remove_edge_list = []
+ for each_edge in hg.edges:
+ if hg.edge_attr(each_edge).get('tmp', False):
+ remove_edge_list.append(each_edge)
+ hg.remove_edges(remove_edge_list)
+ return hg
diff --git a/graph_grammar/graph_grammar/symbols.py b/graph_grammar/graph_grammar/symbols.py
new file mode 100644
index 0000000000000000000000000000000000000000..a024fb263c0aed40f9dc3b816e3c87913594f96c
--- /dev/null
+++ b/graph_grammar/graph_grammar/symbols.py
@@ -0,0 +1,180 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 1 2018"
+
+from typing import List
+
+class TSymbol(object):
+
+ ''' terminal symbol
+
+ Attributes
+ ----------
+ degree : int
+ the number of nodes in a hyperedge
+ is_aromatic : bool
+ whether or not the hyperedge is in an aromatic ring
+ symbol : str
+ atomic symbol
+ num_explicit_Hs : int
+ the number of hydrogens associated to this hyperedge
+ formal_charge : int
+ charge
+ chirality : int
+ chirality
+ '''
+
+ def __init__(self, degree, is_aromatic,
+ symbol, num_explicit_Hs, formal_charge, chirality):
+ self.degree = degree
+ self.is_aromatic = is_aromatic
+ self.symbol = symbol
+ self.num_explicit_Hs = num_explicit_Hs
+ self.formal_charge = formal_charge
+ self.chirality = chirality
+
+ @property
+ def terminal(self):
+ return True
+
+ def __eq__(self, other):
+ if not isinstance(other, TSymbol):
+ return False
+ if self.degree != other.degree:
+ return False
+ if self.is_aromatic != other.is_aromatic:
+ return False
+ if self.symbol != other.symbol:
+ return False
+ if self.num_explicit_Hs != other.num_explicit_Hs:
+ return False
+ if self.formal_charge != other.formal_charge:
+ return False
+ if self.chirality != other.chirality:
+ return False
+ return True
+
+ def __hash__(self):
+ return self.__str__().__hash__()
+
+ def __str__(self):
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
+ f'symbol={self.symbol}, '\
+ f'num_explicit_Hs={self.num_explicit_Hs}, '\
+ f'formal_charge={self.formal_charge}, chirality={self.chirality}'
+
+
+class NTSymbol(object):
+
+ ''' non-terminal symbol
+
+ Attributes
+ ----------
+ degree : int
+ degree of the hyperedge
+ is_aromatic : bool
+ if True, at least one of the associated bonds must be aromatic.
+ node_aromatic_list : list of bool
+ indicate whether each of the nodes is aromatic or not.
+ bond_type_list : list of int
+ bond type of each node"
+ '''
+
+ def __init__(self, degree: int, is_aromatic: bool,
+ bond_symbol_list: list,
+ for_ring=False):
+ self.degree = degree
+ self.is_aromatic = is_aromatic
+ self.for_ring = for_ring
+ self.bond_symbol_list = bond_symbol_list
+
+ @property
+ def terminal(self) -> bool:
+ return False
+
+ @property
+ def symbol(self):
+ return f'NT{self.degree}'
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, NTSymbol):
+ return False
+
+ if self.degree != other.degree:
+ return False
+ if self.is_aromatic != other.is_aromatic:
+ return False
+ if self.for_ring != other.for_ring:
+ return False
+ if len(self.bond_symbol_list) != len(other.bond_symbol_list):
+ return False
+ for each_idx in range(len(self.bond_symbol_list)):
+ if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
+ return False
+ return True
+
+ def __hash__(self):
+ return self.__str__().__hash__()
+
+ def __str__(self) -> str:
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
+ f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
+ f'for_ring={self.for_ring}'
+
+
+class BondSymbol(object):
+
+
+ ''' Bond symbol
+
+ Attributes
+ ----------
+ is_aromatic : bool
+ if True, at least one of the associated bonds must be aromatic.
+ bond_type : int
+ bond type of each node"
+ '''
+
+ def __init__(self, is_aromatic: bool,
+ bond_type: int,
+ stereo: int):
+ self.is_aromatic = is_aromatic
+ self.bond_type = bond_type
+ self.stereo = stereo
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, BondSymbol):
+ return False
+
+ if self.is_aromatic != other.is_aromatic:
+ return False
+ if self.bond_type != other.bond_type:
+ return False
+ if self.stereo != other.stereo:
+ return False
+ return True
+
+ def __hash__(self):
+ return self.__str__().__hash__()
+
+ def __str__(self) -> str:
+ return f'is_aromatic={self.is_aromatic}, '\
+ f'bond_type={self.bond_type}, '\
+ f'stereo={self.stereo}, '
diff --git a/graph_grammar/graph_grammar/utils.py b/graph_grammar/graph_grammar/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c6bcad4b8ff341f2ffb844dd2265976cca2803
--- /dev/null
+++ b/graph_grammar/graph_grammar/utils.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jun 4 2018"
+
+from ..hypergraph import Hypergraph
+from copy import deepcopy
+from typing import List
+import numpy as np
+
+
+def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
+ """ return a list of common nodes
+
+ Parameters
+ ----------
+ hg1, hg2 : Hypergraph
+
+ Returns
+ -------
+ list of str
+ list of common nodes
+ """
+ if hg1 is None or hg2 is None:
+ return [], False
+ else:
+ node_set = hg1.nodes.intersection(hg2.nodes)
+ node_dict = {}
+ if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
+ for each_node in node_set:
+ node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
+ else:
+ for each_node in node_set:
+ node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
+ node_list = []
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
+ node_list.append(each_key)
+ edge_name = hg1.has_edge(node_list, ignore_order=True)
+ if edge_name:
+ if not hg1.edge_attr(edge_name).get('terminal', True):
+ node_list = hg1.nodes_in_edge(edge_name)
+ return node_list, True
+ else:
+ return node_list, False
+
+
+def _node_match(node1, node2):
+ # if the nodes are hyperedges, `atom_attr` determines the match
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
+ # bond_symbol
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
+ else:
+ return False
+
+def _easy_node_match(node1, node2):
+ # if the nodes are hyperedges, `atom_attr` determines the match
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
+ return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
+ # bond_symbol
+ return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
+ and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
+ else:
+ return False
+
+
+def _node_match_prod_rule(node1, node2, ignore_order=False):
+ # if the nodes are hyperedges, `atom_attr` determines the match
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
+ # ext_id, order4hrg, bond_symbol
+ if ignore_order:
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
+ else:
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
+ and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
+ else:
+ return False
+
+
+def _edge_match(edge1, edge2, ignore_order=False):
+ #return True
+ if ignore_order:
+ return True
+ else:
+ return edge1["order"] == edge2["order"]
+
+def masked_softmax(logit, mask):
+ ''' compute a probability distribution from logit
+
+ Parameters
+ ----------
+ logit : array-like, length D
+ each element indicates how each dimension is likely to be chosen
+ (the larger, the more likely)
+ mask : array-like, length D
+ each element is either 0 or 1.
+ if 0, the dimension is ignored
+ when computing the probability distribution.
+
+ Returns
+ -------
+ prob_dist : array, length D
+ probability distribution computed from logit.
+ if `mask[d] = 0`, `prob_dist[d] = 0`.
+ '''
+ if logit.shape != mask.shape:
+ raise ValueError('logit and mask must have the same shape')
+ c = np.max(logit)
+ exp_logit = np.exp(logit - c) * mask
+ sum_exp_logit = exp_logit @ mask
+ return exp_logit / sum_exp_logit
diff --git a/graph_grammar/hypergraph.py b/graph_grammar/hypergraph.py
new file mode 100644
index 0000000000000000000000000000000000000000..15448755e3b40ac0f2b5f80c2b84511a904c2dd1
--- /dev/null
+++ b/graph_grammar/hypergraph.py
@@ -0,0 +1,544 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 31 2018"
+
+from copy import deepcopy
+from typing import List, Dict, Tuple
+import networkx as nx
+import numpy as np
+import os
+
+
+class Hypergraph(object):
+ '''
+ A class of a hypergraph.
+ Each hyperedge can be ordered. For the ordered case,
+ edges adjacent to the hyperedge node are labeled by their orders.
+
+ Attributes
+ ----------
+ hg : nx.Graph
+ a bipartite graph representation of a hypergraph
+ edge_idx : int
+ total number of hyperedges that exist so far
+ '''
+ def __init__(self):
+ self.hg = nx.Graph()
+ self.edge_idx = 0
+ self.nodes = set([])
+ self.num_nodes = 0
+ self.edges = set([])
+ self.num_edges = 0
+ self.nodes_in_edge_dict = {}
+
+ def add_node(self, node: str, attr_dict=None):
+ ''' add a node to hypergraph
+
+ Parameters
+ ----------
+ node : str
+ node name
+ attr_dict : dict
+ dictionary of node attributes
+ '''
+ self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
+ if node not in self.nodes:
+ self.num_nodes += 1
+ self.nodes.add(node)
+
+ def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
+ ''' add an edge consisting of nodes `node_list`
+
+ Parameters
+ ----------
+ node_list : list
+ ordered list of nodes that consist the edge
+ attr_dict : dict
+ dictionary of edge attributes
+ '''
+ if edge_name is None:
+ edge = 'e{}'.format(self.edge_idx)
+ else:
+ assert edge_name not in self.edges
+ edge = edge_name
+ self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
+ if edge not in self.edges:
+ self.num_edges += 1
+ self.edges.add(edge)
+ self.nodes_in_edge_dict[edge] = node_list
+ if type(node_list) == list:
+ for node_idx, each_node in enumerate(node_list):
+ self.hg.add_edge(edge, each_node, order=node_idx)
+ if each_node not in self.nodes:
+ self.num_nodes += 1
+ self.nodes.add(each_node)
+
+ elif type(node_list) == set:
+ for each_node in node_list:
+ self.hg.add_edge(edge, each_node, order=-1)
+ if each_node not in self.nodes:
+ self.num_nodes += 1
+ self.nodes.add(each_node)
+ else:
+ raise ValueError
+ self.edge_idx += 1
+ return edge
+
+ def remove_node(self, node: str, remove_connected_edges=True):
+ ''' remove a node
+
+ Parameters
+ ----------
+ node : str
+ node name
+ remove_connected_edges : bool
+ if True, remove edges that are adjacent to the node
+ '''
+ if remove_connected_edges:
+ connected_edges = deepcopy(self.adj_edges(node))
+ for each_edge in connected_edges:
+ self.remove_edge(each_edge)
+ self.hg.remove_node(node)
+ self.num_nodes -= 1
+ self.nodes.remove(node)
+
+ def remove_nodes(self, node_iter, remove_connected_edges=True):
+ ''' remove a set of nodes
+
+ Parameters
+ ----------
+ node_iter : iterator of strings
+ nodes to be removed
+ remove_connected_edges : bool
+ if True, remove edges that are adjacent to the node
+ '''
+ for each_node in node_iter:
+ self.remove_node(each_node, remove_connected_edges)
+
+ def remove_edge(self, edge: str):
+ ''' remove an edge
+
+ Parameters
+ ----------
+ edge : str
+ edge to be removed
+ '''
+ self.hg.remove_node(edge)
+ self.edges.remove(edge)
+ self.num_edges -= 1
+ self.nodes_in_edge_dict.pop(edge)
+
+ def remove_edges(self, edge_iter):
+ ''' remove a set of edges
+
+ Parameters
+ ----------
+ edge_iter : iterator of strings
+ edges to be removed
+ '''
+ for each_edge in edge_iter:
+ self.remove_edge(each_edge)
+
+ def remove_edges_with_attr(self, edge_attr_dict):
+ remove_edge_list = []
+ for each_edge in self.edges:
+ satisfy = True
+ for each_key, each_val in edge_attr_dict.items():
+ if not satisfy:
+ break
+ try:
+ if self.edge_attr(each_edge)[each_key] != each_val:
+ satisfy = False
+ except KeyError:
+ satisfy = False
+ if satisfy:
+ remove_edge_list.append(each_edge)
+ self.remove_edges(remove_edge_list)
+
+ def remove_subhg(self, subhg):
+ ''' remove subhypergraph.
+ all of the hyperedges are removed.
+ each node of subhg is removed if its degree becomes 0 after removing hyperedges.
+
+ Parameters
+ ----------
+ subhg : Hypergraph
+ '''
+ for each_edge in subhg.edges:
+ self.remove_edge(each_edge)
+ for each_node in subhg.nodes:
+ if self.degree(each_node) == 0:
+ self.remove_node(each_node)
+
+ def nodes_in_edge(self, edge):
+ ''' return an ordered list of nodes in a given edge.
+
+ Parameters
+ ----------
+ edge : str
+ edge whose nodes are returned
+
+ Returns
+ -------
+ list or set
+ ordered list or set of nodes that belong to the edge
+ '''
+ if edge.startswith('e'):
+ return self.nodes_in_edge_dict[edge]
+ else:
+ adj_node_list = self.hg.adj[edge]
+ adj_node_order_list = []
+ adj_node_name_list = []
+ for each_node in adj_node_list:
+ adj_node_order_list.append(adj_node_list[each_node]['order'])
+ adj_node_name_list.append(each_node)
+ if adj_node_order_list == [-1] * len(adj_node_order_list):
+ return set(adj_node_name_list)
+ else:
+ return [adj_node_name_list[each_idx] for each_idx
+ in np.argsort(adj_node_order_list)]
+
+ def adj_edges(self, node):
+ ''' return a dict of adjacent hyperedges
+
+ Parameters
+ ----------
+ node : str
+
+ Returns
+ -------
+ set
+ set of edges that are adjacent to `node`
+ '''
+ return self.hg.adj[node]
+
+ def adj_nodes(self, node):
+ ''' return a set of adjacent nodes
+
+ Parameters
+ ----------
+ node : str
+
+ Returns
+ -------
+ set
+ set of nodes that are adjacent to `node`
+ '''
+ node_set = set([])
+ for each_adj_edge in self.adj_edges(node):
+ node_set.update(set(self.nodes_in_edge(each_adj_edge)))
+ node_set.discard(node)
+ return node_set
+
+ def has_edge(self, node_list, ignore_order=False):
+ for each_edge in self.edges:
+ if ignore_order:
+ if set(self.nodes_in_edge(each_edge)) == set(node_list):
+ return each_edge
+ else:
+ if self.nodes_in_edge(each_edge) == node_list:
+ return each_edge
+ return False
+
+ def degree(self, node):
+ return len(self.hg.adj[node])
+
+ def degrees(self):
+ return {each_node: self.degree(each_node) for each_node in self.nodes}
+
+ def edge_degree(self, edge):
+ return len(self.nodes_in_edge(edge))
+
+ def edge_degrees(self):
+ return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
+
+ def is_adj(self, node1, node2):
+ return node1 in self.adj_nodes(node2)
+
+ def adj_subhg(self, node, ident_node_dict=None):
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
+
+ Parameters
+ ----------
+ node : str
+ ident_node_dict : dict
+ dict containing identical nodes. see `get_identical_node_dict` for more details
+
+ Returns
+ -------
+ subhg : Hypergraph
+ """
+ if ident_node_dict is None:
+ ident_node_dict = self.get_identical_node_dict()
+ adj_node_set = set(ident_node_dict[node])
+ adj_edge_set = set([])
+ for each_node in ident_node_dict[node]:
+ adj_edge_set.update(set(self.adj_edges(each_node)))
+ fixed_adj_edge_set = deepcopy(adj_edge_set)
+ for each_edge in fixed_adj_edge_set:
+ other_nodes = self.nodes_in_edge(each_edge)
+ adj_node_set.update(other_nodes)
+
+ # if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
+ for each_node in other_nodes:
+ for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
+ if len(set(self.nodes_in_edge(other_edge)) \
+ - set(self.nodes_in_edge(each_edge))) == 0:
+ adj_edge_set.update(set([other_edge]))
+ subhg = Hypergraph()
+ for each_node in adj_node_set:
+ subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
+ for each_edge in adj_edge_set:
+ subhg.add_edge(self.nodes_in_edge(each_edge),
+ attr_dict=self.edge_attr(each_edge),
+ edge_name=each_edge)
+ subhg.edge_idx = self.edge_idx
+ return subhg
+
+ def get_subhg(self, node_list, edge_list, ident_node_dict=None):
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
+
+ Parameters
+ ----------
+ node : str
+ ident_node_dict : dict
+ dict containing identical nodes. see `get_identical_node_dict` for more details
+
+ Returns
+ -------
+ subhg : Hypergraph
+ """
+ if ident_node_dict is None:
+ ident_node_dict = self.get_identical_node_dict()
+ adj_node_set = set([])
+ for each_node in node_list:
+ adj_node_set.update(set(ident_node_dict[each_node]))
+ adj_edge_set = set(edge_list)
+
+ subhg = Hypergraph()
+ for each_node in adj_node_set:
+ subhg.add_node(each_node,
+ attr_dict=deepcopy(self.node_attr(each_node)))
+ for each_edge in adj_edge_set:
+ subhg.add_edge(self.nodes_in_edge(each_edge),
+ attr_dict=deepcopy(self.edge_attr(each_edge)),
+ edge_name=each_edge)
+ subhg.edge_idx = self.edge_idx
+ return subhg
+
+ def copy(self):
+ ''' return a copy of the object
+
+ Returns
+ -------
+ Hypergraph
+ '''
+ return deepcopy(self)
+
+ def node_attr(self, node):
+ return self.hg.nodes[node]['attr_dict']
+
+ def edge_attr(self, edge):
+ return self.hg.nodes[edge]['attr_dict']
+
+ def set_node_attr(self, node, attr_dict):
+ for each_key, each_val in attr_dict.items():
+ self.hg.nodes[node]['attr_dict'][each_key] = each_val
+
+ def set_edge_attr(self, edge, attr_dict):
+ for each_key, each_val in attr_dict.items():
+ self.hg.nodes[edge]['attr_dict'][each_key] = each_val
+
+ def get_identical_node_dict(self):
+ ''' get identical nodes
+ nodes are identical if they share the same set of adjacent edges.
+
+ Returns
+ -------
+ ident_node_dict : dict
+ ident_node_dict[node] returns a list of nodes that are identical to `node`.
+ '''
+ ident_node_dict = {}
+ for each_node in self.nodes:
+ ident_node_list = []
+ for each_other_node in self.nodes:
+ if each_other_node == each_node:
+ ident_node_list.append(each_other_node)
+ elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
+ and len(self.adj_edges(each_node)) != 0:
+ ident_node_list.append(each_other_node)
+ ident_node_dict[each_node] = ident_node_list
+ return ident_node_dict
+ '''
+ ident_node_dict = {}
+ for each_node in self.nodes:
+ ident_node_dict[each_node] = [each_node]
+ return ident_node_dict
+ '''
+
+ def get_leaf_edge(self):
+ ''' get an edge that is incident only to one edge
+
+ Returns
+ -------
+ if exists, return a leaf edge. otherwise, return None.
+ '''
+ for each_edge in self.edges:
+ if len(self.adj_nodes(each_edge)) == 1:
+ if 'tmp' not in self.edge_attr(each_edge):
+ return each_edge
+ return None
+
+ def get_nontmp_edge(self):
+ for each_edge in self.edges:
+ if 'tmp' not in self.edge_attr(each_edge):
+ return each_edge
+ return None
+
+ def is_subhg(self, hg):
+ ''' return whether this hypergraph is a subhypergraph of `hg`
+
+ Returns
+ -------
+ True if self \in hg,
+ False otherwise.
+ '''
+ for each_node in self.nodes:
+ if each_node not in hg.nodes:
+ return False
+ for each_edge in self.edges:
+ if each_edge not in hg.edges:
+ return False
+ return True
+
+ def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
+ ''' if `node` is in a cycle, then return True. otherwise, False.
+
+ Parameters
+ ----------
+ node : str
+ node in a hypergraph
+ visited : list
+ list of visited nodes, used for recursion
+ parent : str
+ parent node, used to eliminate a cycle consisting of two nodes and one edge.
+
+ Returns
+ -------
+ bool
+ '''
+ if visited is None:
+ visited = []
+ if parent == '':
+ visited = []
+ if root_node == '':
+ root_node = node
+ visited.append(node)
+ for each_adj_node in self.adj_nodes(node):
+ if each_adj_node not in visited:
+ if self.in_cycle(each_adj_node, visited, node, root_node):
+ return True
+ elif each_adj_node != parent and each_adj_node == root_node:
+ return True
+ return False
+
+
+ def draw(self, file_path=None, with_node=False, with_edge_name=False):
+ ''' draw hypergraph
+ '''
+ import graphviz
+ G = graphviz.Graph(format='png')
+ for each_node in self.nodes:
+ if 'ext_id' in self.node_attr(each_node):
+ G.node(each_node, label='',
+ shape='circle', width='0.1', height='0.1', style='filled',
+ fillcolor='black')
+ else:
+ if with_node:
+ G.node(each_node, label='',
+ shape='circle', width='0.1', height='0.1', style='filled',
+ fillcolor='gray')
+ edge_list = []
+ for each_edge in self.edges:
+ if self.edge_attr(each_edge).get('terminal', False):
+ G.node(each_edge,
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
+ fontcolor='black', shape='square')
+ elif self.edge_attr(each_edge).get('tmp', False):
+ G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
+ fontcolor='black', shape='square')
+ else:
+ G.node(each_edge,
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
+ fontcolor='black', shape='square', style='filled')
+ if with_node:
+ for each_node in self.nodes_in_edge(each_edge):
+ G.edge(each_edge, each_node)
+ else:
+ for each_node in self.nodes_in_edge(each_edge):
+ if 'ext_id' in self.node_attr(each_node)\
+ and set([each_node, each_edge]) not in edge_list:
+ G.edge(each_edge, each_node)
+ edge_list.append(set([each_node, each_edge]))
+ for each_other_edge in self.adj_nodes(each_edge):
+ if set([each_edge, each_other_edge]) not in edge_list:
+ num_bond = 0
+ common_node_set = set(self.nodes_in_edge(each_edge))\
+ .intersection(set(self.nodes_in_edge(each_other_edge)))
+ for each_node in common_node_set:
+ if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
+ num_bond += self.node_attr(each_node)['symbol'].bond_type
+ elif self.node_attr(each_node)['symbol'].bond_type in [12]:
+ num_bond += 1
+ else:
+ raise NotImplementedError('unsupported bond type')
+ for _ in range(num_bond):
+ G.edge(each_edge, each_other_edge)
+ edge_list.append(set([each_edge, each_other_edge]))
+ if file_path is not None:
+ G.render(file_path, cleanup=True)
+ #os.remove(file_path)
+ return G
+
+ def is_dividable(self, node):
+ _hg = deepcopy(self.hg)
+ _hg.remove_node(node)
+ return (not nx.is_connected(_hg))
+
+ def divide(self, node):
+ subhg_list = []
+
+ hg_wo_node = deepcopy(self)
+ hg_wo_node.remove_node(node, remove_connected_edges=False)
+ connected_components = nx.connected_components(hg_wo_node.hg)
+ for each_component in connected_components:
+ node_list = [node]
+ edge_list = []
+ node_list.extend([each_node for each_node in each_component
+ if each_node.startswith('bond_')])
+ edge_list.extend([each_edge for each_edge in each_component
+ if each_edge.startswith('e')])
+ subhg_list.append(self.get_subhg(node_list, edge_list))
+ #subhg_list[-1].set_node_attr(node, {'divided': True})
+ return subhg_list
+
diff --git a/graph_grammar/io/__init__.py b/graph_grammar/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e6131daba8a4f601ae72d37e6eb035d9503045
--- /dev/null
+++ b/graph_grammar/io/__init__.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 1 2018"
+
diff --git a/graph_grammar/io/__pycache__/__init__.cpython-310.pyc b/graph_grammar/io/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4f6e140d8246ba4cb6b25bcf7d89f2511581804
Binary files /dev/null and b/graph_grammar/io/__pycache__/__init__.cpython-310.pyc differ
diff --git a/graph_grammar/io/__pycache__/smi.cpython-310.pyc b/graph_grammar/io/__pycache__/smi.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..682a16d3069fb140568cc08d7251f4bef4240126
Binary files /dev/null and b/graph_grammar/io/__pycache__/smi.cpython-310.pyc differ
diff --git a/graph_grammar/io/smi.py b/graph_grammar/io/smi.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd17428fdcb7888365cffc59867976f415841d79
--- /dev/null
+++ b/graph_grammar/io/smi.py
@@ -0,0 +1,559 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 12 2018"
+
+from copy import deepcopy
+from rdkit import Chem
+from rdkit import RDLogger
+import networkx as nx
+import numpy as np
+from ..hypergraph import Hypergraph
+from ..graph_grammar.symbols import TSymbol, BondSymbol
+
+# supress warnings
+lg = RDLogger.logger()
+lg.setLevel(RDLogger.CRITICAL)
+
+
+class HGGen(object):
+ """
+ load .smi file and yield a hypergraph.
+
+ Attributes
+ ----------
+ path_to_file : str
+ path to .smi file
+ kekulize : bool
+ kekulize or not
+ add_Hs : bool
+ add implicit hydrogens to the molecule or not.
+ all_single : bool
+ if True, all multiple bonds are summarized into a single bond with some attributes
+
+ Yields
+ ------
+ Hypergraph
+ """
+ def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
+ self.num_line = 1
+ self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
+ self.kekulize = kekulize
+ self.add_Hs = add_Hs
+ self.all_single = all_single
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ '''
+ each_mol = None
+ while each_mol is None:
+ each_mol = next(self.mol_gen)
+ '''
+ # not ignoring parse errors
+ each_mol = next(self.mol_gen)
+ if each_mol is None:
+ raise ValueError(f'incorrect smiles in line {self.num_line}')
+ else:
+ self.num_line += 1
+ return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
+
+
+def mol_to_bipartite(mol, kekulize):
+ """
+ get a bipartite representation of a molecule.
+
+ Parameters
+ ----------
+ mol : rdkit.Chem.rdchem.Mol
+ molecule object
+
+ Returns
+ -------
+ nx.Graph
+ a bipartite graph representing which bond is connected to which atoms.
+ """
+ try:
+ mol = standardize_stereo(mol)
+ except KeyError:
+ print(Chem.MolToSmiles(mol))
+ raise KeyError
+
+ if kekulize:
+ Chem.Kekulize(mol)
+
+ bipartite_g = nx.Graph()
+ for each_atom in mol.GetAtoms():
+ bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
+ atom_attr=atom_attr(each_atom, kekulize))
+
+ for each_bond in mol.GetBonds():
+ bond_idx = each_bond.GetIdx()
+ bipartite_g.add_node(
+ f"bond_{bond_idx}",
+ bond_attr=bond_attr(each_bond, kekulize))
+ bipartite_g.add_edge(
+ f"atom_{each_bond.GetBeginAtomIdx()}",
+ f"bond_{bond_idx}")
+ bipartite_g.add_edge(
+ f"atom_{each_bond.GetEndAtomIdx()}",
+ f"bond_{bond_idx}")
+ return bipartite_g
+
+
+def mol_to_hg(mol, kekulize, add_Hs):
+ """
+ get a bipartite representation of a molecule.
+
+ Parameters
+ ----------
+ mol : rdkit.Chem.rdchem.Mol
+ molecule object
+ kekulize : bool
+ kekulize or not
+ add_Hs : bool
+ add implicit hydrogens to the molecule or not.
+
+ Returns
+ -------
+ Hypergraph
+ """
+ if add_Hs:
+ mol = Chem.AddHs(mol)
+
+ if kekulize:
+ Chem.Kekulize(mol)
+
+ bipartite_g = mol_to_bipartite(mol, kekulize)
+ hg = Hypergraph()
+ for each_atom in [each_node for each_node in bipartite_g.nodes()
+ if each_node.startswith('atom_')]:
+ node_set = set([])
+ for each_bond in bipartite_g.adj[each_atom]:
+ hg.add_node(each_bond,
+ attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
+ node_set.add(each_bond)
+ hg.add_edge(node_set,
+ attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
+ return hg
+
+
+def hg_to_mol(hg, verbose=False):
+ """ convert a hypergraph into Mol object
+
+ Parameters
+ ----------
+ hg : Hypergraph
+
+ Returns
+ -------
+ mol : Chem.RWMol
+ """
+ mol = Chem.RWMol()
+ atom_dict = {}
+ bond_set = set([])
+ for each_edge in hg.edges:
+ atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
+ atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
+ atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
+ atom.SetChiralTag(
+ Chem.rdchem.ChiralType.values[
+ hg.edge_attr(each_edge)['symbol'].chirality])
+ atom_idx = mol.AddAtom(atom)
+ atom_dict[each_edge] = atom_idx
+
+ for each_node in hg.nodes:
+ edge_1, edge_2 = hg.adj_edges(each_node)
+ if edge_1+edge_2 not in bond_set:
+ if hg.node_attr(each_node)['symbol'].bond_type <= 3:
+ num_bond = hg.node_attr(each_node)['symbol'].bond_type
+ elif hg.node_attr(each_node)['symbol'].bond_type == 12:
+ num_bond = 1
+ else:
+ raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
+ _ = mol.AddBond(atom_dict[edge_1],
+ atom_dict[edge_2],
+ order=Chem.rdchem.BondType.values[num_bond])
+ bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
+
+ # stereo
+ mol.GetBondWithIdx(bond_idx).SetStereo(
+ Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
+ bond_set.update([edge_1+edge_2])
+ bond_set.update([edge_2+edge_1])
+ mol.UpdatePropertyCache()
+ mol = mol.GetMol()
+ not_stereo_mol = deepcopy(mol)
+ if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
+ raise RuntimeError('no valid molecule was obtained.')
+ try:
+ mol = set_stereo(mol)
+ is_stereo = True
+ except:
+ import traceback
+ traceback.print_exc()
+ is_stereo = False
+ mol_tmp = deepcopy(mol)
+ Chem.SetAromaticity(mol_tmp)
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
+ mol = mol_tmp
+ else:
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
+ mol = not_stereo_mol
+ mol.UpdatePropertyCache()
+ Chem.GetSymmSSSR(mol)
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
+ if verbose:
+ return mol, is_stereo
+ else:
+ return mol
+
+def hgs_to_mols(hg_list, ignore_error=False):
+ if ignore_error:
+ mol_list = []
+ for each_hg in hg_list:
+ try:
+ mol = hg_to_mol(each_hg)
+ except:
+ mol = None
+ mol_list.append(mol)
+ else:
+ mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
+ return mol_list
+
+def hgs_to_smiles(hg_list, ignore_error=False):
+ mol_list = hgs_to_mols(hg_list, ignore_error)
+ smiles_list = []
+ for each_mol in mol_list:
+ try:
+ smiles_list.append(
+ Chem.MolToSmiles(
+ Chem.MolFromSmiles(
+ Chem.MolToSmiles(
+ each_mol))))
+ except:
+ smiles_list.append(None)
+ return smiles_list
+
+def atom_attr(atom, kekulize):
+ """
+ get atom's attributes
+
+ Parameters
+ ----------
+ atom : rdkit.Chem.rdchem.Atom
+ kekulize : bool
+ kekulize or not
+
+ Returns
+ -------
+ atom_attr : dict
+ "is_aromatic" : bool
+ the atom is aromatic or not.
+ "smarts" : str
+ SMARTS representation of the atom.
+ """
+ if kekulize:
+ return {'terminal': True,
+ 'is_in_ring': atom.IsInRing(),
+ 'symbol': TSymbol(degree=0,
+ #degree=atom.GetTotalDegree(),
+ is_aromatic=False,
+ symbol=atom.GetSymbol(),
+ num_explicit_Hs=atom.GetNumExplicitHs(),
+ formal_charge=atom.GetFormalCharge(),
+ chirality=atom.GetChiralTag().real
+ )}
+ else:
+ return {'terminal': True,
+ 'is_in_ring': atom.IsInRing(),
+ 'symbol': TSymbol(degree=0,
+ #degree=atom.GetTotalDegree(),
+ is_aromatic=atom.GetIsAromatic(),
+ symbol=atom.GetSymbol(),
+ num_explicit_Hs=atom.GetNumExplicitHs(),
+ formal_charge=atom.GetFormalCharge(),
+ chirality=atom.GetChiralTag().real
+ )}
+
+def bond_attr(bond, kekulize):
+ """
+ get atom's attributes
+
+ Parameters
+ ----------
+ bond : rdkit.Chem.rdchem.Bond
+ kekulize : bool
+ kekulize or not
+
+ Returns
+ -------
+ bond_attr : dict
+ "bond_type" : int
+ {0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
+ 1: rdkit.Chem.rdchem.BondType.SINGLE,
+ 2: rdkit.Chem.rdchem.BondType.DOUBLE,
+ 3: rdkit.Chem.rdchem.BondType.TRIPLE,
+ 4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
+ 5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
+ 6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
+ 7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
+ 8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
+ 9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
+ 10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
+ 11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
+ 12: rdkit.Chem.rdchem.BondType.AROMATIC,
+ 13: rdkit.Chem.rdchem.BondType.IONIC,
+ 14: rdkit.Chem.rdchem.BondType.HYDROGEN,
+ 15: rdkit.Chem.rdchem.BondType.THREECENTER,
+ 16: rdkit.Chem.rdchem.BondType.DATIVEONE,
+ 17: rdkit.Chem.rdchem.BondType.DATIVE,
+ 18: rdkit.Chem.rdchem.BondType.DATIVEL,
+ 19: rdkit.Chem.rdchem.BondType.DATIVER,
+ 20: rdkit.Chem.rdchem.BondType.OTHER,
+ 21: rdkit.Chem.rdchem.BondType.ZERO}
+ """
+ if kekulize:
+ is_aromatic = False
+ if bond.GetBondType().real == 12:
+ bond_type = 1
+ else:
+ bond_type = bond.GetBondType().real
+ else:
+ is_aromatic = bond.GetIsAromatic()
+ bond_type = bond.GetBondType().real
+ return {'symbol': BondSymbol(is_aromatic=is_aromatic,
+ bond_type=bond_type,
+ stereo=int(bond.GetStereo())),
+ 'is_in_ring': bond.IsInRing()}
+
+
+def standardize_stereo(mol):
+ '''
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
+
+ '''
+ # mol = Chem.AddHs(mol) # this removes CIPRank !!!
+ for each_bond in mol.GetBonds():
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
+ atom_idx_1 = each_bond.GetStereoAtoms()[0]
+ atom_idx_2 = each_bond.GetStereoAtoms()[1]
+ if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
+ begin_atom_idx = atom_idx_1
+ end_atom_idx = atom_idx_2
+ else:
+ begin_atom_idx = atom_idx_2
+ end_atom_idx = atom_idx_1
+
+ begin_another_atom_idx = None
+ assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
+ for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
+ each_neighbor_idx = each_neighbor.GetIdx()
+ if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
+ begin_another_atom_idx = each_neighbor_idx
+
+ end_another_atom_idx = None
+ assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
+ for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
+ each_neighbor_idx = each_neighbor.GetIdx()
+ if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
+ end_another_atom_idx = each_neighbor_idx
+
+ '''
+ relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
+ '''
+ begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
+ end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
+ try:
+ begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
+ except:
+ begin_another_atom_rank = np.inf
+ try:
+ end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
+ except:
+ end_another_atom_rank = np.inf
+ if begin_atom_rank < begin_another_atom_rank\
+ and end_atom_rank < end_another_atom_rank:
+ pass
+ elif begin_atom_rank < begin_another_atom_rank\
+ and end_atom_rank > end_another_atom_rank:
+ # (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
+ if each_bond.GetStereo() == 2:
+ # set stereo
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
+ # set bond dir
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
+ elif each_bond.GetStereo() == 3:
+ # set stereo
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
+ # set bond dir
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
+ else:
+ raise ValueError
+ each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
+ elif begin_atom_rank > begin_another_atom_rank\
+ and end_atom_rank < end_another_atom_rank:
+ # (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
+ if each_bond.GetStereo() == 2:
+ # set stereo
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
+ # set bond dir
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
+ elif each_bond.GetStereo() == 3:
+ # set stereo
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
+ # set bond dir
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
+ else:
+ raise ValueError
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
+ elif begin_atom_rank > begin_another_atom_rank\
+ and end_atom_rank > end_another_atom_rank:
+ # begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
+ if each_bond.GetStereo() == 2:
+ # set bond dir
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
+ elif each_bond.GetStereo() == 3:
+ # set bond dir
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
+ else:
+ raise ValueError
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
+ else:
+ raise RuntimeError
+ return mol
+
+
+def set_stereo(mol):
+ '''
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
+ '''
+ _mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
+ Chem.Kekulize(_mol, True)
+ substruct_match = mol.GetSubstructMatch(_mol)
+ if not substruct_match:
+ ''' mol and _mol are kekulized.
+ sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
+ '''
+ Chem.SetAromaticity(mol)
+ Chem.SetAromaticity(_mol)
+ substruct_match = mol.GetSubstructMatch(_mol)
+ try:
+ atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
+ except:
+ raise ValueError('two molecules obtained from the same data do not match.')
+
+ for each_bond in mol.GetBonds():
+ begin_atom_idx = each_bond.GetBeginAtomIdx()
+ end_atom_idx = each_bond.GetEndAtomIdx()
+ _bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
+ _bond.SetStereo(each_bond.GetStereo())
+
+ mol = _mol
+ for each_bond in mol.GetBonds():
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
+ begin_atom_idx_set = set([each_neighbor.GetIdx()
+ for each_neighbor
+ in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
+ if each_neighbor.GetIdx() != end_stereo_atom_idx])
+ end_atom_idx_set = set([each_neighbor.GetIdx()
+ for each_neighbor
+ in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
+ if each_neighbor.GetIdx() != begin_stereo_atom_idx])
+ if not begin_atom_idx_set:
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
+ continue
+ if not end_atom_idx_set:
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
+ continue
+ if len(begin_atom_idx_set) == 1:
+ begin_atom_idx = begin_atom_idx_set.pop()
+ begin_another_atom_idx = None
+ if len(end_atom_idx_set) == 1:
+ end_atom_idx = end_atom_idx_set.pop()
+ end_another_atom_idx = None
+ if len(begin_atom_idx_set) == 2:
+ atom_idx_1 = begin_atom_idx_set.pop()
+ atom_idx_2 = begin_atom_idx_set.pop()
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
+ begin_atom_idx = atom_idx_1
+ begin_another_atom_idx = atom_idx_2
+ else:
+ begin_atom_idx = atom_idx_2
+ begin_another_atom_idx = atom_idx_1
+ if len(end_atom_idx_set) == 2:
+ atom_idx_1 = end_atom_idx_set.pop()
+ atom_idx_2 = end_atom_idx_set.pop()
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
+ end_atom_idx = atom_idx_1
+ end_another_atom_idx = atom_idx_2
+ else:
+ end_atom_idx = atom_idx_2
+ end_another_atom_idx = atom_idx_1
+
+ if each_bond.GetStereo() == 2: # same side
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
+ elif each_bond.GetStereo() == 3: # opposite side
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
+ else:
+ raise ValueError
+ return mol
+
+
+def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
+ if atom_idx_1 is None or atom_idx_2 is None:
+ return mol
+ else:
+ mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
+ return mol
+
diff --git a/graph_grammar/nn/__init__.py b/graph_grammar/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2912d326a2ed386143a635e5dc674b3ea7ce01f
--- /dev/null
+++ b/graph_grammar/nn/__init__.py
@@ -0,0 +1,11 @@
+# -*- coding:utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
diff --git a/graph_grammar/nn/__pycache__/__init__.cpython-310.pyc b/graph_grammar/nn/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a19d9c8a3800a7939024885b5a30651bd8bb1695
Binary files /dev/null and b/graph_grammar/nn/__pycache__/__init__.cpython-310.pyc differ
diff --git a/graph_grammar/nn/__pycache__/decoder.cpython-310.pyc b/graph_grammar/nn/__pycache__/decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08e683bfd8bb50a3efb57086958589af1bc760b1
Binary files /dev/null and b/graph_grammar/nn/__pycache__/decoder.cpython-310.pyc differ
diff --git a/graph_grammar/nn/__pycache__/encoder.cpython-310.pyc b/graph_grammar/nn/__pycache__/encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae5d350bbfea93dd453eacbf3ab4afa188f2e61e
Binary files /dev/null and b/graph_grammar/nn/__pycache__/encoder.cpython-310.pyc differ
diff --git a/graph_grammar/nn/dataset.py b/graph_grammar/nn/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..70894b289d6413dd29c81dfd3934574c6edb1383
--- /dev/null
+++ b/graph_grammar/nn/dataset.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Apr 18 2018"
+
+from torch.utils.data import Dataset, DataLoader
+import torch
+import numpy as np
+
+
+def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False):
+ ''' pad left
+
+ Parameters
+ ----------
+ sentence_list : list of sequences of integers
+ max_len : int
+ maximum length of sentences.
+ if a sentence is shorter than `max_len`, its left part is padded.
+ pad_idx : int
+ integer for padding
+ inverse : bool
+ if True, the sequence is inversed.
+
+ Returns
+ -------
+ List of torch.LongTensor
+ each sentence is left-padded.
+ '''
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
+
+ if max_in_list > max_len:
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
+
+ if inverse:
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list]
+ else:
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list]
+
+
+def right_padding(sentence_list, max_len, pad_idx=-1):
+ ''' pad right
+
+ Parameters
+ ----------
+ sentence_list : list of sequences of integers
+ max_len : int
+ maximum length of sentences.
+ if a sentence is shorter than `max_len`, its right part is padded.
+ pad_idx : int
+ integer for padding
+
+ Returns
+ -------
+ List of torch.LongTensor
+ each sentence is right-padded.
+ '''
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
+ if max_in_list > max_len:
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
+
+ return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list]
+
+
+class HRGDataset(Dataset):
+
+ '''
+ A class of HRG data
+ '''
+
+ def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False):
+ self.hrg = hrg
+ self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list,
+ max_len,
+ inverse=inversed_input)
+
+ self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len)
+ self.inserved_input = inversed_input
+ self.target_val_list = target_val_list
+ if target_val_list is not None:
+ if len(prod_rule_seq_list) != len(target_val_list):
+ raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}')
+
+ def __len__(self):
+ return len(self.left_prod_rule_seq_list)
+
+ def __getitem__(self, idx):
+ if self.target_val_list is not None:
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx])
+ else:
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx]
+
+ @property
+ def vocab_size(self):
+ return self.hrg.num_prod_rule
+
+def batch_padding(each_batch, batch_size, padding_idx):
+ num_pad = batch_size - len(each_batch[0])
+ if num_pad:
+ each_batch[0] = torch.cat([each_batch[0],
+ padding_idx * torch.ones((batch_size - len(each_batch[0]),
+ len(each_batch[0][0])), dtype=torch.int64)], dim=0)
+ each_batch[1] = torch.cat([each_batch[1],
+ padding_idx * torch.ones((batch_size - len(each_batch[1]),
+ len(each_batch[1][0])), dtype=torch.int64)], dim=0)
+ return each_batch, num_pad
diff --git a/graph_grammar/nn/decoder.py b/graph_grammar/nn/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3b9c9d69e7077650355e1df201b18420de55471
--- /dev/null
+++ b/graph_grammar/nn/decoder.py
@@ -0,0 +1,158 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Aug 9 2018"
+
+
+import abc
+import numpy as np
+import torch
+from torch import nn
+
+
+class DecoderBase(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.hidden_dict = {}
+
+ @abc.abstractmethod
+ def forward_one_step(self, tgt_emb_in):
+ ''' one-step forward model
+
+ Parameters
+ ----------
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
+
+ Returns
+ -------
+ Tensor, shape (batch_size, hidden_dim)
+ '''
+ tgt_emb_out = None
+ return tgt_emb_out
+
+ @abc.abstractmethod
+ def init_hidden(self):
+ ''' initialize the hidden states
+ '''
+ pass
+
+ @abc.abstractmethod
+ def feed_hidden(self, hidden_dict_0):
+ for each_hidden in self.hidden_dict.keys():
+ self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden]
+
+
+class GRUDecoder(DecoderBase):
+
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
+ dropout: float, batch_size: int, use_gpu: bool,
+ no_dropout=False):
+ super().__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.batch_size = batch_size
+ self.use_gpu = use_gpu
+ self.model = nn.GRU(input_size=self.input_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ batch_first=True,
+ bidirectional=False,
+ dropout=self.dropout if not no_dropout else 0
+ )
+ if self.use_gpu:
+ self.model.cuda()
+ self.init_hidden()
+
+ def init_hidden(self):
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
+ self.batch_size,
+ self.hidden_dim),
+ requires_grad=False)
+ if self.use_gpu:
+ self.hidden_dict['h'] = self.hidden_dict['h'].cuda()
+
+ def forward_one_step(self, tgt_emb_in):
+ ''' one-step forward model
+
+ Parameters
+ ----------
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
+
+ Returns
+ -------
+ Tensor, shape (batch_size, hidden_dim)
+ '''
+ tgt_emb_out, self.hidden_dict['h'] \
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
+ self.hidden_dict['h'])
+ return tgt_emb_out
+
+
+class LSTMDecoder(DecoderBase):
+
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
+ dropout: float, batch_size: int, use_gpu: bool,
+ no_dropout=False):
+ super().__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.batch_size = batch_size
+ self.use_gpu = use_gpu
+ self.model = nn.LSTM(input_size=self.input_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ batch_first=True,
+ bidirectional=False,
+ dropout=self.dropout if not no_dropout else 0)
+ if self.use_gpu:
+ self.model.cuda()
+ self.init_hidden()
+
+ def init_hidden(self):
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
+ self.batch_size,
+ self.hidden_dim),
+ requires_grad=False)
+ self.hidden_dict['c'] = torch.zeros((self.num_layers,
+ self.batch_size,
+ self.hidden_dim),
+ requires_grad=False)
+ if self.use_gpu:
+ for each_hidden in self.hidden_dict.keys():
+ self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda()
+
+ def forward_one_step(self, tgt_emb_in):
+ ''' one-step forward model
+
+ Parameters
+ ----------
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
+
+ Returns
+ -------
+ Tensor, shape (batch_size, hidden_dim)
+ '''
+ tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
+ self.hidden_dict['h'], self.hidden_dict['c'])
+ return tgt_hidden_out
diff --git a/graph_grammar/nn/encoder.py b/graph_grammar/nn/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a59226d9980479503ed482be26976dd4917b9953
--- /dev/null
+++ b/graph_grammar/nn/encoder.py
@@ -0,0 +1,199 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Aug 9 2018"
+
+
+import abc
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from typing import List
+
+
+class EncoderBase(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ @abc.abstractmethod
+ def forward(self, in_seq):
+ ''' forward model
+
+ Parameters
+ ----------
+ in_seq_emb : Variable, shape (batch_size, max_len, input_dim)
+
+ Returns
+ -------
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
+ '''
+ pass
+
+ @abc.abstractmethod
+ def init_hidden(self):
+ ''' initialize the hidden states
+ '''
+ pass
+
+
+class GRUEncoder(EncoderBase):
+
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
+ no_dropout=False):
+ super().__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_layers = num_layers
+ self.bidirectional = bidirectional
+ self.dropout = dropout
+ self.batch_size = batch_size
+ self.use_gpu = use_gpu
+ self.model = nn.GRU(input_size=self.input_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ batch_first=True,
+ bidirectional=self.bidirectional,
+ dropout=self.dropout if not no_dropout else 0)
+ if self.use_gpu:
+ self.model.cuda()
+ self.init_hidden()
+
+
+ def init_hidden(self):
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
+ self.batch_size,
+ self.hidden_dim),
+ requires_grad=False)
+ if self.use_gpu:
+ self.h0 = self.h0.cuda()
+
+ def forward(self, in_seq_emb):
+ ''' forward model
+
+ Parameters
+ ----------
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
+
+ Returns
+ -------
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
+ '''
+ max_len = in_seq_emb.size(1)
+ hidden_seq_emb, self.h0 = self.model(
+ in_seq_emb, self.h0)
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
+ max_len,
+ 1 + self.bidirectional,
+ self.hidden_dim)
+ return hidden_seq_emb
+
+
+class LSTMEncoder(EncoderBase):
+
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
+ no_dropout=False):
+ super().__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_layers = num_layers
+ self.bidirectional = bidirectional
+ self.dropout = dropout
+ self.batch_size = batch_size
+ self.use_gpu = use_gpu
+ self.model = nn.LSTM(input_size=self.input_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ batch_first=True,
+ bidirectional=self.bidirectional,
+ dropout=self.dropout if not no_dropout else 0)
+ if self.use_gpu:
+ self.model.cuda()
+ self.init_hidden()
+
+ def init_hidden(self):
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
+ self.batch_size,
+ self.hidden_dim),
+ requires_grad=False)
+ self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
+ self.batch_size,
+ self.hidden_dim),
+ requires_grad=False)
+ if self.use_gpu:
+ self.h0 = self.h0.cuda()
+ self.c0 = self.c0.cuda()
+
+ def forward(self, in_seq_emb):
+ ''' forward model
+
+ Parameters
+ ----------
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
+
+ Returns
+ -------
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
+ '''
+ max_len = in_seq_emb.size(1)
+ hidden_seq_emb, (self.h0, self.c0) = self.model(
+ in_seq_emb, (self.h0, self.c0))
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
+ max_len,
+ 1 + self.bidirectional,
+ self.hidden_dim)
+ return hidden_seq_emb
+
+
+class FullConnectedEncoder(EncoderBase):
+
+ def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int],
+ batch_size: int, use_gpu: bool):
+ super().__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.max_len = max_len
+ self.hidden_dim_list = hidden_dim_list
+ self.use_gpu = use_gpu
+ in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim]
+ self.linear_list = nn.ModuleList(
+ [nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\
+ for each_idx in range(len(in_out_dim_list) - 1)])
+
+ def forward(self, in_seq_emb):
+ ''' forward model
+
+ Parameters
+ ----------
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
+
+ Returns
+ -------
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
+ '''
+ batch_size = in_seq_emb.size(0)
+ x = in_seq_emb.view(batch_size, -1)
+ for each_linear in self.linear_list:
+ x = F.relu(each_linear(x))
+ return x.view(batch_size, 1, -1)
+
+ def init_hidden(self):
+ pass
diff --git a/graph_grammar/nn/graph.py b/graph_grammar/nn/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2da699b2000ddf81c856c2c636f8ccdb864c81c
--- /dev/null
+++ b/graph_grammar/nn/graph.py
@@ -0,0 +1,313 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+""" Title """
+
+__author__ = "Hiroshi Kajino "
+__copyright__ = "(c) Copyright IBM Corp. 2018"
+__version__ = "0.1"
+__date__ = "Jan 1 2018"
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus
+from torch import nn
+from torch.autograd import Variable
+
+class MolecularProdRuleEmbedding(nn.Module):
+
+ ''' molecular fingerprint layer
+ '''
+
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
+ out_dim=32, element_embed_dim=32,
+ num_layers=3, padding_idx=None, use_gpu=False):
+ super().__init__()
+ if padding_idx is not None:
+ assert padding_idx == -1, 'padding_idx must be -1.'
+ self.prod_rule_corpus = prod_rule_corpus
+ self.layer2layer_activation = layer2layer_activation
+ self.layer2out_activation = layer2out_activation
+ self.out_dim = out_dim
+ self.element_embed_dim = element_embed_dim
+ self.num_layers = num_layers
+ self.padding_idx = padding_idx
+ self.use_gpu = use_gpu
+
+ self.layer2layer_list = []
+ self.layer2out_list = []
+
+ if self.use_gpu:
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
+ self.element_embed_dim, requires_grad=True).cuda()
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
+ self.element_embed_dim, requires_grad=True).cuda()
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
+ self.element_embed_dim, requires_grad=True).cuda()
+ for _ in range(num_layers):
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
+ else:
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
+ self.element_embed_dim, requires_grad=True)
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
+ self.element_embed_dim, requires_grad=True)
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
+ self.element_embed_dim, requires_grad=True)
+ for _ in range(num_layers):
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
+
+
+ def forward(self, prod_rule_idx_seq):
+ ''' forward model for mini-batch
+
+ Parameters
+ ----------
+ prod_rule_idx_seq : (batch_size, length)
+
+ Returns
+ -------
+ Variable, shape (batch_size, length, out_dim)
+ '''
+ batch_size, length = prod_rule_idx_seq.shape
+ if self.use_gpu:
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
+ else:
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
+ for each_batch_idx in range(batch_size):
+ for each_idx in range(length):
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
+ continue
+ else:
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
+ layer_wise_embed_dict = {each_edge: self.atom_embed[
+ each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
+ for each_edge in each_prod_rule.rhs.edges}
+ layer_wise_embed_dict.update({each_node: self.bond_embed[
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]
+ for each_node in each_prod_rule.rhs.nodes})
+ for each_node in each_prod_rule.rhs.nodes:
+ if 'ext_id' in each_prod_rule.rhs.node_attr(each_node):
+ layer_wise_embed_dict[each_node] \
+ = layer_wise_embed_dict[each_node] \
+ + self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']]
+
+ for each_layer in range(self.num_layers):
+ next_layer_embed_dict = {}
+ for each_edge in each_prod_rule.rhs.edges:
+ v = layer_wise_embed_dict[each_edge]
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
+ v = v + layer_wise_embed_dict[each_node]
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
+ out[each_batch_idx, each_idx, :] \
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
+ for each_node in each_prod_rule.rhs.nodes:
+ v = layer_wise_embed_dict[each_node]
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
+ v = v + layer_wise_embed_dict[each_edge]
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
+ out[each_batch_idx, each_idx, :]\
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
+ layer_wise_embed_dict = next_layer_embed_dict
+
+ return out
+
+
+class MolecularProdRuleEmbeddingLastLayer(nn.Module):
+
+ ''' molecular fingerprint layer
+ '''
+
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
+ out_dim=32, element_embed_dim=32,
+ num_layers=3, padding_idx=None, use_gpu=False):
+ super().__init__()
+ if padding_idx is not None:
+ assert padding_idx == -1, 'padding_idx must be -1.'
+ self.prod_rule_corpus = prod_rule_corpus
+ self.layer2layer_activation = layer2layer_activation
+ self.layer2out_activation = layer2out_activation
+ self.out_dim = out_dim
+ self.element_embed_dim = element_embed_dim
+ self.num_layers = num_layers
+ self.padding_idx = padding_idx
+ self.use_gpu = use_gpu
+
+ self.layer2layer_list = []
+ self.layer2out_list = []
+
+ if self.use_gpu:
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda()
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda()
+ for _ in range(num_layers+1):
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
+ else:
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim)
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim)
+ for _ in range(num_layers+1):
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
+
+
+ def forward(self, prod_rule_idx_seq):
+ ''' forward model for mini-batch
+
+ Parameters
+ ----------
+ prod_rule_idx_seq : (batch_size, length)
+
+ Returns
+ -------
+ Variable, shape (batch_size, length, out_dim)
+ '''
+ batch_size, length = prod_rule_idx_seq.shape
+ if self.use_gpu:
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
+ else:
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
+ for each_batch_idx in range(batch_size):
+ for each_idx in range(length):
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
+ continue
+ else:
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
+
+ if self.use_gpu:
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
+ Variable(torch.LongTensor(
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
+ ), requires_grad=False).cuda())
+ for each_edge in each_prod_rule.rhs.edges}
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
+ Variable(
+ torch.LongTensor([
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
+ requires_grad=False).cuda()
+ ) for each_node in each_prod_rule.rhs.nodes})
+ else:
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
+ Variable(torch.LongTensor(
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
+ ), requires_grad=False))
+ for each_edge in each_prod_rule.rhs.edges}
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
+ Variable(
+ torch.LongTensor([
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
+ requires_grad=False)
+ ) for each_node in each_prod_rule.rhs.nodes})
+
+ for each_layer in range(self.num_layers):
+ next_layer_embed_dict = {}
+ for each_edge in each_prod_rule.rhs.edges:
+ v = layer_wise_embed_dict[each_edge]
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
+ v += layer_wise_embed_dict[each_node]
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
+ for each_node in each_prod_rule.rhs.nodes:
+ v = layer_wise_embed_dict[each_node]
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
+ v += layer_wise_embed_dict[each_edge]
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
+ layer_wise_embed_dict = next_layer_embed_dict
+ for each_edge in each_prod_rule.rhs.edges:
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
+ for each_edge in each_prod_rule.rhs.edges:
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
+
+ return out
+
+
+class MolecularProdRuleEmbeddingUsingFeatures(nn.Module):
+
+ ''' molecular fingerprint layer
+ '''
+
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
+ out_dim=32, num_layers=3, padding_idx=None, use_gpu=False):
+ super().__init__()
+ if padding_idx is not None:
+ assert padding_idx == -1, 'padding_idx must be -1.'
+ self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors()
+ self.prod_rule_corpus = prod_rule_corpus
+ self.layer2layer_activation = layer2layer_activation
+ self.layer2out_activation = layer2out_activation
+ self.out_dim = out_dim
+ self.num_layers = num_layers
+ self.padding_idx = padding_idx
+ self.use_gpu = use_gpu
+
+ self.layer2layer_list = []
+ self.layer2out_list = []
+
+ if self.use_gpu:
+ for each_key in self.feature_dict:
+ self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda()
+ for _ in range(num_layers):
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda())
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda())
+ else:
+ for _ in range(num_layers):
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim))
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim))
+
+
+ def forward(self, prod_rule_idx_seq):
+ ''' forward model for mini-batch
+
+ Parameters
+ ----------
+ prod_rule_idx_seq : (batch_size, length)
+
+ Returns
+ -------
+ Variable, shape (batch_size, length, out_dim)
+ '''
+ batch_size, length = prod_rule_idx_seq.shape
+ if self.use_gpu:
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
+ else:
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
+ for each_batch_idx in range(batch_size):
+ for each_idx in range(length):
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
+ continue
+ else:
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
+ edge_list = sorted(list(each_prod_rule.rhs.edges))
+ node_list = sorted(list(each_prod_rule.rhs.nodes))
+ adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list)))
+ if self.use_gpu:
+ adj_mat = adj_mat.cuda()
+ layer_wise_embed = [
+ self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']]
+ for each_edge in edge_list]\
+ + [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']]
+ for each_node in node_list]
+ for each_node in each_prod_rule.ext_node.values():
+ layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
+ = layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
+ + self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])]
+ layer_wise_embed = torch.stack(layer_wise_embed)
+
+ for each_layer in range(self.num_layers):
+ message = adj_mat @ layer_wise_embed
+ next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message))
+ out[each_batch_idx, each_idx, :] \
+ = out[each_batch_idx, each_idx, :] \
+ + self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0)
+ layer_wise_embed = next_layer_embed
+ return out
diff --git a/images/mhg_example.png b/images/mhg_example.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a7dd8ce73476fba75ed242e67147946d99740eb
Binary files /dev/null and b/images/mhg_example.png differ
diff --git a/images/mhg_example1.png b/images/mhg_example1.png
new file mode 100644
index 0000000000000000000000000000000000000000..150b71f10580655433a6f59a60cbc2afc07d8dc8
Binary files /dev/null and b/images/mhg_example1.png differ
diff --git a/images/mhg_example2.png b/images/mhg_example2.png
new file mode 100644
index 0000000000000000000000000000000000000000..b00f97a7fb3bec25c0e6e42990d18aaa216eff2d
Binary files /dev/null and b/images/mhg_example2.png differ
diff --git a/load.py b/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe3208aa7b5c5efc97691fec0069c20f56f64c15
--- /dev/null
+++ b/load.py
@@ -0,0 +1,83 @@
+# -*- coding:utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+import os
+import pickle
+import sys
+
+from rdkit import Chem
+import torch
+from torch_geometric.utils.smiles import from_smiles
+
+from typing import Any, Dict, List, Optional, Union
+from typing_extensions import Self
+
+from .graph_grammar.io.smi import hg_to_mol
+from .models.mhgvae import GrammarGINVAE
+
+
+class PretrainedModelWrapper:
+ model: GrammarGINVAE
+
+ def __init__(self, model_dict: Dict[str, Any]) -> None:
+ json_params = model_dict['gnn_params']
+ encoder_params = json_params['encoder_params']
+ encoder_params['node_feature_size'] = model_dict['num_features']
+ encoder_params['edge_feature_size'] = model_dict['num_edge_features']
+ self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
+ decoder_params=json_params['decoder_params'],
+ prod_rule_embed_params=json_params["prod_rule_embed_params"],
+ batch_size=512, max_len=model_dict['max_length'])
+ self.model.load_state_dict(model_dict['model_state_dict'])
+
+ self.model.eval()
+
+ def to(self, device: Union[str, int, torch.device]) -> Self:
+ dev_type = type(device)
+ if dev_type != torch.device:
+ if dev_type == str or torch.cuda.is_available():
+ device = torch.device(device)
+ else:
+ device = torch.device("mps", device)
+
+ self.model = self.model.to(device)
+ return self
+
+ def encode(self, data: List[str]) -> List[torch.tensor]:
+ # Need to encode them into a graph nn
+ output = []
+ for d in data:
+ params = next(self.model.parameters())
+ g = from_smiles(d)
+ if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
+ g.to(params.device)
+ ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
+ output.append(ltvec[0])
+ return output
+
+ def decode(self, data: List[torch.tensor]) -> List[str]:
+ output = []
+ for d in data:
+ mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
+ z = self.model.reparameterize(mu, logvar)
+ flags, _, hgs = self.model.decode(z)
+ if flags[0]:
+ reconstructed_mol, _ = hg_to_mol(hgs[0], True)
+ output.append(Chem.MolToSmiles(reconstructed_mol))
+ else:
+ output.append(None)
+ return output
+
+
+def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
+ PretrainedModelWrapper]:
+ for p in sys.path:
+ file = p + "/" + model_name
+ if os.path.isfile(file):
+ with open(file, "rb") as f:
+ model_dict = pickle.load(f)
+ return PretrainedModelWrapper(model_dict)
+ return None
diff --git a/mhg_gnn.egg-info/PKG-INFO b/mhg_gnn.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..205fc9ded7ff7f8a67d589d1bc0fa997a77a067c
--- /dev/null
+++ b/mhg_gnn.egg-info/PKG-INFO
@@ -0,0 +1,102 @@
+Metadata-Version: 2.1
+Name: mhg-gnn
+Version: 0.0
+Summary: Package for mhg-gnn
+Author: team
+License: TBD
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.9
+Description-Content-Type: text/markdown
+Requires-Dist: networkx>=2.8
+Requires-Dist: numpy<2.0.0,>=1.23.5
+Requires-Dist: pandas>=1.5.3
+Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4
+Requires-Dist: torch>=2.0.0
+Requires-Dist: torchinfo>=1.8.0
+Requires-Dist: torch-geometric>=2.3.1
+
+# mhg-gnn
+
+This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
+
+**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
+
+For more information contact: SEIJITKD@jp.ibm.com
+
+![mhg-gnn](images/mhg_example1.png)
+
+## Introduction
+
+We present MHG-GNN, an autoencoder architecture
+that has an encoder based on GNN and a decoder based on a sequential model with MHG.
+Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
+demonstrate high predictive performance on molecular graph data.
+In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
+
+## Table of Contents
+
+1. [Getting Started](#getting-started)
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
+ 2. [Replicating Conda Environment](#replicating-conda-environment)
+2. [Feature Extraction](#feature-extraction)
+
+## Getting Started
+
+**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
+
+### Pretrained Models and Training Logs
+
+We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
+
+Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
+
+### Replacicating Conda Environment
+
+Follow these steps to replicate our Conda environment and install the necessary libraries:
+
+```
+conda create --name mhg-gnn-env python=3.8.18
+conda activate mhg-gnn-env
+```
+
+#### Install Packages with Conda
+
+```
+conda install -c conda-forge networkx=2.8
+conda install numpy=1.23.5
+# conda install -c conda-forge rdkit=2022.9.4
+conda install pytorch=2.0.0 torchvision torchaudio -c pytorch
+conda install -c conda-forge torchinfo=1.8.0
+conda install pyg -c pyg
+```
+
+#### Install Packages with pip
+```
+pip install rdkit torch-nl==0.3 torch-scatter torch-sparse
+```
+
+## Feature Extraction
+
+The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
+
+To load mhg-gnn, you can simply use:
+
+```python
+import torch
+import load
+
+model = load.load()
+```
+
+To encode SMILES into embeddings, you can use:
+
+```python
+with torch.no_grad():
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
+```
+
+For decoder, you can use the function, so you can return from embeddings to SMILES strings:
+
+```python
+orig = model.decode(repr)
+```
diff --git a/mhg_gnn.egg-info/SOURCES.txt b/mhg_gnn.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f6429c60d226eadd5f9fce9ba11de93451412c34
--- /dev/null
+++ b/mhg_gnn.egg-info/SOURCES.txt
@@ -0,0 +1,46 @@
+README.md
+setup.cfg
+setup.py
+./graph_grammar/__init__.py
+./graph_grammar/hypergraph.py
+./graph_grammar/algo/__init__.py
+./graph_grammar/algo/tree_decomposition.py
+./graph_grammar/graph_grammar/__init__.py
+./graph_grammar/graph_grammar/base.py
+./graph_grammar/graph_grammar/corpus.py
+./graph_grammar/graph_grammar/hrg.py
+./graph_grammar/graph_grammar/symbols.py
+./graph_grammar/graph_grammar/utils.py
+./graph_grammar/io/__init__.py
+./graph_grammar/io/smi.py
+./graph_grammar/nn/__init__.py
+./graph_grammar/nn/dataset.py
+./graph_grammar/nn/decoder.py
+./graph_grammar/nn/encoder.py
+./graph_grammar/nn/graph.py
+./models/__init__.py
+./models/mhgvae.py
+graph_grammar/__init__.py
+graph_grammar/hypergraph.py
+graph_grammar/algo/__init__.py
+graph_grammar/algo/tree_decomposition.py
+graph_grammar/graph_grammar/__init__.py
+graph_grammar/graph_grammar/base.py
+graph_grammar/graph_grammar/corpus.py
+graph_grammar/graph_grammar/hrg.py
+graph_grammar/graph_grammar/symbols.py
+graph_grammar/graph_grammar/utils.py
+graph_grammar/io/__init__.py
+graph_grammar/io/smi.py
+graph_grammar/nn/__init__.py
+graph_grammar/nn/dataset.py
+graph_grammar/nn/decoder.py
+graph_grammar/nn/encoder.py
+graph_grammar/nn/graph.py
+mhg_gnn.egg-info/PKG-INFO
+mhg_gnn.egg-info/SOURCES.txt
+mhg_gnn.egg-info/dependency_links.txt
+mhg_gnn.egg-info/requires.txt
+mhg_gnn.egg-info/top_level.txt
+models/__init__.py
+models/mhgvae.py
\ No newline at end of file
diff --git a/mhg_gnn.egg-info/dependency_links.txt b/mhg_gnn.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/mhg_gnn.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/mhg_gnn.egg-info/requires.txt b/mhg_gnn.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..54aa1b371905a3d46c6cbc15741cfb9b8a376c7d
--- /dev/null
+++ b/mhg_gnn.egg-info/requires.txt
@@ -0,0 +1,7 @@
+networkx>=2.8
+numpy<2.0.0,>=1.23.5
+pandas>=1.5.3
+rdkit-pypi<2023.9.6,>=2022.9.4
+torch>=2.0.0
+torchinfo>=1.8.0
+torch-geometric>=2.3.1
diff --git a/mhg_gnn.egg-info/top_level.txt b/mhg_gnn.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d606741958961ccbe1f156f690406bcd74658ad4
--- /dev/null
+++ b/mhg_gnn.egg-info/top_level.txt
@@ -0,0 +1,2 @@
+graph_grammar
+models
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1c71cf8133e8cd3edbcb23b4ecf2cd326ec0316
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding:utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecceb80a0ac40826b1514c4006fb28c2089b66bf
Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/__pycache__/mhgvae.cpython-310.pyc b/models/__pycache__/mhgvae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1de425d10bca4341388de64d4d0081f16fe663ea
Binary files /dev/null and b/models/__pycache__/mhgvae.cpython-310.pyc differ
diff --git a/models/mhgvae.py b/models/mhgvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..829a2c5567e7ffb2b61842840b83d428c2e2cbe0
--- /dev/null
+++ b/models/mhgvae.py
@@ -0,0 +1,956 @@
+# -*- coding:utf-8 -*-
+# Rhizome
+# Version beta 0.0, August 2023
+# Property of IBM Research, Accelerated Discovery
+#
+
+"""
+PLEASE NOTE THIS IMPLEMENTATION INCLUDES ADAPTED SOURCE CODE
+OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE,
+E.G., GRUEncoder/GRUDecoder, GrammarSeq2SeqVAE AND EVEN SOME METHODS OF GrammarGINVAE.
+THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
+"""
+
+import numpy as np
+import logging
+
+import torch
+from torch.autograd import Variable
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.loss import _Loss
+
+from torch_geometric.nn import MessagePassing
+from torch_geometric.nn import global_add_pool
+
+
+from ..graph_grammar.graph_grammar.symbols import NTSymbol
+from ..graph_grammar.nn.encoder import EncoderBase
+from ..graph_grammar.nn.decoder import DecoderBase
+
+def get_atom_edge_feature_dims():
+ from torch_geometric.utils.smiles import x_map, e_map
+ func = lambda x: len(x[1])
+ return list(map(func, x_map.items())), list(map(func, e_map.items()))
+
+
+class FeatureEmbedding(nn.Module):
+ def __init__(self, input_dims, embedded_dim):
+ super().__init__()
+ self.embedding_list = nn.ModuleList()
+ for dim in input_dims:
+ embedding = nn.Embedding(dim, embedded_dim)
+ self.embedding_list.append(embedding)
+
+ def forward(self, x):
+ output = 0
+ for i in range(x.shape[1]):
+ input = x[:, i].to(torch.int)
+ device = next(self.parameters()).device
+ if device != input.device:
+ input = input.to(device)
+ emb = self.embedding_list[i](input)
+ output += emb
+ return output
+
+
+class GRUEncoder(EncoderBase):
+
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
+ bidirectional: bool, dropout: float, batch_size: int, rank: int=-1,
+ no_dropout: bool=False):
+ super().__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_layers = num_layers
+ self.bidirectional = bidirectional
+ self.dropout = dropout
+ self.batch_size = batch_size
+ self.rank = rank
+ self.model = nn.GRU(input_size=self.input_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ batch_first=True,
+ bidirectional=self.bidirectional,
+ dropout=self.dropout if not no_dropout else 0)
+ if self.rank >= 0:
+ if torch.cuda.is_available():
+ self.model = self.model.to(rank)
+ else:
+ # support mac mps
+ self.model = self.model.to(torch.device("mps", rank))
+ self.init_hidden(self.batch_size)
+
+ def init_hidden(self, bsize):
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
+ min(self.batch_size, bsize),
+ self.hidden_dim),
+ requires_grad=False)
+ if self.rank >= 0:
+ if torch.cuda.is_available():
+ self.h0 = self.h0.to(self.rank)
+ else:
+ # support mac mps
+ self.h0 = self.h0.to(torch.device("mps", self.rank))
+
+ def to(self, device):
+ newself = super().to(device)
+ newself.model = newself.model.to(device)
+ newself.h0 = newself.h0.to(device)
+ newself.rank = next(newself.parameters()).get_device()
+ return newself
+
+ def forward(self, in_seq_emb):
+ ''' forward model
+
+ Parameters
+ ----------
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
+
+ Returns
+ -------
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
+ '''
+ # Kishi: I think original MHG had this init_hidden()
+ self.init_hidden(in_seq_emb.size(0))
+ max_len = in_seq_emb.size(1)
+ hidden_seq_emb, self.h0 = self.model(
+ in_seq_emb, self.h0)
+ # As shown as returns, convert hidden_seq_emb: (batch_size, seq_len, (1 or 2) * hidden_size) -->
+ # (batch_size, seq_len, 1 or 2, hidden_size)
+ # In the original input the original GRU/LSTM with bidirectional encoding
+ # has contactinated tensors
+ # (first half for forward RNN, latter half for backward RNN)
+ # so convert them in a more friendly format packed for each RNN
+ hidden_seq_emb = hidden_seq_emb.view(-1,
+ max_len,
+ 1 + self.bidirectional,
+ self.hidden_dim)
+ return hidden_seq_emb
+
+
+class GRUDecoder(DecoderBase):
+
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
+ dropout: float, batch_size: int, rank: int=-1,
+ no_dropout: bool=False):
+ super().__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.batch_size = batch_size
+ self.rank = rank
+ self.model = nn.GRU(input_size=self.input_dim,
+ hidden_size=self.hidden_dim,
+ num_layers=self.num_layers,
+ batch_first=True,
+ bidirectional=False,
+ dropout=self.dropout if not no_dropout else 0
+ )
+ if self.rank >= 0:
+ if torch.cuda.is_available():
+ self.model = self.model.to(self.rank)
+ else:
+ # support mac mps
+ self.model = self.model.to(torch.device("mps", self.rank))
+ self.init_hidden(self.batch_size)
+
+ def init_hidden(self, bsize):
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
+ min(self.batch_size, bsize),
+ self.hidden_dim),
+ requires_grad=False)
+ if self.rank >= 0:
+ if torch.cuda.is_available():
+ self.hidden_dict['h'] = self.hidden_dict['h'].to(self.rank)
+ else:
+ self.hidden_dict['h'] = self.hidden_dict['h'].to(torch.device("mps", self.rank))
+
+ def to(self, device):
+ newself = super().to(device)
+ newself.model = newself.model.to(device)
+ for k in self.hidden_dict.keys():
+ newself.hidden_dict[k] = newself.hidden_dict[k].to(device)
+ newself.rank = next(newself.parameters()).get_device()
+ return newself
+
+ def forward_one_step(self, tgt_emb_in):
+ ''' one-step forward model
+
+ Parameters
+ ----------
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
+
+ Returns
+ -------
+ Tensor, shape (batch_size, hidden_dim)
+ '''
+ bsize = tgt_emb_in.size(0)
+ tgt_emb_out, self.hidden_dict['h'] \
+ = self.model(tgt_emb_in.view(bsize, 1, -1),
+ self.hidden_dict['h'])
+ return tgt_emb_out
+
+
+class NodeMLP(nn.Module):
+ def __init__(self, input_size, output_size, hidden_size):
+ super().__init__()
+ self.lin1 = nn.Linear(input_size, hidden_size)
+ self.nbat = nn.BatchNorm1d(hidden_size)
+ self.lin2 = nn.Linear(hidden_size, output_size)
+
+ def forward(self, x):
+ x = self.lin1(x)
+ x = self.nbat(x)
+ x = x.relu()
+ x = self.lin2(x)
+ return x
+
+
+class GINLayer(MessagePassing):
+ def __init__(self, node_input_size, node_output_size, node_hidden_size, edge_input_size):
+ super().__init__()
+ self.node_mlp = NodeMLP(node_input_size, node_output_size, node_hidden_size)
+ self.edge_mlp = FeatureEmbedding(edge_input_size, node_output_size)
+ self.eps = nn.Parameter(torch.tensor([0.0]))
+
+ def forward(self, x, edge_index, edge_attr):
+ msg = self.propagate(edge_index, x=x ,edge_attr=edge_attr)
+ x = (1.0 + self.eps) * x + msg
+ x = x.relu()
+ x = self.node_mlp(x)
+ return x
+
+ def message(self, x_j, edge_attr):
+ edge_attr = self.edge_mlp(edge_attr)
+ x_j = x_j + edge_attr
+ x_j = x_j.relu()
+ return x_j
+
+ def update(self, aggr_out):
+ return aggr_out
+
+#TODO implement the case where features of atoms and edges are considered
+# Check GraphMVP and ogb (open graph benchmark) to realize this
+class GIN(torch.nn.Module):
+ def __init__(self, node_feature_size, edge_feature_size, hidden_channels=64,
+ proximity_size=3, dropout=0.1):
+ super().__init__()
+ #print("(num node features, num edge features)=", (node_feature_size, edge_feature_size))
+ hsize = hidden_channels * 2
+ atom_dim, edge_dim = get_atom_edge_feature_dims()
+ self.trans = FeatureEmbedding(atom_dim, hidden_channels)
+ ml = []
+ for _ in range(proximity_size):
+ ml.append(GINLayer(hidden_channels, hidden_channels, hsize, edge_dim))
+ self.mlist = nn.ModuleList(ml)
+ #It is possible to calculate relu with x.relu() where x is an output
+ #self.activations = nn.ModuleList(actl)
+ self.dropout = dropout
+ self.proximity_size = proximity_size
+
+ def forward(self, x, edge_index, edge_attr, batch_size):
+ x = x.to(torch.float)
+ #print("before: edge_weight.shape=", edge_attr.shape)
+ edge_attr = edge_attr.to(torch.float)
+ #print("after: edge_weight.shape=", edge_attr.shape)
+ x = self.trans(x)
+ # TODO Check if this x is consistent with global_add_pool
+ hlist = [global_add_pool(x, batch_size)]
+ for id, m in enumerate(self.mlist):
+ x = m(x, edge_index=edge_index, edge_attr=edge_attr)
+ #print("Done with one layer")
+ ###if id != self.proximity_size - 1:
+ x = x.relu()
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ #h = global_mean_pool(x, batch_size)
+ h = global_add_pool(x, batch_size)
+ hlist.append(h)
+ #print("Done with one relu call: x.shape=", x.shape)
+ #print("calling golbal mean pool")
+ #print("calling dropout x.shape=", x.shape)
+ #print("x=", x)
+ #print("hlist[0].shape=", hlist[0].shape)
+ x = torch.cat(hlist, dim=1)
+ #print("x.shape=", x.shape)
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ return x
+
+
+# TODO copied from MHG implementation and adapted here.
+class GrammarSeq2SeqVAE(nn.Module):
+
+ '''
+ Variational seq2seq with grammar.
+ TODO: rewrite this class using mixin
+ '''
+
+ def __init__(self, hrg, rank=-1, latent_dim=64, max_len=80,
+ batch_size=64, padding_idx=-1,
+ encoder_params={'hidden_dim': 384, 'num_layers': 3, 'bidirectional': True,
+ 'dropout': 0.1},
+ decoder_params={'hidden_dim': 384, #'num_layers': 2,
+ 'num_layers': 3,
+ 'dropout': 0.1},
+ prod_rule_embed_params={'out_dim': 128},
+ no_dropout=False):
+
+ super().__init__()
+ # TODO USE GRU FOR ENCODING AND DECODING
+ self.hrg = hrg
+ self.rank = rank
+ self.prod_rule_corpus = hrg.prod_rule_corpus
+ self.prod_rule_embed_params = prod_rule_embed_params
+
+ self.vocab_size = hrg.num_prod_rule + 1
+ self.batch_size = batch_size
+ self.padding_idx = np.mod(padding_idx, self.vocab_size)
+ self.no_dropout = no_dropout
+
+ self.latent_dim = latent_dim
+ self.max_len = max_len
+ self.encoder_params = encoder_params
+ self.decoder_params = decoder_params
+
+ # TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
+ embed_out_dim = self.prod_rule_embed_params['out_dim']
+ #use MolecularProdRuleEmbedding later on
+ self.src_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
+ padding_idx=self.padding_idx)
+ self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
+ padding_idx=self.padding_idx)
+
+ # USE a GRU-based encoder in MHG
+ self.encoder = GRUEncoder(input_dim=embed_out_dim, batch_size=self.batch_size,
+ rank=self.rank, no_dropout=self.no_dropout,
+ **self.encoder_params)
+
+ lin_dim = (self.encoder_params.get('bidirectional', False) + 1) * self.encoder_params['hidden_dim']
+ lin_out_dim = self.latent_dim
+ self.hidden2mean = nn.Linear(lin_dim, lin_out_dim, bias=False)
+ self.hidden2logvar = nn.Linear(lin_dim, lin_out_dim)
+
+ # USE a GRU-based decoder in MHG
+ self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
+ rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
+ self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
+ self.latent2hidden_dict = nn.ModuleDict()
+ dec_lin_out_dim = self.decoder_params['hidden_dim']
+ for each_hidden in self.decoder.hidden_dict.keys():
+ self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, dec_lin_out_dim)
+ if self.rank >= 0:
+ if torch.cuda.is_available():
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
+ else:
+ # support mac mps
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
+
+ self.dec2vocab = nn.Linear(dec_lin_out_dim, self.vocab_size)
+ self.encoder.init_hidden(self.batch_size)
+ self.decoder.init_hidden(self.batch_size)
+
+ # TODO Do we need this?
+ if hasattr(self.src_embedding, 'weight'):
+ self.src_embedding.weight.data.uniform_(-0.1, 0.1)
+ if hasattr(self.tgt_embedding, 'weight'):
+ self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
+
+ self.encoder.init_hidden(self.batch_size)
+ self.decoder.init_hidden(self.batch_size)
+
+ def to(self, device):
+ newself = super().to(device)
+ newself.src_embedding = newself.src_embedding.to(device)
+ newself.tgt_embedding = newself.tgt_embedding.to(device)
+ newself.encoder = newself.encoder.to(device)
+ newself.decoder = newself.decoder.to(device)
+ newself.dec2vocab = newself.dec2vocab.to(device)
+ newself.hidden2mean = newself.hidden2mean.to(device)
+ newself.hidden2logvar = newself.hidden2logvar.to(device)
+ newself.latent2tgt_emb = newself.latent2tgt_emb.to(device)
+ newself.latent2hidden_dict = newself.latent2hidden_dict.to(device)
+ return newself
+
+ def forward(self, in_seq, out_seq):
+ ''' forward model
+
+ Parameters
+ ----------
+ in_seq : Variable, shape (batch_size, length)
+ each element corresponds to word index.
+ where the index should be less than `vocab_size`
+
+ Returns
+ -------
+ Variable, shape (batch_size, length, vocab_size)
+ logit of each word (applying softmax yields the probability)
+ '''
+ mu, logvar = self.encode(in_seq)
+ z = self.reparameterize(mu, logvar)
+ return self.decode(z, out_seq), mu, logvar
+
+ def encode(self, in_seq):
+ src_emb = self.src_embedding(in_seq)
+ src_h = self.encoder.forward(src_emb)
+ if self.encoder_params.get('bidirectional', False):
+ concat_src_h = torch.cat((src_h[:, -1, 0, :], src_h[:, 0, 1, :]), dim=1)
+ return self.hidden2mean(concat_src_h), self.hidden2logvar(concat_src_h)
+ else:
+ return self.hidden2mean(src_h[:, -1, :]), self.hidden2logvar(src_h[:, -1, :])
+
+ def reparameterize(self, mu, logvar, training=True):
+ if training:
+ std = logvar.mul(0.5).exp_()
+ device = next(self.parameters()).device
+ eps = Variable(std.data.new(std.size()).normal_())
+ if device != eps.get_device():
+ eps.to(device)
+ return eps.mul(std).add_(mu)
+ else:
+ return mu
+
+ #TODO Not tested. Need to implement this in case of molecular structure generation
+ def sample(self, sample_size=-1, deterministic=True, return_z=False):
+ self.eval()
+ self.init_hidden()
+ if sample_size == -1:
+ sample_size = self.batch_size
+
+ num_iter = int(np.ceil(sample_size / self.batch_size))
+ hg_list = []
+ z_list = []
+ for _ in range(num_iter):
+ z = Variable(torch.normal(
+ torch.zeros(self.batch_size, self.latent_dim),
+ torch.ones(self.batch_size * self.latent_dim))).cuda()
+ _, each_hg_list = self.decode(z, deterministic=deterministic)
+ z_list.append(z)
+ hg_list += each_hg_list
+ z = torch.cat(z_list)[:sample_size]
+ hg_list = hg_list[:sample_size]
+ if return_z:
+ return hg_list, z.cpu().detach().numpy()
+ else:
+ return hg_list
+
+ def decode(self, z=None, out_seq=None, deterministic=True):
+ if z is None:
+ z = Variable(torch.normal(
+ torch.zeros(self.batch_size, self.latent_dim),
+ torch.ones(self.batch_size * self.latent_dim)))
+ if self.rank >= 0:
+ z = z.to(next(self.parameters()).device)
+
+ hidden_dict_0 = {}
+ for each_hidden in self.latent2hidden_dict.keys():
+ hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
+ bsize = z.size(0)
+ self.decoder.init_hidden(bsize)
+ self.decoder.feed_hidden(hidden_dict_0)
+
+ if out_seq is not None:
+ tgt_emb0 = self.latent2tgt_emb(z)
+ tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
+ out_seq_emb = self.tgt_embedding(out_seq)
+ tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
+ tgt_emb_pred_list = []
+ for each_idx in range(self.max_len):
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb[:, each_idx, :].view(bsize, 1, -1))
+ tgt_emb_pred_list.append(tgt_emb_pred)
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
+ return vocab_logit
+ else:
+ with torch.no_grad():
+ tgt_emb = self.latent2tgt_emb(z)
+ tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
+ tgt_emb_pred_list = []
+ stack_list = []
+ hg_list = []
+ nt_symbol_list = []
+ nt_edge_list = []
+ gen_finish_list = []
+ for _ in range(bsize):
+ stack_list.append([])
+ hg_list.append(None)
+ nt_symbol_list.append(NTSymbol(degree=0,
+ is_aromatic=False,
+ bond_symbol_list=[]))
+ nt_edge_list.append(None)
+ gen_finish_list.append(False)
+
+ for idx in range(self.max_len):
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
+ tgt_emb_pred_list.append(tgt_emb_pred)
+ vocab_logit = self.dec2vocab(tgt_emb_pred)
+ for each_batch_idx in range(bsize):
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
+ # get production rule greedily
+ prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
+ nt_symbol_list[each_batch_idx],
+ deterministic=deterministic)
+ # convert production rule into an index
+ tgt_id = self.hrg.prod_rule_list.index(prod_rule)
+ # apply the production rule
+ hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
+ # add non-terminals to the stack
+ stack_list[each_batch_idx].extend(nt_edges[::-1])
+ # if the stack size is 0, generation has finished!
+ if len(stack_list[each_batch_idx]) == 0:
+ gen_finish_list[each_batch_idx] = True
+ else:
+ nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
+ nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
+ else:
+ tgt_id = np.mod(self.padding_idx, self.vocab_size)
+ indice_tensor = torch.LongTensor([tgt_id])
+ device = next(self.parameters()).device
+ if indice_tensor.device != device:
+ indice_tensor = indice_tensor.to(device)
+ tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
+ #for id, v in enumerate(gen_finish_list):
+ #if not v:
+ # print("bacth id={} not finished generating a sequence: ".format(id))
+ return gen_finish_list, vocab_logit, hg_list
+
+
+# TODO A lot of duplicates with GrammarVAE. Clean up it if necessary
+class GrammarGINVAE(nn.Module):
+
+ '''
+ Variational autoencoder based on GIN and grammar
+ '''
+
+ def __init__(self, hrg, rank=-1, max_len=80,
+ batch_size=64, padding_idx=-1,
+ encoder_params={'node_feature_size': 4, 'edge_feature_size': 3,
+ 'hidden_channels': 64, 'proximity_size': 3,
+ 'dropout': 0.1},
+ decoder_params={'hidden_dim': 384, 'num_layers': 3,
+ 'dropout': 0.1},
+ prod_rule_embed_params={'out_dim': 128},
+ no_dropout=False):
+
+ super().__init__()
+ # TODO USE GRU FOR ENCODING AND DECODING
+ self.hrg = hrg
+ self.rank = rank
+ self.prod_rule_corpus = hrg.prod_rule_corpus
+ self.prod_rule_embed_params = prod_rule_embed_params
+
+ self.vocab_size = hrg.num_prod_rule + 1
+ self.batch_size = batch_size
+ self.padding_idx = np.mod(padding_idx, self.vocab_size)
+ self.no_dropout = no_dropout
+ self.max_len = max_len
+ self.encoder_params = encoder_params
+ self.decoder_params = decoder_params
+
+ # TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
+ embed_out_dim = self.prod_rule_embed_params['out_dim']
+ #use MolecularProdRuleEmbedding later on
+ self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
+ padding_idx=self.padding_idx)
+
+ self.encoder = GIN(**self.encoder_params)
+ self.latent_dim = self.encoder_params['hidden_channels']
+ self.proximity_size = self.encoder_params['proximity_size']
+ hidden_dim = self.decoder_params['hidden_dim']
+ self.hidden2mean = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim, bias=False)
+ self.hidden2logvar = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim)
+
+ self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
+ rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
+ self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
+ self.latent2hidden_dict = nn.ModuleDict()
+ for each_hidden in self.decoder.hidden_dict.keys():
+ self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, hidden_dim)
+ if self.rank >= 0:
+ if torch.cuda.is_available():
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
+ else:
+ # support mac mps
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
+
+ self.dec2vocab = nn.Linear(hidden_dim, self.vocab_size)
+ self.decoder.init_hidden(self.batch_size)
+
+ # TODO Do we need this?
+ if hasattr(self.tgt_embedding, 'weight'):
+ self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
+ self.decoder.init_hidden(self.batch_size)
+
+ def to(self, device):
+ newself = super().to(device)
+ newself.encoder = newself.encoder.to(device)
+ newself.decoder = newself.decoder.to(device)
+ newself.rank = next(newself.encoder.parameters()).get_device()
+ return newself
+
+ def forward(self, x, edge_index, edge_attr, batch_size, out_seq=None, sched_prob = None):
+ mu, logvar = self.encode(x, edge_index, edge_attr, batch_size)
+ z = self.reparameterize(mu, logvar)
+ return self.decode(z, out_seq, sched_prob=sched_prob), mu, logvar
+
+ #TODO Not tested. Need to implement this in case of molecular structure generation
+ def sample(self, sample_size=-1, deterministic=True, return_z=False):
+ self.eval()
+ self.init_hidden()
+ if sample_size == -1:
+ sample_size = self.batch_size
+
+ num_iter = int(np.ceil(sample_size / self.batch_size))
+ hg_list = []
+ z_list = []
+ for _ in range(num_iter):
+ z = Variable(torch.normal(
+ torch.zeros(self.batch_size, self.latent_dim),
+ torch.ones(self.batch_size * self.latent_dim))).cuda()
+ _, each_hg_list = self.decode(z, deterministic=deterministic)
+ z_list.append(z)
+ hg_list += each_hg_list
+ z = torch.cat(z_list)[:sample_size]
+ hg_list = hg_list[:sample_size]
+ if return_z:
+ return hg_list, z.cpu().detach().numpy()
+ else:
+ return hg_list
+
+ def decode(self, z=None, out_seq=None, deterministic=True, sched_prob=None):
+ if z is None:
+ z = Variable(torch.normal(
+ torch.zeros(self.batch_size, self.latent_dim),
+ torch.ones(self.batch_size * self.latent_dim)))
+ if self.rank >= 0:
+ z = z.to(next(self.parameters()).device)
+
+ hidden_dict_0 = {}
+ for each_hidden in self.latent2hidden_dict.keys():
+ hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
+ bsize = z.size(0)
+ self.decoder.init_hidden(bsize)
+ self.decoder.feed_hidden(hidden_dict_0)
+
+ if out_seq is not None:
+ tgt_emb0 = self.latent2tgt_emb(z)
+ tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
+ out_seq_emb = self.tgt_embedding(out_seq)
+ tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
+ tgt_emb_pred_list = []
+ tgt_emb_pred = None
+ for each_idx in range(self.max_len):
+ if tgt_emb_pred is None or sched_prob is None or torch.rand(1)[0] <= sched_prob:
+ inp = tgt_emb[:, each_idx, :].view(bsize, 1, -1)
+ else:
+ cur_logit = self.dec2vocab(tgt_emb_pred)
+ yi = torch.argmax(cur_logit, dim=2)
+ inp = self.tgt_embedding(yi)
+ tgt_emb_pred = self.decoder.forward_one_step(inp)
+ tgt_emb_pred_list.append(tgt_emb_pred)
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
+ return vocab_logit
+ else:
+ with torch.no_grad():
+ tgt_emb = self.latent2tgt_emb(z)
+ tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
+ tgt_emb_pred_list = []
+ stack_list = []
+ hg_list = []
+ nt_symbol_list = []
+ nt_edge_list = []
+ gen_finish_list = []
+ for _ in range(bsize):
+ stack_list.append([])
+ hg_list.append(None)
+ nt_symbol_list.append(NTSymbol(degree=0,
+ is_aromatic=False,
+ bond_symbol_list=[]))
+ nt_edge_list.append(None)
+ gen_finish_list.append(False)
+
+ for _ in range(self.max_len):
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
+ tgt_emb_pred_list.append(tgt_emb_pred)
+ vocab_logit = self.dec2vocab(tgt_emb_pred)
+ for each_batch_idx in range(bsize):
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
+ # get production rule greedily
+ prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
+ nt_symbol_list[each_batch_idx],
+ deterministic=deterministic)
+ # convert production rule into an index
+ tgt_id = self.hrg.prod_rule_list.index(prod_rule)
+ # apply the production rule
+ hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
+ # add non-terminals to the stack
+ stack_list[each_batch_idx].extend(nt_edges[::-1])
+ # if the stack size is 0, generation has finished!
+ if len(stack_list[each_batch_idx]) == 0:
+ gen_finish_list[each_batch_idx] = True
+ else:
+ nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
+ nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
+ else:
+ tgt_id = np.mod(self.padding_idx, self.vocab_size)
+ indice_tensor = torch.LongTensor([tgt_id])
+ if self.rank >= 0:
+ indice_tensor = indice_tensor.to(next(self.parameters()).device)
+ tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
+ return gen_finish_list, vocab_logit, hg_list
+
+ #TODO Not tested. Need to implement this in case of molecular structure generation
+ def conditional_distribution(self, z, tgt_id_list):
+ self.eval()
+ self.init_hidden()
+ z = z.cuda()
+
+ hidden_dict_0 = {}
+ for each_hidden in self.latent2hidden_dict.keys():
+ hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
+ self.decoder.feed_hidden(hidden_dict_0)
+
+ with torch.no_grad():
+ tgt_emb = self.latent2tgt_emb(z)
+ tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
+ nt_symbol_list = []
+ stack_list = []
+ hg_list = []
+ nt_edge_list = []
+ gen_finish_list = []
+ for _ in range(self.batch_size):
+ nt_symbol_list.append(NTSymbol(degree=0,
+ is_aromatic=False,
+ bond_symbol_list=[]))
+ stack_list.append([])
+ hg_list.append(None)
+ nt_edge_list.append(None)
+ gen_finish_list.append(False)
+
+ for each_position in range(len(tgt_id_list[0])):
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
+ for each_batch_idx in range(self.batch_size):
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
+ # use the prespecified target ids
+ tgt_id = tgt_id_list[each_batch_idx][each_position]
+ prod_rule = self.hrg.prod_rule_list[tgt_id]
+ # apply the production rule
+ hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
+ # add non-terminals to the stack
+ stack_list[each_batch_idx].extend(nt_edges[::-1])
+ # if the stack size is 0, generation has finished!
+ if len(stack_list[each_batch_idx]) == 0:
+ gen_finish_list[each_batch_idx] = True
+ else:
+ nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
+ nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
+ else:
+ tgt_id = np.mod(self.padding_idx, self.vocab_size)
+ indice_tensor = torch.LongTensor([tgt_id])
+ indice_tensor = indice_tensor.cuda()
+ tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
+
+ # last one step
+ conditional_logprob_list = []
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
+ vocab_logit = self.dec2vocab(tgt_emb_pred)
+ for each_batch_idx in range(self.batch_size):
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
+ # get production rule greedily
+ masked_logprob = self.hrg.prod_rule_corpus.masked_logprob(
+ vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
+ nt_symbol_list[each_batch_idx])
+ conditional_logprob_list.append(masked_logprob)
+ else:
+ conditional_logprob_list.append(None)
+ return conditional_logprob_list
+
+ #TODO Not tested. Need to implement this in case of molecular structure generation
+ def decode_with_beam_search(self, z, beam_width=1):
+ ''' Decode a latent vector using beam search.
+
+ Parameters
+ ----------
+ z
+ latent vector
+ beam_width : int
+ parameter for beam search
+
+ Returns
+ -------
+ List of Hypergraphs
+ '''
+ if self.batch_size != 1:
+ raise ValueError('this method works only under batch_size=1')
+ if self.padding_idx != -1:
+ raise ValueError('this method works only under padding_idx=-1')
+ top_k_tgt_id_list = [[]] * beam_width
+ logprob_list = [0.] * beam_width
+
+ for each_len in range(self.max_len):
+ expanded_logprob_list = np.repeat(logprob_list, self.vocab_size) # including padding_idx
+ expanded_length_list = np.array([0] * (beam_width * self.vocab_size))
+ for each_beam_idx, each_candidate in enumerate(top_k_tgt_id_list):
+ conditional_logprob = self.conditional_distribution(z, [each_candidate])[0]
+ if conditional_logprob is None:
+ expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
+ = logprob_list[each_beam_idx]
+ expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
+ = -np.inf
+ expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
+ = len(each_candidate)
+ else:
+ expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
+ = logprob_list[each_beam_idx] + conditional_logprob
+ expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
+ = -np.inf
+ expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
+ = len(each_candidate) + 1
+ score_list = np.array(expanded_logprob_list) / np.array(expanded_length_list)
+ if each_len == 0:
+ top_k_list = np.argsort(score_list[:self.vocab_size])[::-1][:beam_width]
+ else:
+ top_k_list = np.argsort(score_list)[::-1][:beam_width]
+ next_top_k_tgt_id_list = []
+ next_logprob_list = []
+ for each_top_k in top_k_list:
+ beam_idx = each_top_k // self.vocab_size
+ vocab_idx = each_top_k % self.vocab_size
+ if vocab_idx == self.vocab_size - 1:
+ next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx])
+ next_logprob_list.append(expanded_logprob_list[each_top_k])
+ else:
+ next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx] + [vocab_idx])
+ next_logprob_list.append(expanded_logprob_list[each_top_k])
+ top_k_tgt_id_list = next_top_k_tgt_id_list
+ logprob_list = next_logprob_list
+
+ # construct hypergraphs
+ hg_list = []
+ for each_tgt_id_list in top_k_tgt_id_list:
+ hg = None
+ stack = []
+ nt_edge = None
+ for each_idx, each_prod_rule_id in enumerate(each_tgt_id_list):
+ prod_rule = self.hrg.prod_rule_list[each_prod_rule_id]
+ hg, nt_edges = prod_rule.applied_to(hg, nt_edge)
+ stack.extend(nt_edges[::-1])
+ try:
+ nt_edge = stack.pop()
+ except IndexError:
+ if each_idx == len(each_tgt_id_list) - 1:
+ break
+ else:
+ raise ValueError('some bugs')
+ hg_list.append(hg)
+ return hg_list
+
+ def graph_embed(self, x, edge_index, edge_attr, batch_size):
+ src_h = self.encoder.forward(x, edge_index, edge_attr, batch_size)
+ return src_h
+
+ def encode(self, x, edge_index, edge_attr, batch_size):
+ #print("device for src_emb=", src_emb.get_device())
+ #print("device for self.encoder=", next(self.encoder.parameters()).get_device())
+ src_h = self.graph_embed(x, edge_index, edge_attr, batch_size)
+ mu, lv = self.get_mean_var(src_h)
+ return mu, lv
+
+ def get_mean_var(self, src_h):
+ #src_h = torch.tanh(src_h)
+ mu = self.hidden2mean(src_h)
+ lv = self.hidden2logvar(src_h)
+ mu = torch.tanh(mu)
+ lv = torch.tanh(lv)
+ return mu, lv
+
+ def reparameterize(self, mu, logvar, training=True):
+ if training:
+ std = logvar.mul(0.5).exp_()
+ eps = Variable(std.data.new(std.size()).normal_())
+ if self.rank >= 0:
+ eps = eps.to(next(self.parameters()).device)
+ return eps.mul(std).add_(mu)
+ else:
+ return mu
+
+# Copied from the MHG implementation and adapted
+class GrammarVAELoss(_Loss):
+
+ '''
+ a loss function for Grammar VAE
+
+ Attributes
+ ----------
+ hrg : HyperedgeReplacementGrammar
+ beta : float
+ coefficient of KL divergence
+ '''
+
+ def __init__(self, rank, hrg, beta=1.0, **kwargs):
+ super().__init__(**kwargs)
+ self.hrg = hrg
+ self.beta = beta
+ self.rank = rank
+
+ def forward(self, mu, logvar, in_seq_pred, in_seq):
+ ''' compute VAE loss
+
+ Parameters
+ ----------
+ in_seq_pred : torch.Tensor, shape (batch_size, max_len, vocab_size)
+ logit
+ in_seq : torch.Tensor, shape (batch_size, max_len)
+ each element corresponds to a word id in vocabulary.
+ mu : torch.Tensor, shape (batch_size, hidden_dim)
+ logvar : torch.Tensor, shape (batch_size, hidden_dim)
+ mean and log variance of the normal distribution
+ '''
+ batch_size = in_seq_pred.shape[0]
+ max_len = in_seq_pred.shape[1]
+ vocab_size = in_seq_pred.shape[2]
+ mask = torch.zeros(in_seq_pred.shape)
+
+ for each_batch in range(batch_size):
+ flag = True
+ for each_idx in range(max_len):
+ prod_rule_idx = in_seq[each_batch, each_idx]
+ if prod_rule_idx == vocab_size - 1:
+ #### DETERMINE WHETHER THIS SHOULD BE SKIPPED OR NOT
+ mask[each_batch, each_idx, prod_rule_idx] = 1
+ #break
+ continue
+ lhs = self.hrg.prod_rule_corpus.prod_rule_list[prod_rule_idx].lhs_nt_symbol
+ lhs_idx = self.hrg.prod_rule_corpus.nt_symbol_list.index(lhs)
+ mask[each_batch, each_idx, :-1] = torch.FloatTensor(self.hrg.prod_rule_corpus.lhs_in_prod_rule[lhs_idx])
+ if self.rank >= 0:
+ mask = mask.to(next(self.parameters()).device)
+ in_seq_pred = mask * in_seq_pred
+
+ cross_entropy = F.cross_entropy(
+ in_seq_pred.view(-1, vocab_size),
+ in_seq.view(-1),
+ reduction='sum',
+ #ignore_index=self.ignore_index if self.ignore_index is not None else -100
+ )
+ kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
+ return cross_entropy + self.beta * kl_div
+
+
+class VAELoss(_Loss):
+ def __init__(self, beta=0.01):
+ super().__init__()
+ self.beta = beta
+
+ def forward(self, mean, log_var, dec_outputs, targets):
+
+ device = mean.get_device()
+ if device >= 0:
+ targets = targets.to(mean.get_device())
+ reconstruction = F.cross_entropy(dec_outputs.view(-1, dec_outputs.size(2)), targets.view(-1), reduction='sum')
+
+ KL = 0.5 * torch.sum(1 + log_var - mean ** 2 - torch.exp(log_var))
+ loss = - self.beta * KL + reconstruction
+ return loss
diff --git a/notebooks/mhg-gnn_encoder_decoder_example.ipynb b/notebooks/mhg-gnn_encoder_decoder_example.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..84ff55ec04adaa87452caec876def45757592879
--- /dev/null
+++ b/notebooks/mhg-gnn_encoder_decoder_example.ipynb
@@ -0,0 +1,114 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "829ddc03",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ea820e23",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import load"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b9a51fa8",
+ "metadata": {},
+ "source": [
+ "# Load MHG-GNN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c6ea1fc8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_ckp = \"models/model_checkpoints/mhg_model/pickles/mhggnn_pretrained_model_radius7_1116_2023.pickle\"\n",
+ "\n",
+ "model = load.load(model_name = model_ckp)\n",
+ "if model is None:\n",
+ " print(\"Model not loaded, please check you have MHG pickle file\")\n",
+ "else:\n",
+ " print(\"MHG model loaded\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b4a0b557",
+ "metadata": {},
+ "source": [
+ "# Embeddings\n",
+ "\n",
+ "※ replace the smiles exaple list with your dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c63a6be6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " repr = model.encode([\"CCO\", \"O=C=O\", \"OC(=O)c1ccccc1C(=O)O\"])\n",
+ " \n",
+ "# Print the latent vectors\n",
+ "print(repr)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a59f9442",
+ "metadata": {},
+ "source": [
+ "# Decoding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6a0d8a41",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "orig = model.decode(repr)\n",
+ "print(orig)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf b/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..a7dcc1270d1f444f77366013ad2d3d93ebb426ab
Binary files /dev/null and b/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf differ
diff --git a/pickles/.DS_Store b/pickles/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/pickles/.DS_Store differ
diff --git a/pickles/mhggnn_pretrained_model_0724_2023.pickle b/pickles/mhggnn_pretrained_model_0724_2023.pickle
new file mode 100644
index 0000000000000000000000000000000000000000..24d5ec31658ddc4d59876cc69996c399727a4e7e
--- /dev/null
+++ b/pickles/mhggnn_pretrained_model_0724_2023.pickle
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e5fd6991c7f3a1791c0ba5b2e3805cb6f8f3f734cfdeda145d99cceed4a8533
+size 261888810
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..12d7e158b8963bb101610dcb0a81fd3cc04eaae0
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,37 @@
+[metadata]
+name = mhg-gnn
+version = attr: .__version__
+description = Package for mhg-gnn
+author= team
+long_description_content_type=text/markdown
+long_description = file: README.md
+python_requires = >= 3.9.7
+license = TBD
+
+classifiers =
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.9
+
+[options]
+install_requires =
+ networkx>=2.8
+ numpy>=1.23.5, <2.0.0
+ pandas>=1.5.3
+ rdkit-pypi>=2022.9.4, <2023.9.6
+ torch>=2.0.0
+ torchinfo>=1.8.0
+ torch-geometric>=2.3.1
+ requests>=2.32.2
+ scikit-learn>=1.5.0
+ urllib3>=2.2.2
+
+
+setup_requires =
+ setuptools
+package_dir =
+ = .
+packages=find:
+include_package_data = True
+
+[options.packages.find]
+where = .
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f160da695819ef6906f6dd332e8398cf419b8e
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+
+import setuptools
+
+if __name__ == "__main__":
+ setuptools.setup()
\ No newline at end of file